mint_utils.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package quic
  2. import (
  3. "bytes"
  4. gocrypto "crypto"
  5. "crypto/tls"
  6. "crypto/x509"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "github.com/bifurcation/mint"
  11. "github.com/lucas-clemente/quic-go/internal/crypto"
  12. "github.com/lucas-clemente/quic-go/internal/handshake"
  13. "github.com/lucas-clemente/quic-go/internal/protocol"
  14. "github.com/lucas-clemente/quic-go/internal/utils"
  15. "github.com/lucas-clemente/quic-go/internal/wire"
  16. )
  17. type mintController struct {
  18. csc *handshake.CryptoStreamConn
  19. conn *mint.Conn
  20. }
  21. var _ handshake.MintTLS = &mintController{}
  22. func newMintController(
  23. csc *handshake.CryptoStreamConn,
  24. mconf *mint.Config,
  25. pers protocol.Perspective,
  26. ) handshake.MintTLS {
  27. var conn *mint.Conn
  28. if pers == protocol.PerspectiveClient {
  29. conn = mint.Client(csc, mconf)
  30. } else {
  31. conn = mint.Server(csc, mconf)
  32. }
  33. return &mintController{
  34. csc: csc,
  35. conn: conn,
  36. }
  37. }
  38. func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
  39. return mc.conn.ConnectionState().CipherSuite
  40. }
  41. func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
  42. return mc.conn.ComputeExporter(label, context, keyLength)
  43. }
  44. func (mc *mintController) Handshake() mint.Alert {
  45. return mc.conn.Handshake()
  46. }
  47. func (mc *mintController) State() mint.State {
  48. return mc.conn.ConnectionState().HandshakeState
  49. }
  50. func (mc *mintController) ConnectionState() mint.ConnectionState {
  51. return mc.conn.ConnectionState()
  52. }
  53. func (mc *mintController) SetCryptoStream(stream io.ReadWriter) {
  54. mc.csc.SetStream(stream)
  55. }
  56. func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
  57. mconf := &mint.Config{
  58. NonBlocking: true,
  59. CipherSuites: []mint.CipherSuite{
  60. mint.TLS_AES_128_GCM_SHA256,
  61. mint.TLS_AES_256_GCM_SHA384,
  62. },
  63. }
  64. if tlsConf != nil {
  65. mconf.ServerName = tlsConf.ServerName
  66. mconf.InsecureSkipVerify = tlsConf.InsecureSkipVerify
  67. mconf.Certificates = make([]*mint.Certificate, len(tlsConf.Certificates))
  68. mconf.RootCAs = tlsConf.RootCAs
  69. mconf.VerifyPeerCertificate = tlsConf.VerifyPeerCertificate
  70. for i, certChain := range tlsConf.Certificates {
  71. mconf.Certificates[i] = &mint.Certificate{
  72. Chain: make([]*x509.Certificate, len(certChain.Certificate)),
  73. PrivateKey: certChain.PrivateKey.(gocrypto.Signer),
  74. }
  75. for j, cert := range certChain.Certificate {
  76. c, err := x509.ParseCertificate(cert)
  77. if err != nil {
  78. return nil, err
  79. }
  80. mconf.Certificates[i].Chain[j] = c
  81. }
  82. }
  83. switch tlsConf.ClientAuth {
  84. case tls.NoClientCert:
  85. case tls.RequireAnyClientCert:
  86. mconf.RequireClientAuth = true
  87. default:
  88. return nil, errors.New("mint currently only support ClientAuthType RequireAnyClientCert")
  89. }
  90. }
  91. if err := mconf.Init(pers == protocol.PerspectiveClient); err != nil {
  92. return nil, err
  93. }
  94. return mconf, nil
  95. }
  96. // unpackInitialOrRetryPacket unpacks packets Initial and Retry packets
  97. // These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0.
  98. func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, logger utils.Logger, version protocol.VersionNumber) (*wire.StreamFrame, error) {
  99. decrypted, err := aead.Open(data[:0], data, hdr.PacketNumber, hdr.Raw)
  100. if err != nil {
  101. return nil, err
  102. }
  103. var frame *wire.StreamFrame
  104. r := bytes.NewReader(decrypted)
  105. for {
  106. f, err := wire.ParseNextFrame(r, hdr, version)
  107. if err != nil {
  108. return nil, err
  109. }
  110. var ok bool
  111. if frame, ok = f.(*wire.StreamFrame); ok || frame == nil {
  112. break
  113. }
  114. }
  115. if frame == nil {
  116. return nil, errors.New("Packet doesn't contain a STREAM_FRAME")
  117. }
  118. if frame.StreamID != version.CryptoStreamID() {
  119. return nil, fmt.Errorf("Received STREAM_FRAME for wrong stream (Stream ID %d)", frame.StreamID)
  120. }
  121. // We don't need a check for the stream ID here.
  122. // The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream.
  123. if frame.Offset != 0 {
  124. return nil, errors.New("received stream data with non-zero offset")
  125. }
  126. if logger.Debug() {
  127. logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.DestConnectionID)
  128. hdr.Log(logger)
  129. wire.LogFrame(logger, frame, false)
  130. }
  131. return frame, nil
  132. }
  133. // packUnencryptedPacket provides a low-overhead way to pack a packet.
  134. // It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available.
  135. func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective, logger utils.Logger) ([]byte, error) {
  136. raw := *getPacketBuffer()
  137. buffer := bytes.NewBuffer(raw[:0])
  138. if err := hdr.Write(buffer, pers, hdr.Version); err != nil {
  139. return nil, err
  140. }
  141. payloadStartIndex := buffer.Len()
  142. if err := f.Write(buffer, hdr.Version); err != nil {
  143. return nil, err
  144. }
  145. raw = raw[0:buffer.Len()]
  146. _ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex])
  147. raw = raw[0 : buffer.Len()+aead.Overhead()]
  148. if logger.Debug() {
  149. logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", hdr.PacketNumber, len(raw), hdr.SrcConnectionID, protocol.EncryptionUnencrypted)
  150. hdr.Log(logger)
  151. wire.LogFrame(logger, f, true)
  152. }
  153. return raw, nil
  154. }