Jelajahi Sumber

Merge pull request #487 from rod-hynes/master

QUIC-OSSH updates
Rod Hynes 7 tahun lalu
induk
melakukan
b355d52109
53 mengubah file dengan 1773 tambahan dan 1728 penghapusan
  1. 0 1
      psiphon/common/protocol/protocol.go
  2. 4 4
      psiphon/common/quic/quic.go
  3. 5 0
      vendor/github.com/lucas-clemente/quic-go/Changelog.md
  4. 2 2
      vendor/github.com/lucas-clemente/quic-go/appveyor.yml
  5. 191 155
      vendor/github.com/lucas-clemente/quic-go/client.go
  6. 0 96
      vendor/github.com/lucas-clemente/quic-go/client_multiplexer.go
  7. 6 6
      vendor/github.com/lucas-clemente/quic-go/crypto_stream.go
  8. 20 4
      vendor/github.com/lucas-clemente/quic-go/interface.go
  9. 2 1
      vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go
  10. 39 33
      vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go
  11. 2 2
      vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go
  12. 14 2
      vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go
  13. 0 12
      vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go
  14. 1 4
      vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go
  15. 0 9
      vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go
  16. 2 4
      vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go
  17. 0 51
      vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go
  18. 86 0
      vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_protector.go
  19. 14 12
      vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go
  20. 25 29
      vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go
  21. 26 58
      vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go
  22. 6 13
      vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go
  23. 3 0
      vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go
  24. 78 9
      vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go
  25. 12 36
      vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go
  26. 13 35
      vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go
  27. 84 58
      vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go
  28. 15 2
      vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go
  29. 5 0
      vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go
  30. 6 0
      vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go
  31. 6 8
      vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go
  32. 11 0
      vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go
  33. 18 1
      vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go
  34. 5 0
      vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go
  35. 219 48
      vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go
  36. 235 0
      vendor/github.com/lucas-clemente/quic-go/internal/wire/header_parser.go
  37. 0 205
      vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go
  38. 0 244
      vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go
  39. 2 1
      vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go
  40. 0 116
      vendor/github.com/lucas-clemente/quic-go/mint_utils.go
  41. 17 16
      vendor/github.com/lucas-clemente/quic-go/mockgen.go
  42. 13 4
      vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh
  43. 63 0
      vendor/github.com/lucas-clemente/quic-go/multiplexer.go
  44. 135 15
      vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go
  45. 10 5
      vendor/github.com/lucas-clemente/quic-go/packet_packer.go
  46. 7 10
      vendor/github.com/lucas-clemente/quic-go/send_stream.go
  47. 104 155
      vendor/github.com/lucas-clemente/quic-go/server.go
  48. 63 0
      vendor/github.com/lucas-clemente/quic-go/server_session.go
  49. 94 169
      vendor/github.com/lucas-clemente/quic-go/server_tls.go
  50. 79 62
      vendor/github.com/lucas-clemente/quic-go/session.go
  51. 2 2
      vendor/github.com/lucas-clemente/quic-go/stream_framer.go
  52. 2 2
      vendor/github.com/lucas-clemente/quic-go/window_update_queue.go
  53. 27 27
      vendor/vendor.json

+ 0 - 1
psiphon/common/protocol/protocol.go

@@ -103,7 +103,6 @@ var SupportedTunnelProtocols = TunnelProtocols{
 }
 
 var DefaultDisabledTunnelProtocols = TunnelProtocols{
-	TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH,
 	TUNNEL_PROTOCOL_MARIONETTE_OBFUSCATED_SSH,
 	TUNNEL_PROTOCOL_TAPDANCE_OBFUSCATED_SSH,
 }

+ 4 - 4
psiphon/common/quic/quic.go

@@ -169,7 +169,7 @@ func Dial(
 
 		stream, err := session.OpenStream()
 		if err != nil {
-			session.Close(nil)
+			session.Close()
 			resultChannel <- dialResult{err: err}
 			return
 		}
@@ -191,7 +191,7 @@ func Dial(
 	case <-ctx.Done():
 		err = ctx.Err()
 		// Interrupt the goroutine
-		session.Close(nil)
+		session.Close()
 		<-resultChannel
 	}
 
@@ -234,7 +234,7 @@ func (conn *Conn) doDeferredAcceptStream() error {
 
 	stream, err := conn.session.AcceptStream()
 	if err != nil {
-		conn.session.Close(nil)
+		conn.session.Close()
 		conn.acceptErr = common.ContextError(err)
 		return conn.acceptErr
 	}
@@ -290,7 +290,7 @@ func (conn *Conn) Write(b []byte) (int, error) {
 }
 
 func (conn *Conn) Close() error {
-	err := conn.session.Close(nil)
+	err := conn.session.Close()
 	if conn.packetConn != nil {
 		err1 := conn.packetConn.Close()
 		if err == nil {

+ 5 - 0
vendor/github.com/lucas-clemente/quic-go/Changelog.md

@@ -1,5 +1,10 @@
 # Changelog
 
+## v0.9.0 (2018-08-15)
+
+- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC).
+- Split Session.Close into one method for regular closing and one for closing with an error.
+
 ## v0.8.0 (2018-06-26)
 
 - Add support for unidirectional streams (for IETF QUIC).

+ 2 - 2
vendor/github.com/lucas-clemente/quic-go/appveyor.yml

@@ -14,8 +14,8 @@ clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
 
 install:
   - rmdir c:\go /s /q
-  - appveyor DownloadFile https://storage.googleapis.com/golang/go1.10.3.windows-amd64.zip
-  - 7z x go1.10.3.windows-amd64.zip -y -oC:\ > NUL
+  - appveyor DownloadFile https://storage.googleapis.com/golang/go1.11rc1.windows-amd64.zip
+  - 7z x go1.11rc1.windows-amd64.zip -y -oC:\ > NUL
   - set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
   - echo %PATH%
   - echo %GOPATH%

+ 191 - 155
vendor/github.com/lucas-clemente/quic-go/client.go

@@ -7,10 +7,9 @@ import (
 	"errors"
 	"fmt"
 	"net"
-	"strings"
 	"sync"
-	"time"
 
+	"github.com/bifurcation/mint"
 	"github.com/lucas-clemente/quic-go/internal/handshake"
 	"github.com/lucas-clemente/quic-go/internal/protocol"
 	"github.com/lucas-clemente/quic-go/internal/utils"
@@ -21,18 +20,25 @@ import (
 type client struct {
 	mutex sync.Mutex
 
-	conn     connection
+	conn connection
+	// If the client is created with DialAddr, we create a packet conn.
+	// If it is started with Dial, we take a packet conn as a parameter.
+	createdPacketConn bool
+
 	hostname string
 
-	receivedRetry bool
+	packetHandlers packetHandlerManager
+
+	token      []byte
+	numRetries int
 
 	versionNegotiated                bool // has the server accepted our version
 	receivedVersionNegotiationPacket bool
 	negotiatedVersions               []protocol.VersionNumber // the list of versions from the version negotiation packet
 
-	tlsConf *tls.Config
-	config  *Config
-	tls     handshake.MintTLS // only used when using TLS
+	tlsConf  *tls.Config
+	mintConf *mint.Config
+	config   *Config
 
 	srcConnID  protocol.ConnectionID
 	destConnID protocol.ConnectionID
@@ -41,6 +47,7 @@ type client struct {
 	version        protocol.VersionNumber
 
 	handshakeChan chan struct{}
+	closeCallback func(protocol.ConnectionID)
 
 	session quicSession
 
@@ -51,8 +58,10 @@ var _ packetHandler = &client{}
 
 var (
 	// make it possible to mock connection ID generation in the tests
-	generateConnectionID         = protocol.GenerateConnectionID
-	errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
+	generateConnectionID           = protocol.GenerateConnectionID
+	generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
+	errCloseSessionForNewVersion   = errors.New("closing session in order to recreate it with a new version")
+	errCloseSessionForRetry        = errors.New("closing session in response to a stateless retry")
 )
 
 // DialAddr establishes a new QUIC connection to a server.
@@ -81,15 +90,7 @@ func DialAddrContext(
 	if err != nil {
 		return nil, err
 	}
-	c, err := newClient(udpConn, udpAddr, config, tlsConf, addr)
-	if err != nil {
-		return nil, err
-	}
-	go c.listen()
-	if err := c.dial(ctx); err != nil {
-		return nil, err
-	}
-	return c.session, nil
+	return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
 }
 
 // Dial establishes a new QUIC connection to a server using a net.PacketConn.
@@ -114,37 +115,49 @@ func DialContext(
 	tlsConf *tls.Config,
 	config *Config,
 ) (Session, error) {
-	c, err := newClient(pconn, remoteAddr, config, tlsConf, host)
+	return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
+}
+
+func dialContext(
+	ctx context.Context,
+	pconn net.PacketConn,
+	remoteAddr net.Addr,
+	host string,
+	tlsConf *tls.Config,
+	config *Config,
+	createdPacketConn bool,
+) (Session, error) {
+	config = populateClientConfig(config, createdPacketConn)
+	packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
+	if err != nil {
+		return nil, err
+	}
+	c, err := newClient(pconn, remoteAddr, config, tlsConf, host, packetHandlers.Remove, createdPacketConn)
 	if err != nil {
 		return nil, err
 	}
-	getClientMultiplexer().Add(pconn, c.srcConnID, c)
+	c.packetHandlers = packetHandlers
 	if err := c.dial(ctx); err != nil {
 		return nil, err
 	}
 	return c.session, nil
 }
 
-func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, host string) (*client, error) {
-	clientConfig := populateClientConfig(config)
-	version := clientConfig.Versions[0]
-	srcConnID, err := generateConnectionID()
-	if err != nil {
-		return nil, err
-	}
-	destConnID := srcConnID
-	if version.UsesTLS() {
-		destConnID, err = generateConnectionID()
-		if err != nil {
-			return nil, err
-		}
-	}
-
+func newClient(
+	pconn net.PacketConn,
+	remoteAddr net.Addr,
+	config *Config,
+	tlsConf *tls.Config,
+	host string,
+	closeCallback func(protocol.ConnectionID),
+	createdPacketConn bool,
+) (*client, error) {
 	var hostname string
 	if tlsConf != nil {
 		hostname = tlsConf.ServerName
 	}
 	if hostname == "" {
+		var err error
 		hostname, _, err = net.SplitHostPort(host)
 		if err != nil {
 			return nil, err
@@ -159,22 +172,27 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
 			}
 		}
 	}
-	return &client{
-		conn:          &conn{pconn: pconn, currentAddr: remoteAddr},
-		srcConnID:     srcConnID,
-		destConnID:    destConnID,
-		hostname:      hostname,
-		tlsConf:       tlsConf,
-		config:        clientConfig,
-		version:       version,
-		handshakeChan: make(chan struct{}),
-		logger:        utils.DefaultLogger.WithPrefix("client"),
-	}, nil
+	onClose := func(protocol.ConnectionID) {}
+	if closeCallback != nil {
+		onClose = closeCallback
+	}
+	c := &client{
+		conn:              &conn{pconn: pconn, currentAddr: remoteAddr},
+		createdPacketConn: createdPacketConn,
+		hostname:          hostname,
+		tlsConf:           tlsConf,
+		config:            config,
+		version:           config.Versions[0],
+		handshakeChan:     make(chan struct{}),
+		closeCallback:     onClose,
+		logger:            utils.DefaultLogger.WithPrefix("client"),
+	}
+	return c, c.generateConnectionIDs()
 }
 
 // populateClientConfig populates fields in the quic.Config with their default values, if none are set
 // it may be called with nil
-func populateClientConfig(config *Config) *Config {
+func populateClientConfig(config *Config, createdPacketConn bool) *Config {
 	if config == nil {
 		config = &Config{}
 	}
@@ -212,12 +230,17 @@ func populateClientConfig(config *Config) *Config {
 	} else if maxIncomingUniStreams < 0 {
 		maxIncomingUniStreams = 0
 	}
+	connIDLen := config.ConnectionIDLength
+	if connIDLen == 0 && !createdPacketConn {
+		connIDLen = protocol.DefaultConnectionIDLength
+	}
 
 	return &Config{
 		Versions:                              versions,
 		HandshakeTimeout:                      handshakeTimeout,
 		IdleTimeout:                           idleTimeout,
 		RequestConnectionIDOmission:           config.RequestConnectionIDOmission,
+		ConnectionIDLength:                    connIDLen,
 		MaxReceiveStreamFlowControlWindow:     maxReceiveStreamFlowControlWindow,
 		MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
 		MaxIncomingStreams:                    maxIncomingStreams,
@@ -226,6 +249,27 @@ func populateClientConfig(config *Config) *Config {
 	}
 }
 
+func (c *client) generateConnectionIDs() error {
+	connIDLen := protocol.ConnectionIDLenGQUIC
+	if c.version.UsesTLS() {
+		connIDLen = c.config.ConnectionIDLength
+	}
+	srcConnID, err := generateConnectionID(connIDLen)
+	if err != nil {
+		return err
+	}
+	destConnID := srcConnID
+	if c.version.UsesTLS() {
+		destConnID, err = generateConnectionIDForInitial()
+		if err != nil {
+			return err
+		}
+	}
+	c.srcConnID = srcConnID
+	c.destConnID = destConnID
+	return nil
+}
+
 func (c *client) dial(ctx context.Context) error {
 	c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
 
@@ -235,9 +279,6 @@ func (c *client) dial(ctx context.Context) error {
 	} else {
 		err = c.dialGQUIC(ctx)
 	}
-	if err == errCloseSessionForNewVersion {
-		return c.dial(ctx)
-	}
 	return err
 }
 
@@ -245,7 +286,11 @@ func (c *client) dialGQUIC(ctx context.Context) error {
 	if err := c.createNewGQUICSession(); err != nil {
 		return err
 	}
-	return c.establishSecureConnection(ctx)
+	err := c.establishSecureConnection(ctx)
+	if err == errCloseSessionForNewVersion {
+		return c.dial(ctx)
+	}
+	return err
 }
 
 func (c *client) dialTLS(ctx context.Context) error {
@@ -256,8 +301,8 @@ func (c *client) dialTLS(ctx context.Context) error {
 		OmitConnectionID:            c.config.RequestConnectionIDOmission,
 		MaxBidiStreams:              uint16(c.config.MaxIncomingStreams),
 		MaxUniStreams:               uint16(c.config.MaxIncomingUniStreams),
+		DisableMigration:            true,
 	}
-	csc := handshake.NewCryptoStreamConn(nil)
 	extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
 	mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
 	if err != nil {
@@ -265,27 +310,16 @@ func (c *client) dialTLS(ctx context.Context) error {
 	}
 	mintConf.ExtensionHandler = extHandler
 	mintConf.ServerName = c.hostname
-	c.tls = newMintController(csc, mintConf, protocol.PerspectiveClient)
+	c.mintConf = mintConf
 
 	if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
 		return err
 	}
-	if err := c.establishSecureConnection(ctx); err != nil {
-		if err != handshake.ErrCloseSessionForRetry {
-			return err
-		}
-		c.logger.Infof("Received a Retry packet. Recreating session.")
-		c.mutex.Lock()
-		c.receivedRetry = true
-		c.mutex.Unlock()
-		if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
-			return err
-		}
-		if err := c.establishSecureConnection(ctx); err != nil {
-			return err
-		}
+	err = c.establishSecureConnection(ctx)
+	if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
+		return c.dial(ctx)
 	}
-	return nil
+	return err
 }
 
 // establishSecureConnection runs the session, and tries to establish a secure connection
@@ -299,13 +333,16 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
 
 	go func() {
 		err := c.session.run() // returns as soon as the session is closed
+		if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
+			c.conn.Close()
+		}
 		errorChan <- err
 	}()
 
 	select {
 	case <-ctx.Done():
-		// The session sending a PeerGoingAway error to the server.
-		c.session.Close(nil)
+		// The session will send a PeerGoingAway error to the server.
+		c.session.Close()
 		return ctx.Err()
 	case err := <-errorChan:
 		return err
@@ -315,53 +352,6 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
 	}
 }
 
-// Listen listens on the underlying connection and passes packets on for handling.
-// It returns when the connection is closed.
-func (c *client) listen() {
-	var err error
-
-	for {
-		var n int
-		var addr net.Addr
-		data := *getPacketBuffer()
-		data = data[:protocol.MaxReceivePacketSize]
-		// The packet size should not exceed protocol.MaxReceivePacketSize bytes
-		// If it does, we only read a truncated packet, which will then end up undecryptable
-		n, addr, err = c.conn.Read(data)
-		if err != nil {
-			if !strings.HasSuffix(err.Error(), "use of closed network connection") {
-				c.mutex.Lock()
-				if c.session != nil {
-					c.session.Close(err)
-				}
-				c.mutex.Unlock()
-			}
-			break
-		}
-		c.handleRead(addr, data[:n])
-	}
-}
-
-func (c *client) handleRead(remoteAddr net.Addr, packet []byte) {
-	rcvTime := time.Now()
-
-	r := bytes.NewReader(packet)
-	hdr, err := wire.ParseHeaderSentByServer(r)
-	// drop the packet if we can't parse the header
-	if err != nil {
-		c.logger.Errorf("error handling packet: %s", err)
-		return
-	}
-	hdr.Raw = packet[:len(packet)-r.Len()]
-	packetData := packet[len(packet)-r.Len():]
-	c.handlePacket(&receivedPacket{
-		remoteAddr: remoteAddr,
-		header:     hdr,
-		data:       packetData,
-		rcvTime:    rcvTime,
-	})
-}
-
 func (c *client) handlePacket(p *receivedPacket) {
 	if err := c.handlePacketImpl(p); err != nil {
 		c.logger.Errorf("error handling packet: %s", err)
@@ -374,16 +364,12 @@ func (c *client) handlePacketImpl(p *receivedPacket) error {
 
 	// handle Version Negotiation Packets
 	if p.header.IsVersionNegotiation {
-		// ignore delayed / duplicated version negotiation packets
-		if c.receivedVersionNegotiationPacket || c.versionNegotiated {
-			return errors.New("received a delayed Version Negotiation Packet")
+		err := c.handleVersionNegotiationPacket(p.header)
+		if err != nil {
+			c.session.destroy(err)
 		}
-
 		// version negotiation packets have no payload
-		if err := c.handleVersionNegotiationPacket(p.header); err != nil {
-			c.session.Close(err)
-		}
-		return nil
+		return err
 	}
 
 	if p.header.IsPublicHeader {
@@ -400,18 +386,12 @@ func (c *client) handleIETFQUICPacket(p *receivedPacket) error {
 	if p.header.IsLongHeader {
 		switch p.header.Type {
 		case protocol.PacketTypeRetry:
-			if c.receivedRetry {
-				return nil
-			}
+			c.handleRetryPacket(p.header)
+			return nil
 		case protocol.PacketTypeHandshake:
 		default:
 			return fmt.Errorf("Received unsupported packet type: %s", p.header.Type)
 		}
-		if protocol.ByteCount(len(p.data)) < p.header.PayloadLen {
-			return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(p.data), p.header.PayloadLen)
-		}
-		p.data = p.data[:int(p.header.PayloadLen)]
-		// TODO(#1312): implement parsing of compound packets
 	}
 
 	// this is the first packet we are receiving
@@ -462,6 +442,12 @@ func (c *client) handleGQUICPacket(p *receivedPacket) error {
 }
 
 func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
+	// ignore delayed / duplicated version negotiation packets
+	if c.receivedVersionNegotiationPacket || c.versionNegotiated {
+		c.logger.Debugf("Received a delayed Version Negotiation Packet.")
+		return nil
+	}
+
 	for _, v := range hdr.SupportedVersions {
 		if v == c.version {
 			// the version negotiation packet contains the version that we offered
@@ -472,7 +458,6 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
 	}
 
 	c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
-
 	newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
 	if !ok {
 		return qerr.InvalidVersion
@@ -483,28 +468,46 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
 	// switch to negotiated version
 	c.initialVersion = c.version
 	c.version = newVersion
-	var err error
-	c.destConnID, err = generateConnectionID()
-	if err != nil {
+	if err := c.generateConnectionIDs(); err != nil {
 		return err
 	}
-	// in gQUIC, there's only one connection ID
-	if !c.version.UsesTLS() {
-		c.srcConnID = c.destConnID
-	}
+
 	c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
-	c.session.Close(errCloseSessionForNewVersion)
+	c.session.destroy(errCloseSessionForNewVersion)
 	return nil
 }
 
-func (c *client) createNewGQUICSession() (err error) {
+func (c *client) handleRetryPacket(hdr *wire.Header) {
+	c.logger.Debugf("<- Received Retry")
+	hdr.Log(c.logger)
+	// A server that performs multiple retries must use a source connection ID of at least 8 bytes.
+	// Only a server that won't send additional Retries can use shorter connection IDs.
+	if hdr.OrigDestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
+		c.logger.Debugf("Received a Retry with a too short Original Destination Connection ID: %d bytes, must have at least %d bytes.", hdr.OrigDestConnectionID.Len(), protocol.MinConnectionIDLenInitial)
+		return
+	}
+	if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
+		c.logger.Debugf("Received spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
+		return
+	}
+	c.numRetries++
+	if c.numRetries > protocol.MaxRetries {
+		c.session.destroy(qerr.CryptoTooManyRejects)
+		return
+	}
+	c.destConnID = hdr.SrcConnectionID
+	c.token = hdr.Token
+	c.session.destroy(errCloseSessionForRetry)
+}
+
+func (c *client) createNewGQUICSession() error {
 	c.mutex.Lock()
 	defer c.mutex.Unlock()
 	runner := &runner{
 		onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
-		removeConnectionIDImpl:  func(protocol.ConnectionID) {},
+		removeConnectionIDImpl:  c.closeCallback,
 	}
-	c.session, err = newClientSession(
+	sess, err := newClientSession(
 		c.conn,
 		runner,
 		c.hostname,
@@ -516,40 +519,73 @@ func (c *client) createNewGQUICSession() (err error) {
 		c.negotiatedVersions,
 		c.logger,
 	)
-	return err
+	if err != nil {
+		return err
+	}
+	c.session = sess
+	c.packetHandlers.Add(c.srcConnID, c)
+	if c.config.RequestConnectionIDOmission {
+		c.packetHandlers.Add(protocol.ConnectionID{}, c)
+	}
+	return nil
 }
 
 func (c *client) createNewTLSSession(
 	paramsChan <-chan handshake.TransportParameters,
 	version protocol.VersionNumber,
-) (err error) {
+) error {
 	c.mutex.Lock()
 	defer c.mutex.Unlock()
 	runner := &runner{
 		onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
-		removeConnectionIDImpl:  func(protocol.ConnectionID) {},
+		removeConnectionIDImpl:  c.closeCallback,
 	}
-	c.session, err = newTLSClientSession(
+	sess, err := newTLSClientSession(
 		c.conn,
 		runner,
-		c.hostname,
-		c.version,
+		c.token,
 		c.destConnID,
 		c.srcConnID,
 		c.config,
-		c.tls,
+		c.mintConf,
 		paramsChan,
 		1,
 		c.logger,
+		c.version,
 	)
-	return err
+	if err != nil {
+		return err
+	}
+	c.session = sess
+	c.packetHandlers.Add(c.srcConnID, c)
+	return nil
 }
 
-func (c *client) Close(err error) error {
+func (c *client) Close() error {
 	c.mutex.Lock()
 	defer c.mutex.Unlock()
 	if c.session == nil {
 		return nil
 	}
-	return c.session.Close(err)
+	return c.session.Close()
+}
+
+func (c *client) destroy(e error) {
+	c.mutex.Lock()
+	defer c.mutex.Unlock()
+	if c.session == nil {
+		return
+	}
+	c.session.destroy(e)
+}
+
+func (c *client) GetVersion() protocol.VersionNumber {
+	c.mutex.Lock()
+	v := c.version
+	c.mutex.Unlock()
+	return v
+}
+
+func (c *client) GetPerspective() protocol.Perspective {
+	return protocol.PerspectiveClient
 }

+ 0 - 96
vendor/github.com/lucas-clemente/quic-go/client_multiplexer.go

@@ -1,96 +0,0 @@
-package quic
-
-import (
-	"bytes"
-	"net"
-	"strings"
-	"sync"
-	"time"
-
-	"github.com/lucas-clemente/quic-go/internal/protocol"
-	"github.com/lucas-clemente/quic-go/internal/utils"
-	"github.com/lucas-clemente/quic-go/internal/wire"
-)
-
-var (
-	clientMuxerOnce sync.Once
-	clientMuxer     *clientMultiplexer
-)
-
-// The clientMultiplexer listens on multiple net.PacketConns and dispatches
-// incoming packets to the session handler.
-type clientMultiplexer struct {
-	mutex sync.Mutex
-
-	conns map[net.PacketConn]packetHandlerManager
-
-	logger utils.Logger
-}
-
-func getClientMultiplexer() *clientMultiplexer {
-	clientMuxerOnce.Do(func() {
-		clientMuxer = &clientMultiplexer{
-			conns:  make(map[net.PacketConn]packetHandlerManager),
-			logger: utils.DefaultLogger.WithPrefix("client muxer"),
-		}
-	})
-	return clientMuxer
-}
-
-func (m *clientMultiplexer) Add(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) {
-	m.mutex.Lock()
-	defer m.mutex.Unlock()
-	sessions, ok := m.conns[c]
-	if !ok {
-		sessions = newPacketHandlerMap()
-		m.conns[c] = sessions
-	}
-	sessions.Add(connID, handler)
-	if ok {
-		return
-	}
-
-	// If we didn't know this packet conn before, listen for incoming packets
-	// and dispatch them to the right sessions.
-	go m.listen(c, sessions)
-}
-
-func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManager) {
-	for {
-		data := *getPacketBuffer()
-		data = data[:protocol.MaxReceivePacketSize]
-		// The packet size should not exceed protocol.MaxReceivePacketSize bytes
-		// If it does, we only read a truncated packet, which will then end up undecryptable
-		n, addr, err := c.ReadFrom(data)
-		if err != nil {
-			if !strings.HasSuffix(err.Error(), "use of closed network connection") {
-				sessions.Close(err)
-			}
-			return
-		}
-		data = data[:n]
-		rcvTime := time.Now()
-
-		r := bytes.NewReader(data)
-		hdr, err := wire.ParseHeaderSentByServer(r)
-		// drop the packet if we can't parse the header
-		if err != nil {
-			m.logger.Debugf("error parsing packet from %s: %s", addr, err)
-			continue
-		}
-		hdr.Raw = data[:len(data)-r.Len()]
-		packetData := data[len(data)-r.Len():]
-
-		client, ok := sessions.Get(hdr.DestConnectionID)
-		if !ok {
-			m.logger.Debugf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID)
-			continue
-		}
-		client.handlePacket(&receivedPacket{
-			remoteAddr: addr,
-			header:     hdr,
-			data:       packetData,
-			rcvTime:    rcvTime,
-		})
-	}
-}

+ 6 - 6
vendor/github.com/lucas-clemente/quic-go/crypto_stream.go

@@ -8,7 +8,7 @@ import (
 	"github.com/lucas-clemente/quic-go/internal/wire"
 )
 
-type cryptoStreamI interface {
+type cryptoStream interface {
 	StreamID() protocol.StreamID
 	io.Reader
 	io.Writer
@@ -21,21 +21,21 @@ type cryptoStreamI interface {
 	handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
 }
 
-type cryptoStream struct {
+type cryptoStreamImpl struct {
 	*stream
 }
 
-var _ cryptoStreamI = &cryptoStream{}
+var _ cryptoStream = &cryptoStreamImpl{}
 
-func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStreamI {
+func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStream {
 	str := newStream(version.CryptoStreamID(), sender, flowController, version)
-	return &cryptoStream{str}
+	return &cryptoStreamImpl{str}
 }
 
 // SetReadOffset sets the read offset.
 // It is only needed for the crypto stream.
 // It must not be called concurrently with any other stream methods, especially Read and Write.
-func (s *cryptoStream) setReadOffset(offset protocol.ByteCount) {
+func (s *cryptoStreamImpl) setReadOffset(offset protocol.ByteCount) {
 	s.receiveStream.readOffset = offset
 	s.receiveStream.frameQueue.readPosition = offset
 }

+ 20 - 4
vendor/github.com/lucas-clemente/quic-go/interface.go

@@ -16,8 +16,14 @@ type StreamID = protocol.StreamID
 // A VersionNumber is a QUIC version number.
 type VersionNumber = protocol.VersionNumber
 
-// VersionGQUIC39 is gQUIC version 39.
-const VersionGQUIC39 = protocol.Version39
+const (
+	// VersionGQUIC39 is gQUIC version 39.
+	VersionGQUIC39 = protocol.Version39
+	// VersionGQUIC42 is gQUIC version 42.
+	VersionGQUIC42 = protocol.Version42
+	// VersionGQUIC43 is gQUIC version 43.
+	VersionGQUIC43 = protocol.Version43
+)
 
 // A Cookie can be used to verify the ownership of the client address.
 type Cookie = handshake.Cookie
@@ -139,8 +145,11 @@ type Session interface {
 	LocalAddr() net.Addr
 	// RemoteAddr returns the address of the peer.
 	RemoteAddr() net.Addr
-	// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
-	Close(error) error
+	// Close the connection.
+	io.Closer
+	// Close the connection with an error.
+	// The error must not be nil.
+	CloseWithError(ErrorCode, error) error
 	// The context is cancelled when the session is closed.
 	// Warning: This API should not be considered stable and might change soon.
 	Context() context.Context
@@ -159,6 +168,13 @@ type Config struct {
 	// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
 	// Currently only valid for the client.
 	RequestConnectionIDOmission bool
+	// The length of the connection ID in bytes. Only valid for IETF QUIC.
+	// It can be 0, or any value between 4 and 18.
+	// If not set, the interpretation depends on where the Config is used:
+	// If used for dialing an address, a 0 byte connection ID will be used.
+	// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
+	// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
+	ConnectionIDLength int
 	// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
 	// If the timeout is exceeded, the connection is closed.
 	// If this value is zero, the timeout is set to 10 seconds.

+ 2 - 1
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go

@@ -29,7 +29,8 @@ type SentPacketHandler interface {
 
 	GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
 	GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
-	DequeuePacketForRetransmission() (packet *Packet)
+	DequeuePacketForRetransmission() *Packet
+	DequeueProbePacket() (*Packet, error)
 	GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
 
 	GetAlarmTimeout() time.Time

+ 39 - 33
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go

@@ -1,6 +1,7 @@
 package ackhandler
 
 import (
+	"errors"
 	"fmt"
 	"math"
 	"time"
@@ -373,41 +374,50 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight proto
 }
 
 func (h *sentPacketHandler) OnAlarm() error {
-	now := time.Now()
+	// When all outstanding are acknowledged, the alarm is canceled in
+	// updateLossDetectionAlarm. This doesn't reset the timer in the session though.
+	// When OnAlarm is called, we therefore need to make sure that there are
+	// actually packets outstanding.
+	if h.packetHistory.HasOutstandingPackets() {
+		if err := h.onVerifiedAlarm(); err != nil {
+			return err
+		}
+	}
+	h.updateLossDetectionAlarm()
+	return nil
+}
 
+func (h *sentPacketHandler) onVerifiedAlarm() error {
 	var err error
-	if !h.handshakeComplete {
+	if h.packetHistory.HasOutstandingHandshakePackets() {
 		if h.logger.Debug() {
-			h.logger.Debugf("Loss detection alarm fired in handshake mode")
+			h.logger.Debugf("Loss detection alarm fired in handshake mode. Handshake count: %d", h.handshakeCount)
 		}
 		h.handshakeCount++
 		err = h.queueHandshakePacketsForRetransmission()
 	} else if !h.lossTime.IsZero() {
 		if h.logger.Debug() {
-			h.logger.Debugf("Loss detection alarm fired in loss timer mode")
+			h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", h.lossTime)
 		}
 		// Early retransmit or time loss detection
-		err = h.detectLostPackets(now, h.bytesInFlight)
-	} else if h.tlpCount < maxTLPs {
+		err = h.detectLostPackets(time.Now(), h.bytesInFlight)
+	} else if h.tlpCount < maxTLPs { // TLP
 		if h.logger.Debug() {
-			h.logger.Debugf("Loss detection alarm fired in TLP mode")
+			h.logger.Debugf("Loss detection alarm fired in TLP mode. TLP count: %d", h.tlpCount)
 		}
 		h.allowTLP = true
 		h.tlpCount++
-	} else {
+	} else { // RTO
 		if h.logger.Debug() {
-			h.logger.Debugf("Loss detection alarm fired in RTO mode")
+			h.logger.Debugf("Loss detection alarm fired in RTO mode. RTO count: %d", h.rtoCount)
+		}
+		if h.rtoCount == 0 {
+			h.largestSentBeforeRTO = h.lastSentPacketNumber
 		}
-		// RTO
 		h.rtoCount++
 		h.numRTOs += 2
-		err = h.queueRTOs()
-	}
-	if err != nil {
-		return err
 	}
-	h.updateLossDetectionAlarm()
-	return nil
+	return err
 }
 
 func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
@@ -496,6 +506,19 @@ func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
 	return packet
 }
 
+func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
+	if len(h.retransmissionQueue) == 0 {
+		p := h.packetHistory.FirstOutstanding()
+		if p == nil {
+			return nil, errors.New("cannot dequeue a probe packet. No outstanding packets")
+		}
+		if err := h.queuePacketForRetransmission(p); err != nil {
+			return nil, err
+		}
+	}
+	return h.DequeuePacketForRetransmission(), nil
+}
+
 func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
 	return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
 }
@@ -559,23 +582,6 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int {
 	return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
 }
 
-// retransmit the oldest two packets
-func (h *sentPacketHandler) queueRTOs() error {
-	h.largestSentBeforeRTO = h.lastSentPacketNumber
-	// Queue the first two outstanding packets for retransmission.
-	// This does NOT declare this packets as lost:
-	// They are still tracked in the packet history and count towards the bytes in flight.
-	for i := 0; i < 2; i++ {
-		if p := h.packetHistory.FirstOutstanding(); p != nil {
-			h.logger.Debugf("Queueing packet %#x for retransmission (RTO)", p.PacketNumber)
-			if err := h.queuePacketForRetransmission(p); err != nil {
-				return err
-			}
-		}
-	}
-	return nil
-}
-
 func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
 	var handshakePackets []*Packet
 	h.packetHistory.Iterate(func(p *Packet) (bool, error) {

+ 2 - 2
vendor/github.com/lucas-clemente/quic-go/internal/crypto/key_derivation.go

@@ -15,7 +15,7 @@ const (
 
 // A TLSExporter gets the negotiated ciphersuite and computes exporter
 type TLSExporter interface {
-	GetCipherSuite() mint.CipherSuiteParams
+	ConnectionState() mint.ConnectionState
 	ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
 }
 
@@ -49,7 +49,7 @@ func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
 }
 
 func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) {
-	cs := tls.GetCipherSuite()
+	cs := tls.ConnectionState().CipherSuite
 	secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size())
 	if err != nil {
 		return nil, nil, err

+ 14 - 2
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/base_flow_controller.go

@@ -11,8 +11,9 @@ import (
 
 type baseFlowController struct {
 	// for sending data
-	bytesSent  protocol.ByteCount
-	sendWindow protocol.ByteCount
+	bytesSent     protocol.ByteCount
+	sendWindow    protocol.ByteCount
+	lastBlockedAt protocol.ByteCount
 
 	// for receiving data
 	mutex                sync.RWMutex
@@ -29,6 +30,17 @@ type baseFlowController struct {
 	logger utils.Logger
 }
 
+// IsNewlyBlocked says if it is newly blocked by flow control.
+// For every offset, it only returns true once.
+// If it is blocked, the offset is returned.
+func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
+	if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
+		return false, 0
+	}
+	c.lastBlockedAt = c.sendWindow
+	return true, c.sendWindow
+}
+
 func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
 	c.bytesSent += n
 }

+ 0 - 12
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/connection_flow_controller.go

@@ -10,7 +10,6 @@ import (
 )
 
 type connectionFlowController struct {
-	lastBlockedAt protocol.ByteCount
 	baseFlowController
 
 	queueWindowUpdate func()
@@ -43,17 +42,6 @@ func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
 	return c.baseFlowController.sendWindowSize()
 }
 
-// IsNewlyBlocked says if it is newly blocked by flow control.
-// For every offset, it only returns true once.
-// If it is blocked, the offset is returned.
-func (c *connectionFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
-	if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
-		return false, 0
-	}
-	c.lastBlockedAt = c.sendWindow
-	return true, c.sendWindow
-}
-
 // IncrementHighestReceived adds an increment to the highestReceived value
 func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
 	c.mutex.Lock()

+ 1 - 4
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/interface.go

@@ -11,13 +11,12 @@ type flowController interface {
 	AddBytesRead(protocol.ByteCount)
 	GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
 	MaybeQueueWindowUpdate()             //  queues a window update, if necessary
+	IsNewlyBlocked() (bool, protocol.ByteCount)
 }
 
 // A StreamFlowController is a flow controller for a QUIC stream.
 type StreamFlowController interface {
 	flowController
-	// for sending
-	IsBlocked() (bool, protocol.ByteCount)
 	// for receiving
 	// UpdateHighestReceived should be called when a new highest offset is received
 	// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
@@ -27,8 +26,6 @@ type StreamFlowController interface {
 // The ConnectionFlowController is the flow controller for the connection.
 type ConnectionFlowController interface {
 	flowController
-	// for sending
-	IsNewlyBlocked() (bool, protocol.ByteCount)
 }
 
 type connectionFlowControllerI interface {

+ 0 - 9
vendor/github.com/lucas-clemente/quic-go/internal/flowcontrol/stream_flow_controller.go

@@ -115,15 +115,6 @@ func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
 	return window
 }
 
-// IsBlocked says if it is blocked by stream-level flow control.
-// If it is blocked, the offset is returned.
-func (c *streamFlowController) IsBlocked() (bool, protocol.ByteCount) {
-	if c.sendWindowSize() != 0 {
-		return false, 0
-	}
-	return true, c.sendWindow
-}
-
 func (c *streamFlowController) MaybeQueueWindowUpdate() {
 	c.mutex.Lock()
 	hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()

+ 2 - 4
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_generator.go

@@ -5,8 +5,6 @@ import (
 	"fmt"
 	"net"
 	"time"
-
-	"github.com/bifurcation/mint"
 )
 
 const (
@@ -29,12 +27,12 @@ type token struct {
 
 // A CookieGenerator generates Cookies
 type CookieGenerator struct {
-	cookieProtector mint.CookieProtector
+	cookieProtector cookieProtector
 }
 
 // NewCookieGenerator initializes a new CookieGenerator
 func NewCookieGenerator() (*CookieGenerator, error) {
-	cookieProtector, err := mint.NewDefaultCookieProtector()
+	cookieProtector, err := newCookieProtector()
 	if err != nil {
 		return nil, err
 	}

+ 0 - 51
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_handler.go

@@ -1,51 +0,0 @@
-package handshake
-
-import (
-	"net"
-
-	"github.com/bifurcation/mint"
-	"github.com/lucas-clemente/quic-go/internal/utils"
-)
-
-// A CookieHandler generates and validates cookies.
-// The cookie is sent in the TLS Retry.
-// By including the cookie in its ClientHello, a client can proof ownership of its source address.
-type CookieHandler struct {
-	callback        func(net.Addr, *Cookie) bool
-	cookieGenerator *CookieGenerator
-
-	logger utils.Logger
-}
-
-var _ mint.CookieHandler = &CookieHandler{}
-
-// NewCookieHandler creates a new CookieHandler.
-func NewCookieHandler(callback func(net.Addr, *Cookie) bool, logger utils.Logger) (*CookieHandler, error) {
-	cookieGenerator, err := NewCookieGenerator()
-	if err != nil {
-		return nil, err
-	}
-	return &CookieHandler{
-		callback:        callback,
-		cookieGenerator: cookieGenerator,
-		logger:          logger,
-	}, nil
-}
-
-// Generate a new cookie for a mint connection.
-func (h *CookieHandler) Generate(conn *mint.Conn) ([]byte, error) {
-	if h.callback(conn.RemoteAddr(), nil) {
-		return nil, nil
-	}
-	return h.cookieGenerator.NewToken(conn.RemoteAddr())
-}
-
-// Validate a cookie.
-func (h *CookieHandler) Validate(conn *mint.Conn, token []byte) bool {
-	data, err := h.cookieGenerator.DecodeToken(token)
-	if err != nil {
-		h.logger.Debugf("Couldn't decode cookie from %s: %s", conn.RemoteAddr(), err.Error())
-		return false
-	}
-	return h.callback(conn.RemoteAddr(), data)
-}

+ 86 - 0
vendor/github.com/lucas-clemente/quic-go/internal/handshake/cookie_protector.go

@@ -0,0 +1,86 @@
+package handshake
+
+import (
+	"crypto/aes"
+	"crypto/cipher"
+	"crypto/rand"
+	"crypto/sha256"
+	"fmt"
+	"io"
+
+	"golang.org/x/crypto/hkdf"
+)
+
+// CookieProtector is used to create and verify a cookie
+type cookieProtector interface {
+	// NewToken creates a new token
+	NewToken([]byte) ([]byte, error)
+	// DecodeToken decodes a token
+	DecodeToken([]byte) ([]byte, error)
+}
+
+const (
+	cookieSecretSize = 32
+	cookieNonceSize  = 32
+)
+
+// cookieProtector is used to create and verify a cookie
+type cookieProtectorImpl struct {
+	secret []byte
+}
+
+// newCookieProtector creates a source for source address tokens
+func newCookieProtector() (cookieProtector, error) {
+	secret := make([]byte, cookieSecretSize)
+	if _, err := rand.Read(secret); err != nil {
+		return nil, err
+	}
+	return &cookieProtectorImpl{secret: secret}, nil
+}
+
+// NewToken encodes data into a new token.
+func (s *cookieProtectorImpl) NewToken(data []byte) ([]byte, error) {
+	nonce := make([]byte, cookieNonceSize)
+	if _, err := rand.Read(nonce); err != nil {
+		return nil, err
+	}
+	aead, aeadNonce, err := s.createAEAD(nonce)
+	if err != nil {
+		return nil, err
+	}
+	return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil
+}
+
+// DecodeToken decodes a token.
+func (s *cookieProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
+	if len(p) < cookieNonceSize {
+		return nil, fmt.Errorf("Token too short: %d", len(p))
+	}
+	nonce := p[:cookieNonceSize]
+	aead, aeadNonce, err := s.createAEAD(nonce)
+	if err != nil {
+		return nil, err
+	}
+	return aead.Open(nil, aeadNonce, p[cookieNonceSize:], nil)
+}
+
+func (s *cookieProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
+	h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go cookie source"))
+	key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
+	if _, err := io.ReadFull(h, key); err != nil {
+		return nil, nil, err
+	}
+	aeadNonce := make([]byte, 12)
+	if _, err := io.ReadFull(h, aeadNonce); err != nil {
+		return nil, nil, err
+	}
+	c, err := aes.NewCipher(key)
+	if err != nil {
+		return nil, nil, err
+	}
+	aead, err := cipher.NewGCM(c)
+	if err != nil {
+		return nil, nil, err
+	}
+	return aead, aeadNonce, nil
+}

+ 14 - 12
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_client.go

@@ -86,18 +86,20 @@ func NewCryptoSetupClient(
 	}
 	divNonceChan := make(chan struct{})
 	cs := &cryptoSetupClient{
-		cryptoStream:       cryptoStream,
-		hostname:           hostname,
-		connID:             connID,
-		version:            version,
-		certManager:        crypto.NewCertManager(tlsConfig),
-		params:             params,
-		keyDerivation:      crypto.DeriveQuicCryptoAESKeys,
-		nullAEAD:           nullAEAD,
-		paramsChan:         paramsChan,
-		handshakeEvent:     handshakeEvent,
-		initialVersion:     initialVersion,
-		negotiatedVersions: negotiatedVersions,
+		cryptoStream:   cryptoStream,
+		hostname:       hostname,
+		connID:         connID,
+		version:        version,
+		certManager:    crypto.NewCertManager(tlsConfig),
+		params:         params,
+		keyDerivation:  crypto.DeriveQuicCryptoAESKeys,
+		nullAEAD:       nullAEAD,
+		paramsChan:     paramsChan,
+		handshakeEvent: handshakeEvent,
+		initialVersion: initialVersion,
+		// The server might have sent greased versions in the Version Negotiation packet.
+		// We need strip those from the list, since they won't be included in the handshake tag.
+		negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
 		divNonceChan:       divNonceChan,
 		logger:             logger,
 	}

+ 25 - 29
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup_tls.go

@@ -11,9 +11,6 @@ import (
 	"github.com/lucas-clemente/quic-go/internal/protocol"
 )
 
-// ErrCloseSessionForRetry is returned by HandleCryptoStream when the server wishes to perform a stateless retry
-var ErrCloseSessionForRetry = errors.New("closing session in order to recreate after a retry")
-
 // KeyDerivationFunction is used for key derivation
 type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
 
@@ -26,8 +23,8 @@ type cryptoSetupTLS struct {
 	nullAEAD      crypto.AEAD
 	aead          crypto.AEAD
 
-	tls            MintTLS
-	cryptoStream   *CryptoStreamConn
+	tls            mintTLS
+	conn           *cryptoStreamConn
 	handshakeEvent chan<- struct{}
 }
 
@@ -35,39 +32,46 @@ var _ CryptoSetupTLS = &cryptoSetupTLS{}
 
 // NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
 func NewCryptoSetupTLSServer(
-	tls MintTLS,
-	cryptoStream *CryptoStreamConn,
-	nullAEAD crypto.AEAD,
+	cryptoStream io.ReadWriter,
+	connID protocol.ConnectionID,
+	config *mint.Config,
 	handshakeEvent chan<- struct{},
 	version protocol.VersionNumber,
-) CryptoSetupTLS {
+) (CryptoSetupTLS, error) {
+	nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
+	if err != nil {
+		return nil, err
+	}
+	conn := newCryptoStreamConn(cryptoStream)
+	tls := mint.Server(conn, config)
 	return &cryptoSetupTLS{
 		tls:            tls,
-		cryptoStream:   cryptoStream,
+		conn:           conn,
 		nullAEAD:       nullAEAD,
 		perspective:    protocol.PerspectiveServer,
 		keyDerivation:  crypto.DeriveAESKeys,
 		handshakeEvent: handshakeEvent,
-	}
+	}, nil
 }
 
 // NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
 func NewCryptoSetupTLSClient(
 	cryptoStream io.ReadWriter,
 	connID protocol.ConnectionID,
-	hostname string,
+	config *mint.Config,
 	handshakeEvent chan<- struct{},
-	tls MintTLS,
 	version protocol.VersionNumber,
 ) (CryptoSetupTLS, error) {
 	nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
 	if err != nil {
 		return nil, err
 	}
-
+	conn := newCryptoStreamConn(cryptoStream)
+	tls := mint.Client(conn, config)
 	return &cryptoSetupTLS{
-		perspective:    protocol.PerspectiveClient,
 		tls:            tls,
+		conn:           conn,
+		perspective:    protocol.PerspectiveClient,
 		nullAEAD:       nullAEAD,
 		keyDerivation:  crypto.DeriveAESKeys,
 		handshakeEvent: handshakeEvent,
@@ -75,24 +79,16 @@ func NewCryptoSetupTLSClient(
 }
 
 func (h *cryptoSetupTLS) HandleCryptoStream() error {
-	if h.perspective == protocol.PerspectiveServer {
-		// mint already wrote the ServerHello, EncryptedExtensions and the certificate chain to the buffer
-		// send out that data now
-		if _, err := h.cryptoStream.Flush(); err != nil {
-			return err
-		}
-	}
-
-handshakeLoop:
 	for {
 		if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
 			return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
 		}
-		switch h.tls.State() {
-		case mint.StateClientStart: // this happens if a stateless retry is performed
-			return ErrCloseSessionForRetry
-		case mint.StateClientConnected, mint.StateServerConnected:
-			break handshakeLoop
+		state := h.tls.ConnectionState().HandshakeState
+		if err := h.conn.Flush(); err != nil {
+			return err
+		}
+		if state == mint.StateClientConnected || state == mint.StateServerConnected {
+			break
 		}
 	}
 

+ 26 - 58
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_stream_conn.go

@@ -7,95 +7,63 @@ import (
 	"time"
 )
 
-// The CryptoStreamConn is used as the net.Conn passed to mint.
-// It has two operating modes:
-// 1. It can read and write to bytes.Buffers.
-// 2. It can use a quic.Stream for reading and writing.
-// The buffer-mode is only used by the server, in order to statelessly handle retries.
-type CryptoStreamConn struct {
-	remoteAddr net.Addr
-
-	// the buffers are used before the session is initialized
-	readBuf  bytes.Buffer
-	writeBuf bytes.Buffer
-
-	// stream will be set once the session is initialized
+type cryptoStreamConn struct {
+	buffer *bytes.Buffer
 	stream io.ReadWriter
 }
 
-var _ net.Conn = &CryptoStreamConn{}
-
-// NewCryptoStreamConn creates a new CryptoStreamConn
-func NewCryptoStreamConn(remoteAddr net.Addr) *CryptoStreamConn {
-	return &CryptoStreamConn{remoteAddr: remoteAddr}
-}
+var _ net.Conn = &cryptoStreamConn{}
 
-func (c *CryptoStreamConn) Read(b []byte) (int, error) {
-	if c.stream != nil {
-		return c.stream.Read(b)
+func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn {
+	return &cryptoStreamConn{
+		stream: stream,
+		buffer: &bytes.Buffer{},
 	}
-	return c.readBuf.Read(b)
 }
 
-// AddDataForReading adds data to the read buffer.
-// This data will ONLY be read when the stream has not been set.
-func (c *CryptoStreamConn) AddDataForReading(data []byte) {
-	c.readBuf.Write(data)
+func (c *cryptoStreamConn) Read(b []byte) (int, error) {
+	return c.stream.Read(b)
 }
 
-func (c *CryptoStreamConn) Write(p []byte) (int, error) {
-	if c.stream != nil {
-		return c.stream.Write(p)
-	}
-	return c.writeBuf.Write(p)
+func (c *cryptoStreamConn) Write(p []byte) (int, error) {
+	return c.buffer.Write(p)
 }
 
-// GetDataForWriting returns all data currently in the write buffer, and resets this buffer.
-func (c *CryptoStreamConn) GetDataForWriting() []byte {
-	defer c.writeBuf.Reset()
-	data := make([]byte, c.writeBuf.Len())
-	copy(data, c.writeBuf.Bytes())
-	return data
-}
-
-// SetStream sets the stream.
-// After setting the stream, the read and write buffer won't be used any more.
-func (c *CryptoStreamConn) SetStream(stream io.ReadWriter) {
-	c.stream = stream
-}
-
-// Flush copies the contents of the write buffer to the stream
-func (c *CryptoStreamConn) Flush() (int, error) {
-	n, err := io.Copy(c.stream, &c.writeBuf)
-	return int(n), err
+func (c *cryptoStreamConn) Flush() error {
+	if c.buffer.Len() == 0 {
+		return nil
+	}
+	_, err := c.stream.Write(c.buffer.Bytes())
+	c.buffer.Reset()
+	return err
 }
 
 // Close is not implemented
-func (c *CryptoStreamConn) Close() error {
+func (c *cryptoStreamConn) Close() error {
 	return nil
 }
 
 // LocalAddr is not implemented
-func (c *CryptoStreamConn) LocalAddr() net.Addr {
+func (c *cryptoStreamConn) LocalAddr() net.Addr {
 	return nil
 }
 
-// RemoteAddr returns the remote address
-func (c *CryptoStreamConn) RemoteAddr() net.Addr {
-	return c.remoteAddr
+// RemoteAddr is not implemented
+func (c *cryptoStreamConn) RemoteAddr() net.Addr {
+	return nil
 }
 
 // SetReadDeadline is not implemented
-func (c *CryptoStreamConn) SetReadDeadline(time.Time) error {
+func (c *cryptoStreamConn) SetReadDeadline(time.Time) error {
 	return nil
 }
 
 // SetWriteDeadline is not implemented
-func (c *CryptoStreamConn) SetWriteDeadline(time.Time) error {
+func (c *cryptoStreamConn) SetWriteDeadline(time.Time) error {
 	return nil
 }
 
 // SetDeadline is not implemented
-func (c *CryptoStreamConn) SetDeadline(time.Time) error {
+func (c *cryptoStreamConn) SetDeadline(time.Time) error {
 	return nil
 }

+ 6 - 13
vendor/github.com/lucas-clemente/quic-go/internal/handshake/interface.go

@@ -2,7 +2,6 @@ package handshake
 
 import (
 	"crypto/x509"
-	"io"
 
 	"github.com/bifurcation/mint"
 	"github.com/lucas-clemente/quic-go/internal/crypto"
@@ -15,6 +14,12 @@ type Sealer interface {
 	Overhead() int
 }
 
+// mintTLS combines some methods needed to interact with mint.
+type mintTLS interface {
+	crypto.TLSExporter
+	Handshake() mint.Alert
+}
+
 // A TLSExtensionHandler sends and received the QUIC TLS extension.
 // It provides the parameters sent by the peer on a channel.
 type TLSExtensionHandler interface {
@@ -23,18 +28,6 @@ type TLSExtensionHandler interface {
 	GetPeerParams() <-chan TransportParameters
 }
 
-// MintTLS combines some methods needed to interact with mint.
-type MintTLS interface {
-	crypto.TLSExporter
-
-	// additional methods
-	Handshake() mint.Alert
-	State() mint.State
-	ConnectionState() mint.ConnectionState
-
-	SetCryptoStream(io.ReadWriter)
-}
-
 type baseCryptoSetup interface {
 	HandleCryptoStream() error
 	ConnectionState() ConnectionState

+ 3 - 0
vendor/github.com/lucas-clemente/quic-go/internal/handshake/mockgen.go

@@ -0,0 +1,3 @@
+package handshake
+
+//go:generate sh -c "../mockgen_internal.sh handshake mock_mint_tls_test.go github.com/lucas-clemente/quic-go/internal/handshake mintTLS"

+ 78 - 9
vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension.go

@@ -1,7 +1,14 @@
 package handshake
 
 import (
+	"bytes"
+	"encoding/binary"
+	"errors"
+	"fmt"
+
 	"github.com/bifurcation/mint"
+	"github.com/lucas-clemente/quic-go/internal/protocol"
+	"github.com/lucas-clemente/quic-go/internal/utils"
 )
 
 type transportParameterID uint16
@@ -16,22 +23,84 @@ const (
 	maxPacketSizeParameterID         transportParameterID = 0x5
 	statelessResetTokenParameterID   transportParameterID = 0x6
 	initialMaxUniStreamsParameterID  transportParameterID = 0x8
+	disableMigrationParameterID      transportParameterID = 0x9
 )
 
-type transportParameter struct {
-	Parameter transportParameterID
-	Value     []byte `tls:"head=2"`
+type clientHelloTransportParameters struct {
+	InitialVersion protocol.VersionNumber
+	Parameters     TransportParameters
 }
 
-type clientHelloTransportParameters struct {
-	InitialVersion uint32               // actually a protocol.VersionNumber
-	Parameters     []transportParameter `tls:"head=2"`
+func (p *clientHelloTransportParameters) Marshal() []byte {
+	const lenOffset = 4
+	b := &bytes.Buffer{}
+	utils.BigEndian.WriteUint32(b, uint32(p.InitialVersion))
+	b.Write([]byte{0, 0}) // length. Will be replaced later
+	p.Parameters.marshal(b)
+	data := b.Bytes()
+	binary.BigEndian.PutUint16(data[lenOffset:lenOffset+2], uint16(len(data)-lenOffset-2))
+	return data
+}
+
+func (p *clientHelloTransportParameters) Unmarshal(data []byte) error {
+	if len(data) < 6 {
+		return errors.New("transport parameter data too short")
+	}
+	p.InitialVersion = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
+	paramsLen := int(binary.BigEndian.Uint16(data[4:6]))
+	data = data[6:]
+	if len(data) != paramsLen {
+		return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
+	}
+	return p.Parameters.unmarshal(data)
 }
 
 type encryptedExtensionsTransportParameters struct {
-	NegotiatedVersion uint32               // actually a protocol.VersionNumber
-	SupportedVersions []uint32             `tls:"head=1"` // actually a protocol.VersionNumber
-	Parameters        []transportParameter `tls:"head=2"`
+	NegotiatedVersion protocol.VersionNumber
+	SupportedVersions []protocol.VersionNumber
+	Parameters        TransportParameters
+}
+
+func (p *encryptedExtensionsTransportParameters) Marshal() []byte {
+	b := &bytes.Buffer{}
+	utils.BigEndian.WriteUint32(b, uint32(p.NegotiatedVersion))
+	b.WriteByte(uint8(4 * len(p.SupportedVersions)))
+	for _, v := range p.SupportedVersions {
+		utils.BigEndian.WriteUint32(b, uint32(v))
+	}
+	lenOffset := b.Len()
+	b.Write([]byte{0, 0}) // length. Will be replaced later
+	p.Parameters.marshal(b)
+	data := b.Bytes()
+	binary.BigEndian.PutUint16(data[lenOffset:lenOffset+2], uint16(len(data)-lenOffset-2))
+	return data
+}
+
+func (p *encryptedExtensionsTransportParameters) Unmarshal(data []byte) error {
+	if len(data) < 5 {
+		return errors.New("transport parameter data too short")
+	}
+	p.NegotiatedVersion = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
+	numVersions := int(data[4])
+	if numVersions%4 != 0 {
+		return fmt.Errorf("invalid length for version list: %d", numVersions)
+	}
+	numVersions /= 4
+	data = data[5:]
+	if len(data) < 4*numVersions+2 /*length field for the parameter list */ {
+		return errors.New("transport parameter data too short")
+	}
+	p.SupportedVersions = make([]protocol.VersionNumber, numVersions)
+	for i := 0; i < numVersions; i++ {
+		p.SupportedVersions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
+		data = data[4:]
+	}
+	paramsLen := int(binary.BigEndian.Uint16(data[:2]))
+	data = data[2:]
+	if len(data) != paramsLen {
+		return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
+	}
+	return p.Parameters.unmarshal(data)
 }
 
 type tlsExtensionBody struct {

+ 12 - 36
vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_client.go

@@ -7,7 +7,6 @@ import (
 	"github.com/lucas-clemente/quic-go/qerr"
 
 	"github.com/bifurcation/mint"
-	"github.com/bifurcation/mint/syntax"
 	"github.com/lucas-clemente/quic-go/internal/protocol"
 	"github.com/lucas-clemente/quic-go/internal/utils"
 )
@@ -52,16 +51,12 @@ func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.Extensi
 	if hType != mint.HandshakeTypeClientHello {
 		return nil
 	}
-
 	h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
-	data, err := syntax.Marshal(clientHelloTransportParameters{
-		InitialVersion: uint32(h.initialVersion),
-		Parameters:     h.ourParams.getTransportParameters(),
-	})
-	if err != nil {
-		return err
+	chtp := &clientHelloTransportParameters{
+		InitialVersion: h.initialVersion,
+		Parameters:     *h.ourParams,
 	}
-	return el.Add(&tlsExtensionBody{data})
+	return el.Add(&tlsExtensionBody{data: chtp.Marshal()})
 }
 
 func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
@@ -84,50 +79,31 @@ func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.Exte
 	}
 
 	eetp := &encryptedExtensionsTransportParameters{}
-	if _, err := syntax.Unmarshal(ext.data, eetp); err != nil {
+	if err := eetp.Unmarshal(ext.data); err != nil {
 		return err
 	}
-	serverSupportedVersions := make([]protocol.VersionNumber, len(eetp.SupportedVersions))
-	for i, v := range eetp.SupportedVersions {
-		serverSupportedVersions[i] = protocol.VersionNumber(v)
-	}
 	// check that the negotiated_version is the current version
-	if protocol.VersionNumber(eetp.NegotiatedVersion) != h.version {
+	if eetp.NegotiatedVersion != h.version {
 		return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version")
 	}
 	// check that the current version is included in the supported versions
-	if !protocol.IsSupportedVersion(serverSupportedVersions, h.version) {
+	if !protocol.IsSupportedVersion(eetp.SupportedVersions, h.version) {
 		return qerr.Error(qerr.VersionNegotiationMismatch, "current version not included in the supported versions")
 	}
 	// if version negotiation was performed, check that we would have selected the current version based on the supported versions sent by the server
 	if h.version != h.initialVersion {
-		negotiatedVersion, ok := protocol.ChooseSupportedVersion(h.supportedVersions, serverSupportedVersions)
+		negotiatedVersion, ok := protocol.ChooseSupportedVersion(h.supportedVersions, eetp.SupportedVersions)
 		if !ok || h.version != negotiatedVersion {
 			return qerr.Error(qerr.VersionNegotiationMismatch, "would have picked a different version")
 		}
 	}
 
-	// check that the server sent the stateless reset token
-	var foundStatelessResetToken bool
-	for _, p := range eetp.Parameters {
-		if p.Parameter == statelessResetTokenParameterID {
-			if len(p.Value) != 16 {
-				return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", len(p.Value))
-			}
-			foundStatelessResetToken = true
-			// TODO: handle this value
-		}
-	}
-	if !foundStatelessResetToken {
-		// TODO: return the right error here
+	// check that the server sent a stateless reset token
+	if len(eetp.Parameters.StatelessResetToken) == 0 {
 		return errors.New("server didn't sent stateless_reset_token")
 	}
-	params, err := readTransportParameters(eetp.Parameters)
-	if err != nil {
-		return err
-	}
-	h.logger.Debugf("Received Transport Parameters: %s", params)
-	h.paramsChan <- *params
+	h.logger.Debugf("Received Transport Parameters: %s", &eetp.Parameters)
+	h.paramsChan <- eetp.Parameters
 	return nil
 }
 

+ 13 - 35
vendor/github.com/lucas-clemente/quic-go/internal/handshake/tls_extension_handler_server.go

@@ -1,14 +1,12 @@
 package handshake
 
 import (
-	"bytes"
 	"errors"
 	"fmt"
 
 	"github.com/lucas-clemente/quic-go/qerr"
 
 	"github.com/bifurcation/mint"
-	"github.com/bifurcation/mint/syntax"
 	"github.com/lucas-clemente/quic-go/internal/protocol"
 	"github.com/lucas-clemente/quic-go/internal/utils"
 )
@@ -49,27 +47,13 @@ func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.Extensi
 	if hType != mint.HandshakeTypeEncryptedExtensions {
 		return nil
 	}
-
-	transportParams := append(
-		h.ourParams.getTransportParameters(),
-		// TODO(#855): generate a real token
-		transportParameter{statelessResetTokenParameterID, bytes.Repeat([]byte{42}, 16)},
-	)
-	supportedVersions := protocol.GetGreasedVersions(h.supportedVersions)
-	versions := make([]uint32, len(supportedVersions))
-	for i, v := range supportedVersions {
-		versions[i] = uint32(v)
-	}
 	h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
-	data, err := syntax.Marshal(encryptedExtensionsTransportParameters{
-		NegotiatedVersion: uint32(h.version),
-		SupportedVersions: versions,
-		Parameters:        transportParams,
-	})
-	if err != nil {
-		return err
+	eetp := &encryptedExtensionsTransportParameters{
+		NegotiatedVersion: h.version,
+		SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions),
+		Parameters:        *h.ourParams,
 	}
-	return el.Add(&tlsExtensionBody{data})
+	return el.Add(&tlsExtensionBody{data: eetp.Marshal()})
 }
 
 func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
@@ -90,30 +74,24 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte
 		return errors.New("ClientHello didn't contain a QUIC extension")
 	}
 	chtp := &clientHelloTransportParameters{}
-	if _, err := syntax.Unmarshal(ext.data, chtp); err != nil {
+	if err := chtp.Unmarshal(ext.data); err != nil {
 		return err
 	}
-	initialVersion := protocol.VersionNumber(chtp.InitialVersion)
 
 	// perform the stateless version negotiation validation:
 	// make sure that we would have sent a Version Negotiation Packet if the client offered the initial version
 	// this is the case if and only if the initial version is not contained in the supported versions
-	if initialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, initialVersion) {
+	if chtp.InitialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, chtp.InitialVersion) {
 		return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version")
 	}
 
-	for _, p := range chtp.Parameters {
-		if p.Parameter == statelessResetTokenParameterID {
-			// TODO: return the correct error type
-			return errors.New("client sent a stateless reset token")
-		}
-	}
-	params, err := readTransportParameters(chtp.Parameters)
-	if err != nil {
-		return err
+	// check that the client didn't send a stateless reset token
+	if len(chtp.Parameters.StatelessResetToken) != 0 {
+		// TODO: return the correct error type
+		return errors.New("client sent a stateless reset token")
 	}
-	h.logger.Debugf("Received Transport Parameters: %s", params)
-	h.paramsChan <- *params
+	h.logger.Debugf("Received Transport Parameters: %s", &chtp.Parameters)
+	h.paramsChan <- chtp.Parameters
 	return nil
 }
 

+ 84 - 58
vendor/github.com/lucas-clemente/quic-go/internal/handshake/transport_parameters.go

@@ -26,8 +26,10 @@ type TransportParameters struct {
 	MaxBidiStreams uint16 // only used for IETF QUIC
 	MaxStreams     uint32 // only used for gQUIC
 
-	OmitConnectionID bool // only used for gQUIC
-	IdleTimeout      time.Duration
+	OmitConnectionID    bool // only used for gQUIC
+	IdleTimeout         time.Duration
+	DisableMigration    bool   // only used for IETF QUIC
+	StatelessResetToken []byte // only used for IETF QUIC
 }
 
 // readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message
@@ -94,86 +96,110 @@ func (p *TransportParameters) getHelloMap() map[Tag][]byte {
 	return tags
 }
 
-// readTransportParameters reads the transport parameters sent in the QUIC TLS extension
-func readTransportParameters(paramsList []transportParameter) (*TransportParameters, error) {
-	params := &TransportParameters{}
-
-	var foundInitialMaxStreamData bool
-	var foundInitialMaxData bool
+func (p *TransportParameters) unmarshal(data []byte) error {
 	var foundIdleTimeout bool
 
-	for _, p := range paramsList {
-		switch p.Parameter {
+	for len(data) >= 4 {
+		paramID := binary.BigEndian.Uint16(data[:2])
+		paramLen := int(binary.BigEndian.Uint16(data[2:4]))
+		data = data[4:]
+		if len(data) < paramLen {
+			return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(data), paramLen)
+		}
+		switch transportParameterID(paramID) {
 		case initialMaxStreamDataParameterID:
-			foundInitialMaxStreamData = true
-			if len(p.Value) != 4 {
-				return nil, fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", len(p.Value))
+			if paramLen != 4 {
+				return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", paramLen)
 			}
-			params.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
+			p.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
 		case initialMaxDataParameterID:
-			foundInitialMaxData = true
-			if len(p.Value) != 4 {
-				return nil, fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", len(p.Value))
+			if paramLen != 4 {
+				return fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", paramLen)
 			}
-			params.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(p.Value))
+			p.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
 		case initialMaxBidiStreamsParameterID:
-			if len(p.Value) != 2 {
-				return nil, fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", len(p.Value))
+			if paramLen != 2 {
+				return fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", paramLen)
 			}
-			params.MaxBidiStreams = binary.BigEndian.Uint16(p.Value)
+			p.MaxBidiStreams = binary.BigEndian.Uint16(data[:2])
 		case initialMaxUniStreamsParameterID:
-			if len(p.Value) != 2 {
-				return nil, fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", len(p.Value))
+			if paramLen != 2 {
+				return fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", paramLen)
 			}
-			params.MaxUniStreams = binary.BigEndian.Uint16(p.Value)
+			p.MaxUniStreams = binary.BigEndian.Uint16(data[:2])
 		case idleTimeoutParameterID:
 			foundIdleTimeout = true
-			if len(p.Value) != 2 {
-				return nil, fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", len(p.Value))
+			if paramLen != 2 {
+				return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", paramLen)
 			}
-			params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(p.Value))*time.Second)
+			p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(data[:2]))*time.Second)
 		case maxPacketSizeParameterID:
-			if len(p.Value) != 2 {
-				return nil, fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", len(p.Value))
+			if paramLen != 2 {
+				return fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", paramLen)
 			}
-			maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(p.Value))
+			maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(data[:2]))
 			if maxPacketSize < 1200 {
-				return nil, fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize)
+				return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize)
+			}
+			p.MaxPacketSize = maxPacketSize
+		case disableMigrationParameterID:
+			if paramLen != 0 {
+				return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen)
 			}
-			params.MaxPacketSize = maxPacketSize
+			p.DisableMigration = true
+		case statelessResetTokenParameterID:
+			if paramLen != 16 {
+				return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
+			}
+			p.StatelessResetToken = data[:16]
 		}
+		data = data[paramLen:]
 	}
 
-	if !(foundInitialMaxStreamData && foundInitialMaxData && foundIdleTimeout) {
-		return nil, errors.New("missing parameter")
+	if len(data) != 0 {
+		return fmt.Errorf("should have read all data. Still have %d bytes", len(data))
 	}
-	return params, nil
+	if !foundIdleTimeout {
+		return errors.New("missing parameter")
+	}
+	return nil
 }
 
-// GetTransportParameters gets the parameters needed for the TLS handshake.
-// It doesn't send the initial_max_stream_id_uni parameter, so the peer isn't allowed to open any unidirectional streams.
-func (p *TransportParameters) getTransportParameters() []transportParameter {
-	initialMaxStreamData := make([]byte, 4)
-	binary.BigEndian.PutUint32(initialMaxStreamData, uint32(p.StreamFlowControlWindow))
-	initialMaxData := make([]byte, 4)
-	binary.BigEndian.PutUint32(initialMaxData, uint32(p.ConnectionFlowControlWindow))
-	initialMaxBidiStreamID := make([]byte, 2)
-	binary.BigEndian.PutUint16(initialMaxBidiStreamID, p.MaxBidiStreams)
-	initialMaxUniStreamID := make([]byte, 2)
-	binary.BigEndian.PutUint16(initialMaxUniStreamID, p.MaxUniStreams)
-	idleTimeout := make([]byte, 2)
-	binary.BigEndian.PutUint16(idleTimeout, uint16(p.IdleTimeout/time.Second))
-	maxPacketSize := make([]byte, 2)
-	binary.BigEndian.PutUint16(maxPacketSize, uint16(protocol.MaxReceivePacketSize))
-	params := []transportParameter{
-		{initialMaxStreamDataParameterID, initialMaxStreamData},
-		{initialMaxDataParameterID, initialMaxData},
-		{initialMaxBidiStreamsParameterID, initialMaxBidiStreamID},
-		{initialMaxUniStreamsParameterID, initialMaxUniStreamID},
-		{idleTimeoutParameterID, idleTimeout},
-		{maxPacketSizeParameterID, maxPacketSize},
+func (p *TransportParameters) marshal(b *bytes.Buffer) {
+	// initial_max_stream_data
+	utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataParameterID))
+	utils.BigEndian.WriteUint16(b, 4)
+	utils.BigEndian.WriteUint32(b, uint32(p.StreamFlowControlWindow))
+	// initial_max_data
+	utils.BigEndian.WriteUint16(b, uint16(initialMaxDataParameterID))
+	utils.BigEndian.WriteUint16(b, 4)
+	utils.BigEndian.WriteUint32(b, uint32(p.ConnectionFlowControlWindow))
+	// initial_max_bidi_streams
+	utils.BigEndian.WriteUint16(b, uint16(initialMaxBidiStreamsParameterID))
+	utils.BigEndian.WriteUint16(b, 2)
+	utils.BigEndian.WriteUint16(b, p.MaxBidiStreams)
+	// initial_max_uni_streams
+	utils.BigEndian.WriteUint16(b, uint16(initialMaxUniStreamsParameterID))
+	utils.BigEndian.WriteUint16(b, 2)
+	utils.BigEndian.WriteUint16(b, p.MaxUniStreams)
+	// idle_timeout
+	utils.BigEndian.WriteUint16(b, uint16(idleTimeoutParameterID))
+	utils.BigEndian.WriteUint16(b, 2)
+	utils.BigEndian.WriteUint16(b, uint16(p.IdleTimeout/time.Second))
+	// max_packet_size
+	utils.BigEndian.WriteUint16(b, uint16(maxPacketSizeParameterID))
+	utils.BigEndian.WriteUint16(b, 2)
+	utils.BigEndian.WriteUint16(b, uint16(protocol.MaxReceivePacketSize))
+	// disable_migration
+	if p.DisableMigration {
+		utils.BigEndian.WriteUint16(b, uint16(disableMigrationParameterID))
+		utils.BigEndian.WriteUint16(b, 0)
+	}
+	if len(p.StatelessResetToken) > 0 {
+		utils.BigEndian.WriteUint16(b, uint16(statelessResetTokenParameterID))
+		utils.BigEndian.WriteUint16(b, uint16(len(p.StatelessResetToken))) // should always be 16 bytes
+		b.Write(p.StatelessResetToken)
 	}
-	return params
 }
 
 // String returns a string representation, intended for logging.

+ 15 - 2
vendor/github.com/lucas-clemente/quic-go/internal/protocol/connection_id.go

@@ -10,15 +10,28 @@ import (
 // A ConnectionID in QUIC
 type ConnectionID []byte
 
+const maxConnectionIDLen = 18
+
 // GenerateConnectionID generates a connection ID using cryptographic random
-func GenerateConnectionID() (ConnectionID, error) {
-	b := make([]byte, ConnectionIDLen)
+func GenerateConnectionID(len int) (ConnectionID, error) {
+	b := make([]byte, len)
 	if _, err := rand.Read(b); err != nil {
 		return nil, err
 	}
 	return ConnectionID(b), nil
 }
 
+// GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
+// It uses a length randomly chosen between 8 and 18 bytes.
+func GenerateConnectionIDForInitial() (ConnectionID, error) {
+	r := make([]byte, 1)
+	if _, err := rand.Read(r); err != nil {
+		return nil, err
+	}
+	len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
+	return GenerateConnectionID(len)
+}
+
 // ReadConnectionID reads a connection ID of length len from the given io.Reader.
 // It returns io.EOF if there are not enough bytes to read.
 func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) {

+ 5 - 0
vendor/github.com/lucas-clemente/quic-go/internal/protocol/perspective.go

@@ -9,6 +9,11 @@ const (
 	PerspectiveClient Perspective = 2
 )
 
+// Opposite returns the perspective of the peer
+func (p Perspective) Opposite() Perspective {
+	return 3 - p
+}
+
 func (p Perspective) String() string {
 	switch p {
 	case PerspectiveServer:

+ 6 - 0
vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go

@@ -82,3 +82,9 @@ const MinInitialPacketSize = 1200
 // * one failure due to an incorrect or missing source-address token
 // * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
 const MaxClientHellos = 3
+
+// ConnectionIDLenGQUIC is the length of the source Connection ID used on gQUIC QUIC packets.
+const ConnectionIDLenGQUIC = 8
+
+// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
+const MinConnectionIDLenInitial = 8

+ 6 - 8
vendor/github.com/lucas-clemente/quic-go/internal/protocol/server_parameters.go

@@ -146,11 +146,9 @@ const MaxAckFrameSize ByteCount = 1000
 // Example: For a packet pacing delay of 20 microseconds, we would send 5 packets at once, wait for 100 microseconds, and so forth.
 const MinPacingDelay time.Duration = 100 * time.Microsecond
 
-// ConnectionIDLen is the length of the source Connection ID used on IETF QUIC packets.
-// The Short Header contains the connection ID, but not the length,
-// so we need to know this value in advance (or encode it into the connection ID).
-// TODO: make this configurable
-const ConnectionIDLen = 8
-
-// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
-const MinConnectionIDLenInitial = 8
+// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections
+// if no other value is configured.
+const DefaultConnectionIDLength = 4
+
+// MaxRetries is the maximum number of Retries a client will do before failing the connection.
+const MaxRetries = 3

+ 11 - 0
vendor/github.com/lucas-clemente/quic-go/internal/protocol/version.go

@@ -152,3 +152,14 @@ func GetGreasedVersions(supported []VersionNumber) []VersionNumber {
 	copy(greased[randPos+1:], supported[randPos:])
 	return greased
 }
+
+// StripGreasedVersions strips all greased versions from a slice of versions
+func StripGreasedVersions(versions []VersionNumber) []VersionNumber {
+	realVersions := make([]VersionNumber, 0, len(versions))
+	for _, v := range versions {
+		if v&0x0f0f0f0f != 0x0a0a0a0a {
+			realVersions = append(realVersions, v)
+		}
+	}
+	return realVersions
+}

+ 18 - 1
vendor/github.com/lucas-clemente/quic-go/internal/wire/ack_frame.go

@@ -19,8 +19,16 @@ type AckFrame struct {
 	DelayTime time.Duration
 }
 
-// parseAckFrame reads an ACK frame
 func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) {
+	return parseAckOrAckEcnFrame(r, false, version)
+}
+
+func parseAckEcnFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, error) {
+	return parseAckOrAckEcnFrame(r, true, version)
+}
+
+// parseAckFrame reads an ACK frame
+func parseAckOrAckEcnFrame(r *bytes.Reader, ecn bool, version protocol.VersionNumber) (*AckFrame, error) {
 	if !version.UsesIETFFrameFormat() {
 		return parseAckFrameLegacy(r, version)
 	}
@@ -41,6 +49,15 @@ func parseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame,
 		return nil, err
 	}
 	frame.DelayTime = time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
+
+	if ecn {
+		for i := 0; i < 3; i++ {
+			if _, err := utils.ReadVarInt(r); err != nil {
+				return nil, err
+			}
+		}
+	}
+
 	numBlocks, err := utils.ReadVarInt(r)
 	if err != nil {
 		return nil, err

+ 5 - 0
vendor/github.com/lucas-clemente/quic-go/internal/wire/frame_parser.go

@@ -100,6 +100,11 @@ func parseIETFFrame(r *bytes.Reader, typeByte byte, v protocol.VersionNumber) (F
 		if err != nil {
 			err = qerr.Error(qerr.InvalidFrameData, err.Error())
 		}
+	case 0x1a:
+		frame, err = parseAckEcnFrame(r, v)
+		if err != nil {
+			err = qerr.Error(qerr.InvalidAckData, err.Error())
+		}
 	default:
 		err = qerr.Error(qerr.InvalidFrameData, fmt.Sprintf("unknown type byte 0x%x", typeByte))
 	}

+ 219 - 48
vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go

@@ -2,6 +2,9 @@ package wire
 
 import (
 	"bytes"
+	"crypto/rand"
+	"errors"
+	"fmt"
 
 	"github.com/lucas-clemente/quic-go/internal/protocol"
 	"github.com/lucas-clemente/quic-go/internal/utils"
@@ -16,8 +19,9 @@ type Header struct {
 
 	Version protocol.VersionNumber
 
-	DestConnectionID protocol.ConnectionID
-	SrcConnectionID  protocol.ConnectionID
+	DestConnectionID     protocol.ConnectionID
+	SrcConnectionID      protocol.ConnectionID
+	OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet
 
 	PacketNumberLen protocol.PacketNumberLen
 	PacketNumber    protocol.PacketNumber
@@ -35,74 +39,175 @@ type Header struct {
 	IsLongHeader bool
 	KeyPhase     int
 	PayloadLen   protocol.ByteCount
+	Token        []byte
 }
 
-// ParseHeaderSentByServer parses the header for a packet that was sent by the server.
-func ParseHeaderSentByServer(b *bytes.Reader) (*Header, error) {
-	typeByte, err := b.ReadByte()
-	if err != nil {
-		return nil, err
-	}
-	_ = b.UnreadByte() // unread the type byte
+var errInvalidPacketNumberLen6 = errors.New("invalid packet number length: 6 bytes")
 
-	var isPublicHeader bool
-	if typeByte&0x80 > 0 { // gQUIC always has 0x80 unset. IETF Long Header or Version Negotiation
-		isPublicHeader = false
-	} else {
-		// gQUIC never uses 6 byte packet numbers, so the third and fourth bit will never be 11
-		isPublicHeader = typeByte&0x30 != 0x30
+// Write writes the Header.
+func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error {
+	if !version.UsesTLS() {
+		h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
+		return h.writePublicHeader(b, pers, version)
+	}
+	// write an IETF QUIC header
+	if h.IsLongHeader {
+		return h.writeLongHeader(b)
 	}
-	return parsePacketHeader(b, protocol.PerspectiveServer, isPublicHeader)
+	return h.writeShortHeader(b)
 }
 
-// ParseHeaderSentByClient parses the header for a packet that was sent by the client.
-func ParseHeaderSentByClient(b *bytes.Reader) (*Header, error) {
-	typeByte, err := b.ReadByte()
+// TODO: add support for the key phase
+func (h *Header) writeLongHeader(b *bytes.Buffer) error {
+	b.WriteByte(byte(0x80 | h.Type))
+	utils.BigEndian.WriteUint32(b, uint32(h.Version))
+	connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
 	if err != nil {
-		return nil, err
-	}
-	_ = b.UnreadByte() // unread the type byte
-
-	// In an IETF QUIC packet header
-	// * either 0x80 is set (for the Long Header)
-	// * or 0x8 is unset (for the Short Header)
-	// In a gQUIC Public Header
-	// * 0x80 is always unset and
-	// * and 0x8 is always set (this is the Connection ID flag, which the client always sets)
-	isPublicHeader := typeByte&0x88 == 0x8
-	return parsePacketHeader(b, protocol.PerspectiveClient, isPublicHeader)
-}
+		return err
+	}
+	b.WriteByte(connIDLen)
+	b.Write(h.DestConnectionID.Bytes())
+	b.Write(h.SrcConnectionID.Bytes())
 
-func parsePacketHeader(b *bytes.Reader, sentBy protocol.Perspective, isPublicHeader bool) (*Header, error) {
-	// This is a gQUIC Public Header.
-	if isPublicHeader {
-		hdr, err := parsePublicHeader(b, sentBy)
+	if h.Type == protocol.PacketTypeInitial {
+		utils.WriteVarInt(b, uint64(len(h.Token)))
+		b.Write(h.Token)
+	}
+
+	if h.Type == protocol.PacketTypeRetry {
+		odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
 		if err != nil {
-			return nil, err
+			return err
 		}
-		hdr.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
-		return hdr, nil
+		// randomize the first 4 bits
+		odcilByte := make([]byte, 1)
+		_, _ = rand.Read(odcilByte) // it's safe to ignore the error here
+		odcilByte[0] = (odcilByte[0] & 0xf0) | odcil
+		b.Write(odcilByte)
+		b.Write(h.OrigDestConnectionID.Bytes())
+		b.Write(h.Token)
+		return nil
 	}
-	return parseHeader(b)
+
+	utils.WriteVarInt(b, uint64(h.PayloadLen))
+	return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
 }
 
-// Write writes the Header.
-func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, version protocol.VersionNumber) error {
-	if !version.UsesTLS() {
-		h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
-		return h.writePublicHeader(b, pers, version)
+func (h *Header) writeShortHeader(b *bytes.Buffer) error {
+	typeByte := byte(0x30)
+	typeByte |= byte(h.KeyPhase << 6)
+	b.WriteByte(typeByte)
+
+	b.Write(h.DestConnectionID.Bytes())
+	return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
+}
+
+// writePublicHeader writes a Public Header.
+func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error {
+	if h.ResetFlag || (h.VersionFlag && pers == protocol.PerspectiveServer) {
+		return errors.New("PublicHeader: Can only write regular packets")
+	}
+	if h.SrcConnectionID.Len() != 0 {
+		return errors.New("PublicHeader: SrcConnectionID must not be set")
+	}
+	if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 {
+		return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID))
+	}
+
+	publicFlagByte := uint8(0x00)
+	if h.VersionFlag {
+		publicFlagByte |= 0x01
+	}
+	if h.DestConnectionID.Len() > 0 {
+		publicFlagByte |= 0x08
+	}
+	if len(h.DiversificationNonce) > 0 {
+		if len(h.DiversificationNonce) != 32 {
+			return errors.New("invalid diversification nonce length")
+		}
+		publicFlagByte |= 0x04
+	}
+	switch h.PacketNumberLen {
+	case protocol.PacketNumberLen1:
+		publicFlagByte |= 0x00
+	case protocol.PacketNumberLen2:
+		publicFlagByte |= 0x10
+	case protocol.PacketNumberLen4:
+		publicFlagByte |= 0x20
+	}
+	b.WriteByte(publicFlagByte)
+
+	if h.DestConnectionID.Len() > 0 {
+		b.Write(h.DestConnectionID)
+	}
+	if h.VersionFlag && pers == protocol.PerspectiveClient {
+		utils.BigEndian.WriteUint32(b, uint32(h.Version))
+	}
+	if len(h.DiversificationNonce) > 0 {
+		b.Write(h.DiversificationNonce)
+	}
+
+	switch h.PacketNumberLen {
+	case protocol.PacketNumberLen1:
+		b.WriteByte(uint8(h.PacketNumber))
+	case protocol.PacketNumberLen2:
+		utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
+	case protocol.PacketNumberLen4:
+		utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
+	case protocol.PacketNumberLen6:
+		return errInvalidPacketNumberLen6
+	default:
+		return errors.New("PublicHeader: PacketNumberLen not set")
 	}
-	return h.writeHeader(b)
+
+	return nil
 }
 
 // GetLength determines the length of the Header.
-func (h *Header) GetLength(pers protocol.Perspective, version protocol.VersionNumber) (protocol.ByteCount, error) {
+func (h *Header) GetLength(version protocol.VersionNumber) (protocol.ByteCount, error) {
 	if !version.UsesTLS() {
-		return h.getPublicHeaderLength(pers)
+		return h.getPublicHeaderLength()
 	}
 	return h.getHeaderLength()
 }
 
+func (h *Header) getHeaderLength() (protocol.ByteCount, error) {
+	if h.IsLongHeader {
+		length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + utils.VarIntLen(uint64(h.PayloadLen)) + protocol.ByteCount(h.PacketNumberLen)
+		if h.Type == protocol.PacketTypeInitial {
+			length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
+		}
+		return length, nil
+	}
+
+	length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
+	if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
+		return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
+	}
+	length += protocol.ByteCount(h.PacketNumberLen)
+	return length, nil
+}
+
+// getPublicHeaderLength gets the length of the publicHeader in bytes.
+// It can only be called for regular packets.
+func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) {
+	length := protocol.ByteCount(1) // 1 byte for public flags
+	if h.PacketNumberLen == protocol.PacketNumberLen6 {
+		return 0, errInvalidPacketNumberLen6
+	}
+	if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
+		return 0, errPacketNumberLenNotSet
+	}
+	length += protocol.ByteCount(h.PacketNumberLen)
+	length += protocol.ByteCount(h.DestConnectionID.Len())
+	// Version Number in packets sent by the client
+	if h.VersionFlag {
+		length += 4
+	}
+	length += protocol.ByteCount(len(h.DiversificationNonce))
+	return length, nil
+}
+
 // Log logs the Header
 func (h *Header) Log(logger utils.Logger) {
 	if h.IsPublicHeader {
@@ -111,3 +216,69 @@ func (h *Header) Log(logger utils.Logger) {
 		h.logHeader(logger)
 	}
 }
+
+func (h *Header) logHeader(logger utils.Logger) {
+	if h.IsLongHeader {
+		if h.Version == 0 {
+			logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
+		} else {
+			var token string
+			if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
+				if len(h.Token) == 0 {
+					token = "Token: (empty), "
+				} else {
+					token = fmt.Sprintf("Token: %#x, ", h.Token)
+				}
+			}
+			if h.Type == protocol.PacketTypeRetry {
+				logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
+				return
+			}
+			logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
+		}
+	} else {
+		logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
+	}
+}
+
+func (h *Header) logPublicHeader(logger utils.Logger) {
+	ver := "(unset)"
+	if h.Version != 0 {
+		ver = h.Version.String()
+	}
+	logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
+}
+
+func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
+	dcil, err := encodeSingleConnIDLen(dest)
+	if err != nil {
+		return 0, err
+	}
+	scil, err := encodeSingleConnIDLen(src)
+	if err != nil {
+		return 0, err
+	}
+	return scil | dcil<<4, nil
+}
+
+func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
+	len := id.Len()
+	if len == 0 {
+		return 0, nil
+	}
+	if len < 4 || len > 18 {
+		return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
+	}
+	return byte(len - 3), nil
+}
+
+func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
+	return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf)
+}
+
+func decodeSingleConnIDLen(enc uint8) int {
+	if enc == 0 {
+		return 0
+	}
+	return int(enc) + 3
+}

+ 235 - 0
vendor/github.com/lucas-clemente/quic-go/internal/wire/header_parser.go

@@ -0,0 +1,235 @@
+package wire
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+
+	"github.com/lucas-clemente/quic-go/internal/protocol"
+	"github.com/lucas-clemente/quic-go/internal/utils"
+	"github.com/lucas-clemente/quic-go/qerr"
+)
+
+// The InvariantHeader is the version independent part of the header
+type InvariantHeader struct {
+	IsLongHeader     bool
+	Version          protocol.VersionNumber
+	SrcConnectionID  protocol.ConnectionID
+	DestConnectionID protocol.ConnectionID
+
+	typeByte byte
+}
+
+// ParseInvariantHeader parses the version independent part of the header
+func ParseInvariantHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*InvariantHeader, error) {
+	typeByte, err := b.ReadByte()
+	if err != nil {
+		return nil, err
+	}
+
+	h := &InvariantHeader{typeByte: typeByte}
+	h.IsLongHeader = typeByte&0x80 > 0
+
+	// If this is not a Long Header, it could either be a Public Header or a Short Header.
+	if !h.IsLongHeader {
+		// In the Public Header 0x8 is the Connection ID Flag.
+		// In the IETF Short Header:
+		// * 0x8 it is the gQUIC Demultiplexing bit, and always 0.
+		// * 0x20 and 0x10 are always 1.
+		var connIDLen int
+		if typeByte&0x8 > 0 { // Public Header containing a connection ID
+			connIDLen = 8
+		}
+		if typeByte&0x38 == 0x30 { // Short Header
+			connIDLen = shortHeaderConnIDLen
+		}
+		if connIDLen > 0 {
+			h.DestConnectionID, err = protocol.ReadConnectionID(b, connIDLen)
+			if err != nil {
+				return nil, err
+			}
+		}
+		return h, nil
+	}
+	// Long Header
+	v, err := utils.BigEndian.ReadUint32(b)
+	if err != nil {
+		return nil, err
+	}
+	h.Version = protocol.VersionNumber(v)
+	connIDLenByte, err := b.ReadByte()
+	if err != nil {
+		return nil, err
+	}
+	dcil, scil := decodeConnIDLen(connIDLenByte)
+	h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil)
+	if err != nil {
+		return nil, err
+	}
+	h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil)
+	if err != nil {
+		return nil, err
+	}
+	return h, nil
+}
+
+// Parse parses the version dependent part of the header
+func (iv *InvariantHeader) Parse(b *bytes.Reader, sentBy protocol.Perspective, ver protocol.VersionNumber) (*Header, error) {
+	if iv.IsLongHeader {
+		if iv.Version == 0 { // Version Negotiation Packet
+			return iv.parseVersionNegotiationPacket(b)
+		}
+		return iv.parseLongHeader(b)
+	}
+	// The Public Header never uses 6 byte packet numbers.
+	// Therefore, the third and fourth bit will never be 11.
+	// For the Short Header, the third and fourth bit are always 11.
+	if iv.typeByte&0x30 != 0x30 {
+		if sentBy == protocol.PerspectiveServer && iv.typeByte&0x1 > 0 {
+			return iv.parseVersionNegotiationPacket(b)
+		}
+		return iv.parsePublicHeader(b, sentBy, ver)
+	}
+	return iv.parseShortHeader(b)
+
+}
+
+func (iv *InvariantHeader) toHeader() *Header {
+	return &Header{
+		IsLongHeader:     iv.IsLongHeader,
+		DestConnectionID: iv.DestConnectionID,
+		SrcConnectionID:  iv.SrcConnectionID,
+		Version:          iv.Version,
+	}
+}
+
+func (iv *InvariantHeader) parseVersionNegotiationPacket(b *bytes.Reader) (*Header, error) {
+	h := iv.toHeader()
+	h.VersionFlag = true
+	if b.Len() == 0 {
+		return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
+	}
+	h.IsVersionNegotiation = true
+	h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
+	for i := 0; b.Len() > 0; i++ {
+		v, err := utils.BigEndian.ReadUint32(b)
+		if err != nil {
+			return nil, qerr.InvalidVersionNegotiationPacket
+		}
+		h.SupportedVersions[i] = protocol.VersionNumber(v)
+	}
+	return h, nil
+}
+
+func (iv *InvariantHeader) parseLongHeader(b *bytes.Reader) (*Header, error) {
+	h := iv.toHeader()
+	h.Type = protocol.PacketType(iv.typeByte & 0x7f)
+
+	if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake {
+		return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type))
+	}
+
+	if h.Type == protocol.PacketTypeRetry {
+		odcilByte, err := b.ReadByte()
+		if err != nil {
+			return nil, err
+		}
+		odcil := decodeSingleConnIDLen(odcilByte & 0xf)
+		h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
+		if err != nil {
+			return nil, err
+		}
+		h.Token = make([]byte, b.Len())
+		if _, err := io.ReadFull(b, h.Token); err != nil {
+			return nil, err
+		}
+		return h, nil
+	}
+
+	if h.Type == protocol.PacketTypeInitial {
+		tokenLen, err := utils.ReadVarInt(b)
+		if err != nil {
+			return nil, err
+		}
+		if tokenLen > uint64(b.Len()) {
+			return nil, io.EOF
+		}
+		h.Token = make([]byte, tokenLen)
+		if _, err := io.ReadFull(b, h.Token); err != nil {
+			return nil, err
+		}
+	}
+
+	pl, err := utils.ReadVarInt(b)
+	if err != nil {
+		return nil, err
+	}
+	h.PayloadLen = protocol.ByteCount(pl)
+	pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
+	if err != nil {
+		return nil, err
+	}
+	h.PacketNumber = pn
+	h.PacketNumberLen = pnLen
+
+	return h, nil
+}
+
+func (iv *InvariantHeader) parseShortHeader(b *bytes.Reader) (*Header, error) {
+	h := iv.toHeader()
+	h.KeyPhase = int(iv.typeByte&0x40) >> 6
+	pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
+	if err != nil {
+		return nil, err
+	}
+	h.PacketNumber = pn
+	h.PacketNumberLen = pnLen
+	return h, nil
+}
+
+func (iv *InvariantHeader) parsePublicHeader(b *bytes.Reader, sentBy protocol.Perspective, ver protocol.VersionNumber) (*Header, error) {
+	h := iv.toHeader()
+	h.IsPublicHeader = true
+	h.ResetFlag = iv.typeByte&0x2 > 0
+	if h.ResetFlag {
+		return h, nil
+	}
+
+	h.VersionFlag = iv.typeByte&0x1 > 0
+	if h.VersionFlag && sentBy == protocol.PerspectiveClient {
+		v, err := utils.BigEndian.ReadUint32(b)
+		if err != nil {
+			return nil, err
+		}
+		h.Version = protocol.VersionNumber(v)
+	}
+
+	// Contrary to what the gQUIC wire spec says, the 0x4 bit only indicates the presence of the diversification nonce for packets sent by the server.
+	// It doesn't have any meaning when sent by the client.
+	if sentBy == protocol.PerspectiveServer && iv.typeByte&0x4 > 0 {
+		h.DiversificationNonce = make([]byte, 32)
+		if _, err := io.ReadFull(b, h.DiversificationNonce); err != nil {
+			if err == io.ErrUnexpectedEOF {
+				return nil, io.EOF
+			}
+			return nil, err
+		}
+	}
+
+	switch iv.typeByte & 0x30 {
+	case 0x00:
+		h.PacketNumberLen = protocol.PacketNumberLen1
+	case 0x10:
+		h.PacketNumberLen = protocol.PacketNumberLen2
+	case 0x20:
+		h.PacketNumberLen = protocol.PacketNumberLen4
+	}
+
+	pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
+	if err != nil {
+		return nil, err
+	}
+	h.PacketNumber = protocol.PacketNumber(pn)
+
+	return h, nil
+}

+ 0 - 205
vendor/github.com/lucas-clemente/quic-go/internal/wire/ietf_header.go

@@ -1,205 +0,0 @@
-package wire
-
-import (
-	"bytes"
-	"errors"
-	"fmt"
-	"io"
-
-	"github.com/lucas-clemente/quic-go/internal/protocol"
-	"github.com/lucas-clemente/quic-go/internal/utils"
-	"github.com/lucas-clemente/quic-go/qerr"
-)
-
-// parseHeader parses the header.
-func parseHeader(b *bytes.Reader) (*Header, error) {
-	typeByte, err := b.ReadByte()
-	if err != nil {
-		return nil, err
-	}
-	if typeByte&0x80 > 0 {
-		return parseLongHeader(b, typeByte)
-	}
-	return parseShortHeader(b, typeByte)
-}
-
-// parse long header and version negotiation packets
-func parseLongHeader(b *bytes.Reader, typeByte byte) (*Header, error) {
-	v, err := utils.BigEndian.ReadUint32(b)
-	if err != nil {
-		return nil, err
-	}
-
-	connIDLenByte, err := b.ReadByte()
-	if err != nil {
-		return nil, err
-	}
-	dcil, scil := decodeConnIDLen(connIDLenByte)
-	destConnID, err := protocol.ReadConnectionID(b, dcil)
-	if err != nil {
-		return nil, err
-	}
-	srcConnID, err := protocol.ReadConnectionID(b, scil)
-	if err != nil {
-		return nil, err
-	}
-
-	h := &Header{
-		IsLongHeader:     true,
-		Version:          protocol.VersionNumber(v),
-		DestConnectionID: destConnID,
-		SrcConnectionID:  srcConnID,
-	}
-
-	if v == 0 { // version negotiation packet
-		if b.Len() == 0 {
-			return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
-		}
-		h.IsVersionNegotiation = true
-		h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
-		for i := 0; b.Len() > 0; i++ {
-			v, err := utils.BigEndian.ReadUint32(b)
-			if err != nil {
-				return nil, qerr.InvalidVersionNegotiationPacket
-			}
-			h.SupportedVersions[i] = protocol.VersionNumber(v)
-		}
-		return h, nil
-	}
-
-	pl, err := utils.ReadVarInt(b)
-	if err != nil {
-		return nil, err
-	}
-	h.PayloadLen = protocol.ByteCount(pl)
-	pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
-	if err != nil {
-		return nil, err
-	}
-	h.PacketNumber = pn
-	h.PacketNumberLen = pnLen
-	h.Type = protocol.PacketType(typeByte & 0x7f)
-
-	if h.Type != protocol.PacketTypeInitial && h.Type != protocol.PacketTypeRetry && h.Type != protocol.PacketType0RTT && h.Type != protocol.PacketTypeHandshake {
-		return nil, qerr.Error(qerr.InvalidPacketHeader, fmt.Sprintf("Received packet with invalid packet type: %d", h.Type))
-	}
-	return h, nil
-}
-
-func parseShortHeader(b *bytes.Reader, typeByte byte) (*Header, error) {
-	connID := make(protocol.ConnectionID, 8)
-	if _, err := io.ReadFull(b, connID); err != nil {
-		if err == io.ErrUnexpectedEOF {
-			err = io.EOF
-		}
-		return nil, err
-	}
-	// bits 2 and 3 must be set, bit 4 must be unset
-	if typeByte&0x38 != 0x30 {
-		return nil, errors.New("invalid bits 3, 4 and 5")
-	}
-	pn, pnLen, err := utils.ReadVarIntPacketNumber(b)
-	if err != nil {
-		return nil, err
-	}
-	return &Header{
-		KeyPhase:         int(typeByte&0x40) >> 6,
-		DestConnectionID: connID,
-		PacketNumber:     pn,
-		PacketNumberLen:  pnLen,
-	}, nil
-}
-
-// writeHeader writes the Header.
-func (h *Header) writeHeader(b *bytes.Buffer) error {
-	if h.IsLongHeader {
-		return h.writeLongHeader(b)
-	}
-	return h.writeShortHeader(b)
-}
-
-// TODO: add support for the key phase
-func (h *Header) writeLongHeader(b *bytes.Buffer) error {
-	if h.SrcConnectionID.Len() != protocol.ConnectionIDLen {
-		return fmt.Errorf("Header: source connection ID must be %d bytes, is %d", protocol.ConnectionIDLen, h.SrcConnectionID.Len())
-	}
-	b.WriteByte(byte(0x80 | h.Type))
-	utils.BigEndian.WriteUint32(b, uint32(h.Version))
-	connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
-	if err != nil {
-		return err
-	}
-	b.WriteByte(connIDLen)
-	b.Write(h.DestConnectionID.Bytes())
-	b.Write(h.SrcConnectionID.Bytes())
-	utils.WriteVarInt(b, uint64(h.PayloadLen))
-	return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
-}
-
-func (h *Header) writeShortHeader(b *bytes.Buffer) error {
-	typeByte := byte(0x30)
-	typeByte |= byte(h.KeyPhase << 6)
-	b.WriteByte(typeByte)
-
-	b.Write(h.DestConnectionID.Bytes())
-	return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
-}
-
-func (h *Header) getHeaderLength() (protocol.ByteCount, error) {
-	if h.IsLongHeader {
-		return 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + utils.VarIntLen(uint64(h.PayloadLen)) + protocol.ByteCount(h.PacketNumberLen), nil
-	}
-
-	length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
-	if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
-		return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
-	}
-	length += protocol.ByteCount(h.PacketNumberLen)
-	return length, nil
-}
-
-func (h *Header) logHeader(logger utils.Logger) {
-	if h.IsLongHeader {
-		if h.Version == 0 {
-			logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
-		} else {
-			logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
-		}
-	} else {
-		logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
-	}
-}
-
-func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
-	dcil, err := encodeSingleConnIDLen(dest)
-	if err != nil {
-		return 0, err
-	}
-	scil, err := encodeSingleConnIDLen(src)
-	if err != nil {
-		return 0, err
-	}
-	return scil | dcil<<4, nil
-}
-
-func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
-	len := id.Len()
-	if len == 0 {
-		return 0, nil
-	}
-	if len < 4 || len > 18 {
-		return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
-	}
-	return byte(len - 3), nil
-}
-
-func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
-	return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf)
-}
-
-func decodeSingleConnIDLen(enc uint8) int {
-	if enc == 0 {
-		return 0
-	}
-	return int(enc) + 3
-}

+ 0 - 244
vendor/github.com/lucas-clemente/quic-go/internal/wire/public_header.go

@@ -1,244 +0,0 @@
-package wire
-
-import (
-	"bytes"
-	"errors"
-	"fmt"
-	"io"
-
-	"github.com/lucas-clemente/quic-go/internal/protocol"
-	"github.com/lucas-clemente/quic-go/internal/utils"
-	"github.com/lucas-clemente/quic-go/qerr"
-)
-
-var (
-	errResetAndVersionFlagSet            = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time")
-	errInvalidConnectionID               = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0")
-	errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets")
-	errInvalidPacketNumberLen6           = errors.New("invalid packet number length: 6 bytes")
-)
-
-// writePublicHeader writes a Public Header.
-func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error {
-	if h.VersionFlag && pers == protocol.PerspectiveServer {
-		return errors.New("PublicHeader: Writing of Version Negotiation Packets not supported")
-	}
-	if h.VersionFlag && h.ResetFlag {
-		return errResetAndVersionFlagSet
-	}
-	if h.SrcConnectionID.Len() != 0 {
-		return errors.New("PublicHeader: SrcConnectionID must not be set")
-	}
-	if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 {
-		return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID))
-	}
-
-	publicFlagByte := uint8(0x00)
-	if h.VersionFlag {
-		publicFlagByte |= 0x01
-	}
-	if h.ResetFlag {
-		publicFlagByte |= 0x02
-	}
-	if h.DestConnectionID.Len() > 0 {
-		publicFlagByte |= 0x08
-	}
-	if len(h.DiversificationNonce) > 0 {
-		if len(h.DiversificationNonce) != 32 {
-			return errors.New("invalid diversification nonce length")
-		}
-		publicFlagByte |= 0x04
-	}
-	// only set PacketNumberLen bits if a packet number will be written
-	if h.hasPacketNumber(pers) {
-		switch h.PacketNumberLen {
-		case protocol.PacketNumberLen1:
-			publicFlagByte |= 0x00
-		case protocol.PacketNumberLen2:
-			publicFlagByte |= 0x10
-		case protocol.PacketNumberLen4:
-			publicFlagByte |= 0x20
-		}
-	}
-	b.WriteByte(publicFlagByte)
-
-	if h.DestConnectionID.Len() > 0 {
-		b.Write(h.DestConnectionID)
-	}
-	if h.VersionFlag && pers == protocol.PerspectiveClient {
-		utils.BigEndian.WriteUint32(b, uint32(h.Version))
-	}
-	if len(h.DiversificationNonce) > 0 {
-		b.Write(h.DiversificationNonce)
-	}
-	// if we're a server, and the VersionFlag is set, we must not include anything else in the packet
-	if !h.hasPacketNumber(pers) {
-		return nil
-	}
-
-	switch h.PacketNumberLen {
-	case protocol.PacketNumberLen1:
-		b.WriteByte(uint8(h.PacketNumber))
-	case protocol.PacketNumberLen2:
-		utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
-	case protocol.PacketNumberLen4:
-		utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
-	case protocol.PacketNumberLen6:
-		return errInvalidPacketNumberLen6
-	default:
-		return errors.New("PublicHeader: PacketNumberLen not set")
-	}
-
-	return nil
-}
-
-// parsePublicHeader parses a QUIC packet's Public Header.
-// The packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient.
-func parsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Header, error) {
-	header := &Header{}
-
-	// First byte
-	publicFlagByte, err := b.ReadByte()
-	if err != nil {
-		return nil, err
-	}
-	header.ResetFlag = publicFlagByte&0x02 > 0
-	header.VersionFlag = publicFlagByte&0x01 > 0
-
-	// TODO: activate this check once Chrome sends the correct value
-	// see https://github.com/lucas-clemente/quic-go/issues/232
-	// if publicFlagByte&0x04 > 0 {
-	// 	return nil, errors.New("diversification nonces should only be sent by servers")
-	// }
-
-	hasConnectionID := publicFlagByte&0x08 > 0
-	if !hasConnectionID && packetSentBy == protocol.PerspectiveClient {
-		return nil, qerr.Error(qerr.InvalidPacketHeader, "receiving packets with omitted ConnectionID is not supported")
-	}
-	if header.hasPacketNumber(packetSentBy) {
-		switch publicFlagByte & 0x30 {
-		case 0x30:
-			return nil, errInvalidPacketNumberLen6
-		case 0x20:
-			header.PacketNumberLen = protocol.PacketNumberLen4
-		case 0x10:
-			header.PacketNumberLen = protocol.PacketNumberLen2
-		case 0x00:
-			header.PacketNumberLen = protocol.PacketNumberLen1
-		}
-	}
-
-	// Connection ID
-	if hasConnectionID {
-		connID, err := protocol.ReadConnectionID(b, 8)
-		if err != nil {
-			return nil, err
-		}
-		if connID[0] == 0 && connID[1] == 0 && connID[2] == 0 && connID[3] == 0 && connID[4] == 0 && connID[5] == 0 && connID[6] == 0 && connID[7] == 0 {
-			return nil, errInvalidConnectionID
-		}
-		header.DestConnectionID = connID
-	}
-
-	// Contrary to what the gQUIC wire spec says, the 0x4 bit only indicates the presence of the diversification nonce for packets sent by the server.
-	// It doesn't have any meaning when sent by the client.
-	if packetSentBy == protocol.PerspectiveServer && publicFlagByte&0x04 > 0 {
-		if !header.VersionFlag && !header.ResetFlag {
-			header.DiversificationNonce = make([]byte, 32)
-			if _, err := io.ReadFull(b, header.DiversificationNonce); err != nil {
-				return nil, err
-			}
-		}
-	}
-
-	// Version (optional)
-	if !header.ResetFlag && header.VersionFlag {
-		if packetSentBy == protocol.PerspectiveServer { // parse the version negotiation packet
-			if b.Len() == 0 {
-				return nil, qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
-			}
-			if b.Len()%4 != 0 {
-				return nil, qerr.InvalidVersionNegotiationPacket
-			}
-			header.IsVersionNegotiation = true
-			header.SupportedVersions = make([]protocol.VersionNumber, 0)
-			for {
-				var versionTag uint32
-				versionTag, err = utils.BigEndian.ReadUint32(b)
-				if err != nil {
-					break
-				}
-				v := protocol.VersionNumber(versionTag)
-				header.SupportedVersions = append(header.SupportedVersions, v)
-			}
-			// a version negotiation packet doesn't have a packet number
-			return header, nil
-		}
-		// packet was sent by the client. Read the version number
-		var versionTag uint32
-		versionTag, err = utils.BigEndian.ReadUint32(b)
-		if err != nil {
-			return nil, err
-		}
-		header.Version = protocol.VersionNumber(versionTag)
-	}
-
-	// Packet number
-	if header.hasPacketNumber(packetSentBy) {
-		packetNumber, err := utils.BigEndian.ReadUintN(b, uint8(header.PacketNumberLen))
-		if err != nil {
-			return nil, err
-		}
-		header.PacketNumber = protocol.PacketNumber(packetNumber)
-	}
-	return header, nil
-}
-
-// getPublicHeaderLength gets the length of the publicHeader in bytes.
-// It can only be called for regular packets.
-func (h *Header) getPublicHeaderLength(pers protocol.Perspective) (protocol.ByteCount, error) {
-	if h.VersionFlag && h.ResetFlag {
-		return 0, errResetAndVersionFlagSet
-	}
-	if h.VersionFlag && pers == protocol.PerspectiveServer {
-		return 0, errGetLengthNotForVersionNegotiation
-	}
-
-	length := protocol.ByteCount(1) // 1 byte for public flags
-	if h.PacketNumberLen == protocol.PacketNumberLen6 {
-		return 0, errInvalidPacketNumberLen6
-	}
-	if h.hasPacketNumber(pers) {
-		if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
-			return 0, errPacketNumberLenNotSet
-		}
-		length += protocol.ByteCount(h.PacketNumberLen)
-	}
-	length += protocol.ByteCount(h.DestConnectionID.Len()) // if set, always 8 bytes
-	// Version Number in packets sent by the client
-	if h.VersionFlag {
-		length += 4
-	}
-	length += protocol.ByteCount(len(h.DiversificationNonce))
-	return length, nil
-}
-
-// hasPacketNumber determines if this Public Header will contain a packet number
-// this depends on the ResetFlag, the VersionFlag and who sent the packet
-func (h *Header) hasPacketNumber(packetSentBy protocol.Perspective) bool {
-	if h.ResetFlag {
-		return false
-	}
-	if h.VersionFlag && packetSentBy == protocol.PerspectiveServer {
-		return false
-	}
-	return true
-}
-
-func (h *Header) logPublicHeader(logger utils.Logger) {
-	ver := "(unset)"
-	if h.Version != 0 {
-		ver = h.Version.String()
-	}
-	logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
-}

+ 2 - 1
vendor/github.com/lucas-clemente/quic-go/internal/wire/version_negotiation.go

@@ -22,7 +22,8 @@ func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []pro
 // ComposeVersionNegotiation composes a Version Negotiation according to the IETF draft
 func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) ([]byte, error) {
 	greasedVersions := protocol.GetGreasedVersions(versions)
-	buf := bytes.NewBuffer(make([]byte, 0, 1+8+4+len(greasedVersions)*4))
+	expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* connection ID length field */ + destConnID.Len() + srcConnID.Len() + len(greasedVersions)*4
+	buf := bytes.NewBuffer(make([]byte, 0, expectedLen))
 	r := make([]byte, 1)
 	_, _ = rand.Read(r) // ignore the error here. It is not critical to have perfect random here.
 	buf.WriteByte(r[0] | 0x80)

+ 0 - 116
vendor/github.com/lucas-clemente/quic-go/mint_utils.go

@@ -1,70 +1,15 @@
 package quic
 
 import (
-	"bytes"
 	gocrypto "crypto"
 	"crypto/tls"
 	"crypto/x509"
 	"errors"
-	"fmt"
-	"io"
 
 	"github.com/bifurcation/mint"
-	"github.com/lucas-clemente/quic-go/internal/crypto"
-	"github.com/lucas-clemente/quic-go/internal/handshake"
 	"github.com/lucas-clemente/quic-go/internal/protocol"
-	"github.com/lucas-clemente/quic-go/internal/utils"
-	"github.com/lucas-clemente/quic-go/internal/wire"
 )
 
-type mintController struct {
-	csc  *handshake.CryptoStreamConn
-	conn *mint.Conn
-}
-
-var _ handshake.MintTLS = &mintController{}
-
-func newMintController(
-	csc *handshake.CryptoStreamConn,
-	mconf *mint.Config,
-	pers protocol.Perspective,
-) handshake.MintTLS {
-	var conn *mint.Conn
-	if pers == protocol.PerspectiveClient {
-		conn = mint.Client(csc, mconf)
-	} else {
-		conn = mint.Server(csc, mconf)
-	}
-	return &mintController{
-		csc:  csc,
-		conn: conn,
-	}
-}
-
-func (mc *mintController) GetCipherSuite() mint.CipherSuiteParams {
-	return mc.conn.ConnectionState().CipherSuite
-}
-
-func (mc *mintController) ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) {
-	return mc.conn.ComputeExporter(label, context, keyLength)
-}
-
-func (mc *mintController) Handshake() mint.Alert {
-	return mc.conn.Handshake()
-}
-
-func (mc *mintController) State() mint.State {
-	return mc.conn.ConnectionState().HandshakeState
-}
-
-func (mc *mintController) ConnectionState() mint.ConnectionState {
-	return mc.conn.ConnectionState()
-}
-
-func (mc *mintController) SetCryptoStream(stream io.ReadWriter) {
-	mc.csc.SetStream(stream)
-}
-
 func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Config, error) {
 	mconf := &mint.Config{
 		NonBlocking: true,
@@ -105,64 +50,3 @@ func tlsToMintConfig(tlsConf *tls.Config, pers protocol.Perspective) (*mint.Conf
 	}
 	return mconf, nil
 }
-
-// unpackInitialOrRetryPacket unpacks packets Initial and Retry packets
-// These packets must contain a STREAM_FRAME for the crypto stream, starting at offset 0.
-func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, logger utils.Logger, version protocol.VersionNumber) (*wire.StreamFrame, error) {
-	decrypted, err := aead.Open(data[:0], data, hdr.PacketNumber, hdr.Raw)
-	if err != nil {
-		return nil, err
-	}
-	var frame *wire.StreamFrame
-	r := bytes.NewReader(decrypted)
-	for {
-		f, err := wire.ParseNextFrame(r, hdr, version)
-		if err != nil {
-			return nil, err
-		}
-		var ok bool
-		if frame, ok = f.(*wire.StreamFrame); ok || frame == nil {
-			break
-		}
-	}
-	if frame == nil {
-		return nil, errors.New("Packet doesn't contain a STREAM_FRAME")
-	}
-	if frame.StreamID != version.CryptoStreamID() {
-		return nil, fmt.Errorf("Received STREAM_FRAME for wrong stream (Stream ID %d)", frame.StreamID)
-	}
-	// We don't need a check for the stream ID here.
-	// The packetUnpacker checks that there's no unencrypted stream data except for the crypto stream.
-	if frame.Offset != 0 {
-		return nil, errors.New("received stream data with non-zero offset")
-	}
-	if logger.Debug() {
-		logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.DestConnectionID)
-		hdr.Log(logger)
-		wire.LogFrame(logger, frame, false)
-	}
-	return frame, nil
-}
-
-// packUnencryptedPacket provides a low-overhead way to pack a packet.
-// It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available.
-func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective, logger utils.Logger) ([]byte, error) {
-	raw := *getPacketBuffer()
-	buffer := bytes.NewBuffer(raw[:0])
-	if err := hdr.Write(buffer, pers, hdr.Version); err != nil {
-		return nil, err
-	}
-	payloadStartIndex := buffer.Len()
-	if err := f.Write(buffer, hdr.Version); err != nil {
-		return nil, err
-	}
-	raw = raw[0:buffer.Len()]
-	_ = aead.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], hdr.PacketNumber, raw[:payloadStartIndex])
-	raw = raw[0 : buffer.Len()+aead.Overhead()]
-	if logger.Debug() {
-		logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", hdr.PacketNumber, len(raw), hdr.SrcConnectionID, protocol.EncryptionUnencrypted)
-		hdr.Log(logger)
-		wire.LogFrame(logger, f, true)
-	}
-	return raw, nil
-}

+ 17 - 16
vendor/github.com/lucas-clemente/quic-go/mockgen.go

@@ -1,18 +1,19 @@
 package quic
 
-//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI StreamI"
-//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI ReceiveStreamI"
-//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI SendStreamI"
-//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender StreamSender"
-//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter StreamGetter"
-//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource StreamFrameSource"
-//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStreamI CryptoStream"
-//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager StreamManager"
-//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker Unpacker"
-//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD QuicAEAD"
-//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD"
-//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner"
-//go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession QuicSession"
-//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager PacketHandlerManager"
-//go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'"
-//go:generate sh -c "goimports -w mock*_test.go"
+//go:generate sh -c "./mockgen_private.sh quic mock_stream_internal_test.go github.com/lucas-clemente/quic-go streamI"
+//go:generate sh -c "./mockgen_private.sh quic mock_receive_stream_internal_test.go github.com/lucas-clemente/quic-go receiveStreamI"
+//go:generate sh -c "./mockgen_private.sh quic mock_send_stream_internal_test.go github.com/lucas-clemente/quic-go sendStreamI"
+//go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender"
+//go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter"
+//go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource"
+//go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStream"
+//go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager"
+//go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker"
+//go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD"
+//go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD"
+//go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner"
+//go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession"
+//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler"
+//go:generate sh -c "./mockgen_private.sh quic mock_unknown_packet_handler_test.go github.com/lucas-clemente/quic-go unknownPacketHandler"
+//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager"
+//go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/lucas-clemente/quic-go multiplexer"

+ 13 - 4
vendor/github.com/lucas-clemente/quic-go/mockgen_private.sh

@@ -7,13 +7,22 @@
 TEMP_DIR=$(mktemp -d)
 mkdir -p $TEMP_DIR/src/github.com/lucas-clemente/quic-go/
 
+# uppercase the name of the interface
+INTERFACE_NAME="$(tr '[:lower:]' '[:upper:]' <<< ${4:0:1})${4:1}"
+
 # copy all .go files to a temporary directory
-# golang.org/x/crypto/curve25519/ uses Go compiler directives, which is confusing to mockgen
-rsync -r --exclude 'vendor/golang.org/x/crypto/curve25519/' --include='*.go' --include '*/' --exclude '*'   $GOPATH/src/github.com/lucas-clemente/quic-go/ $TEMP_DIR/src/github.com/lucas-clemente/quic-go/
-echo "type $5 = $4" >> $TEMP_DIR/src/github.com/lucas-clemente/quic-go/interface.go
+rsync -r --exclude 'vendor' --include='*.go' --include '*/' --exclude '*'   $GOPATH/src/github.com/lucas-clemente/quic-go/ $TEMP_DIR/src/github.com/lucas-clemente/quic-go/
+
+# create a public alias for the interface, so that mockgen can process it
+echo -e "package $1\n" > $TEMP_DIR/src/github.com/lucas-clemente/quic-go/mockgen_interface.go
+echo "type $INTERFACE_NAME = $4" >> $TEMP_DIR/src/github.com/lucas-clemente/quic-go/mockgen_interface.go
 
 export GOPATH="$TEMP_DIR:$GOPATH"
 
-mockgen -package $1 -self_package $1 -destination $2 $3 $5
+mockgen -package $1 -self_package $1 -destination $2 $3 $INTERFACE_NAME
+
+# mockgen imports quic-go as 'import quic_go github.com/lucas_clemente/quic-go'
+sed -i '' 's/quic_go.//g' $2
+goimports -w $2
 
 rm -r "$TEMP_DIR"

+ 63 - 0
vendor/github.com/lucas-clemente/quic-go/multiplexer.go

@@ -0,0 +1,63 @@
+package quic
+
+import (
+	"fmt"
+	"net"
+	"sync"
+
+	"github.com/lucas-clemente/quic-go/internal/utils"
+)
+
+var (
+	connMuxerOnce sync.Once
+	connMuxer     multiplexer
+)
+
+type multiplexer interface {
+	AddConn(net.PacketConn, int) (packetHandlerManager, error)
+}
+
+type connManager struct {
+	connIDLen int
+	manager   packetHandlerManager
+}
+
+// The connMultiplexer listens on multiple net.PacketConns and dispatches
+// incoming packets to the session handler.
+type connMultiplexer struct {
+	mutex sync.Mutex
+
+	conns                   map[net.PacketConn]connManager
+	newPacketHandlerManager func(net.PacketConn, int, utils.Logger) packetHandlerManager // so it can be replaced in the tests
+
+	logger utils.Logger
+}
+
+var _ multiplexer = &connMultiplexer{}
+
+func getMultiplexer() multiplexer {
+	connMuxerOnce.Do(func() {
+		connMuxer = &connMultiplexer{
+			conns:                   make(map[net.PacketConn]connManager),
+			logger:                  utils.DefaultLogger.WithPrefix("muxer"),
+			newPacketHandlerManager: newPacketHandlerMap,
+		}
+	})
+	return connMuxer
+}
+
+func (m *connMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandlerManager, error) {
+	m.mutex.Lock()
+	defer m.mutex.Unlock()
+
+	p, ok := m.conns[c]
+	if !ok {
+		manager := m.newPacketHandlerManager(c, connIDLen, m.logger)
+		p = connManager{connIDLen: connIDLen, manager: manager}
+		m.conns[c] = p
+	}
+	if p.connIDLen != connIDLen {
+		return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen)
+	}
+	return p.manager, nil
+}

+ 135 - 15
vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go

@@ -1,10 +1,15 @@
 package quic
 
 import (
+	"bytes"
+	"fmt"
+	"net"
 	"sync"
 	"time"
 
 	"github.com/lucas-clemente/quic-go/internal/protocol"
+	"github.com/lucas-clemente/quic-go/internal/utils"
+	"github.com/lucas-clemente/quic-go/internal/wire"
 )
 
 // The packetHandlerMap stores packetHandlers, identified by connection ID.
@@ -14,26 +19,30 @@ import (
 type packetHandlerMap struct {
 	mutex sync.RWMutex
 
+	conn      net.PacketConn
+	connIDLen int
+
 	handlers map[string] /* string(ConnectionID)*/ packetHandler
+	server   unknownPacketHandler
 	closed   bool
 
 	deleteClosedSessionsAfter time.Duration
+
+	logger utils.Logger
 }
 
 var _ packetHandlerManager = &packetHandlerMap{}
 
-func newPacketHandlerMap() packetHandlerManager {
-	return &packetHandlerMap{
+func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager {
+	m := &packetHandlerMap{
+		conn:                      conn,
+		connIDLen:                 connIDLen,
 		handlers:                  make(map[string]packetHandler),
 		deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
+		logger:                    logger,
 	}
-}
-
-func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) {
-	h.mutex.RLock()
-	sess, ok := h.handlers[string(id)]
-	h.mutex.RUnlock()
-	return sess, ok
+	go m.listen()
+	return m
 }
 
 func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
@@ -43,22 +52,51 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler)
 }
 
 func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
+	h.removeByConnectionIDAsString(string(id))
+}
+
+func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
 	h.mutex.Lock()
-	h.handlers[string(id)] = nil
+	h.handlers[id] = nil
 	h.mutex.Unlock()
 
 	time.AfterFunc(h.deleteClosedSessionsAfter, func() {
 		h.mutex.Lock()
-		delete(h.handlers, string(id))
+		delete(h.handlers, id)
 		h.mutex.Unlock()
 	})
 }
 
-func (h *packetHandlerMap) Close(err error) {
+func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
+	h.mutex.Lock()
+	h.server = s
+	h.mutex.Unlock()
+}
+
+func (h *packetHandlerMap) CloseServer() {
+	h.mutex.Lock()
+	h.server = nil
+	var wg sync.WaitGroup
+	for id, handler := range h.handlers {
+		if handler != nil && handler.GetPerspective() == protocol.PerspectiveServer {
+			wg.Add(1)
+			go func(id string, handler packetHandler) {
+				// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
+				_ = handler.Close()
+				h.removeByConnectionIDAsString(id)
+				wg.Done()
+			}(id, handler)
+		}
+	}
+	h.mutex.Unlock()
+	wg.Wait()
+}
+
+func (h *packetHandlerMap) close(e error) error {
 	h.mutex.Lock()
 	if h.closed {
 		h.mutex.Unlock()
-		return
+		return nil
 	}
 	h.closed = true
 
@@ -67,12 +105,94 @@ func (h *packetHandlerMap) Close(err error) {
 		if handler != nil {
 			wg.Add(1)
 			go func(handler packetHandler) {
-				// session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
-				_ = handler.Close(err)
+				handler.destroy(e)
 				wg.Done()
 			}(handler)
 		}
 	}
+
+	if h.server != nil {
+		h.server.closeWithError(e)
+	}
 	h.mutex.Unlock()
 	wg.Wait()
+	return nil
+}
+
+func (h *packetHandlerMap) listen() {
+	for {
+		data := *getPacketBuffer()
+		data = data[:protocol.MaxReceivePacketSize]
+		// The packet size should not exceed protocol.MaxReceivePacketSize bytes
+		// If it does, we only read a truncated packet, which will then end up undecryptable
+		n, addr, err := h.conn.ReadFrom(data)
+		if err != nil {
+			h.close(err)
+			return
+		}
+		data = data[:n]
+
+		if err := h.handlePacket(addr, data); err != nil {
+			h.logger.Debugf("error handling packet from %s: %s", addr, err)
+		}
+	}
+}
+
+func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
+	rcvTime := time.Now()
+
+	r := bytes.NewReader(data)
+	iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen)
+	// drop the packet if we can't parse the header
+	if err != nil {
+		return fmt.Errorf("error parsing invariant header: %s", err)
+	}
+
+	h.mutex.RLock()
+	handler, ok := h.handlers[string(iHdr.DestConnectionID)]
+	server := h.server
+	h.mutex.RUnlock()
+
+	var sentBy protocol.Perspective
+	var version protocol.VersionNumber
+	var handlePacket func(*receivedPacket)
+	if ok && handler == nil {
+		// Late packet for closed session
+		return nil
+	}
+	if !ok {
+		if server == nil { // no server set
+			return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
+		}
+		handlePacket = server.handlePacket
+		sentBy = protocol.PerspectiveClient
+		version = iHdr.Version
+	} else {
+		sentBy = handler.GetPerspective().Opposite()
+		version = handler.GetVersion()
+		handlePacket = handler.handlePacket
+	}
+
+	hdr, err := iHdr.Parse(r, sentBy, version)
+	if err != nil {
+		return fmt.Errorf("error parsing header: %s", err)
+	}
+	hdr.Raw = data[:len(data)-r.Len()]
+	packetData := data[len(data)-r.Len():]
+
+	if hdr.IsLongHeader {
+		if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
+			return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
+		}
+		packetData = packetData[:int(hdr.PayloadLen)]
+		// TODO(#1312): implement parsing of compound packets
+	}
+
+	handlePacket(&receivedPacket{
+		remoteAddr: addr,
+		header:     hdr,
+		data:       packetData,
+		rcvTime:    rcvTime,
+	})
+	return nil
 }

+ 10 - 5
vendor/github.com/lucas-clemente/quic-go/packet_packer.go

@@ -51,9 +51,11 @@ type packetPacker struct {
 
 	perspective protocol.Perspective
 	version     protocol.VersionNumber
-	divNonce    []byte
 	cryptoSetup sealingManager
 
+	token    []byte
+	divNonce []byte
+
 	packetNumberGenerator *packetNumberGenerator
 	getPacketNumberLen    func(protocol.PacketNumber) protocol.PacketNumberLen
 	streams               streamFrameSource
@@ -75,6 +77,7 @@ func newPacketPacker(
 	initialPacketNumber protocol.PacketNumber,
 	getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen,
 	remoteAddr net.Addr, // only used for determining the max packet size
+	token []byte,
 	divNonce []byte,
 	cryptoSetup sealingManager,
 	streamFramer streamFrameSource,
@@ -97,6 +100,7 @@ func newPacketPacker(
 	return &packetPacker{
 		cryptoSetup:           cryptoSetup,
 		divNonce:              divNonce,
+		token:                 token,
 		destConnID:            destConnID,
 		srcConnID:             srcConnID,
 		perspective:           perspective,
@@ -172,7 +176,7 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP
 		var payloadLength protocol.ByteCount
 
 		header := p.getHeader(encLevel)
-		headerLength, err := header.GetLength(p.perspective, p.version)
+		headerLength, err := header.GetLength(p.version)
 		if err != nil {
 			return nil, err
 		}
@@ -298,7 +302,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
 	encLevel, sealer := p.cryptoSetup.GetSealer()
 
 	header := p.getHeader(encLevel)
-	headerLength, err := header.GetLength(p.perspective, p.version)
+	headerLength, err := header.GetLength(p.version)
 	if err != nil {
 		return nil, err
 	}
@@ -352,7 +356,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
 func (p *packetPacker) packCryptoPacket() (*packedPacket, error) {
 	encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream()
 	header := p.getHeader(encLevel)
-	headerLength, err := header.GetLength(p.perspective, p.version)
+	headerLength, err := header.GetLength(p.version)
 	if err != nil {
 		return nil, err
 	}
@@ -463,6 +467,7 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
 		header.PayloadLen = p.maxPacketSize
 		if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient {
 			header.Type = protocol.PacketTypeInitial
+			header.Token = p.token
 		} else {
 			header.Type = protocol.PacketTypeHandshake
 		}
@@ -498,7 +503,7 @@ func (p *packetPacker) writeAndSealPacket(
 	// the payload length is only needed for Long Headers
 	if header.IsLongHeader {
 		if header.Type == protocol.PacketTypeInitial {
-			headerLen, _ := header.GetLength(p.perspective, p.version)
+			headerLen, _ := header.GetLength(p.version)
 			header.PayloadLen = protocol.ByteCount(protocol.MinInitialPacketSize) - headerLen
 		} else {
 			payloadLen := protocol.ByteCount(sealer.Overhead())

+ 7 - 10
vendor/github.com/lucas-clemente/quic-go/send_stream.go

@@ -166,22 +166,19 @@ func (s *sendStream) popStreamFrameImpl(maxBytes protocol.ByteCount) (bool /* co
 		if s.dataForWriting == nil {
 			return false, nil, false
 		}
-		isBlocked, _ := s.flowController.IsBlocked()
-		return false, nil, !isBlocked
-	}
-	if frame.FinBit {
-		s.finSent = true
-		return true, frame, s.dataForWriting != nil
-	} else if s.streamID != s.version.CryptoStreamID() { // TODO(#657): Flow control for the crypto stream
-		if isBlocked, offset := s.flowController.IsBlocked(); isBlocked {
+		if isBlocked, offset := s.flowController.IsNewlyBlocked(); isBlocked {
 			s.sender.queueControlFrame(&wire.StreamBlockedFrame{
 				StreamID: s.streamID,
 				Offset:   offset,
 			})
-			return false, frame, false
+			return false, nil, false
 		}
+		return false, nil, true
+	}
+	if frame.FinBit {
+		s.finSent = true
 	}
-	return false, frame, s.dataForWriting != nil
+	return frame.FinBit, frame, s.dataForWriting != nil
 }
 
 func (s *sendStream) getDataForWriting(maxBytes protocol.ByteCount) ([]byte, bool /* should send FIN */) {

+ 104 - 155
vendor/github.com/lucas-clemente/quic-go/server.go

@@ -1,10 +1,10 @@
 package quic
 
 import (
-	"bytes"
 	"crypto/tls"
 	"errors"
 	"fmt"
+	"io"
 	"net"
 	"time"
 
@@ -13,28 +13,35 @@ import (
 	"github.com/lucas-clemente/quic-go/internal/protocol"
 	"github.com/lucas-clemente/quic-go/internal/utils"
 	"github.com/lucas-clemente/quic-go/internal/wire"
-	"github.com/lucas-clemente/quic-go/qerr"
 )
 
 // packetHandler handles packets
 type packetHandler interface {
 	handlePacket(*receivedPacket)
-	Close(error) error
+	io.Closer
+	destroy(error)
+	GetVersion() protocol.VersionNumber
+	GetPerspective() protocol.Perspective
+}
+
+type unknownPacketHandler interface {
+	handlePacket(*receivedPacket)
+	closeWithError(error) error
 }
 
 type packetHandlerManager interface {
 	Add(protocol.ConnectionID, packetHandler)
-	Get(protocol.ConnectionID) (packetHandler, bool)
+	SetServer(unknownPacketHandler)
 	Remove(protocol.ConnectionID)
-	Close(error)
+	CloseServer()
 }
 
 type quicSession interface {
 	Session
 	handlePacket(*receivedPacket)
-	getCryptoStream() cryptoStreamI
 	GetVersion() protocol.VersionNumber
 	run() error
+	destroy(error)
 	closeRemote(error)
 }
 
@@ -59,6 +66,9 @@ type server struct {
 	config  *Config
 
 	conn net.PacketConn
+	// If the server is started with ListenAddr, we create a packet conn.
+	// If it is started with Listen, we take a packet conn as a parameter.
+	createdPacketConn bool
 
 	supportsTLS bool
 	serverTLS   *serverTLS
@@ -81,6 +91,7 @@ type server struct {
 }
 
 var _ Listener = &server{}
+var _ unknownPacketHandler = &server{}
 
 // ListenAddr creates a QUIC server listening on a given address.
 // The tls.Config must not be nil, the quic.Config may be nil.
@@ -93,12 +104,21 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err
 	if err != nil {
 		return nil, err
 	}
-	return Listen(conn, tlsConf, config)
+	serv, err := listen(conn, tlsConf, config)
+	if err != nil {
+		return nil, err
+	}
+	serv.createdPacketConn = true
+	return serv, nil
 }
 
 // Listen listens for QUIC connections on a given net.PacketConn.
 // The tls.Config must not be nil, the quic.Config may be nil.
 func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
+	return listen(conn, tlsConf, config)
+}
+
+func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) {
 	certChain := crypto.NewCertChain(tlsConf)
 	kex, err := crypto.NewCurve25519KEX()
 	if err != nil {
@@ -122,6 +142,10 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
 		}
 	}
 
+	sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength)
+	if err != nil {
+		return nil, err
+	}
 	s := &server{
 		conn:           conn,
 		tlsConf:        tlsConf,
@@ -129,7 +153,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
 		certChain:      certChain,
 		scfg:           scfg,
 		newSession:     newSession,
-		sessionHandler: newPacketHandlerMap(),
+		sessionHandler: sessionHandler,
 		sessionQueue:   make(chan Session, 5),
 		errorChan:      make(chan struct{}),
 		supportsTLS:    supportsTLS,
@@ -141,7 +165,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener,
 			return nil, err
 		}
 	}
-	go s.serve()
+	sessionHandler.SetServer(s)
 	s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
 	return s, nil
 }
@@ -154,11 +178,7 @@ func (s *server) setup() {
 }
 
 func (s *server) setupTLS() error {
-	cookieHandler, err := handshake.NewCookieHandler(s.config.AcceptCookie, s.logger)
-	if err != nil {
-		return err
-	}
-	serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, cookieHandler, s.tlsConf, s.logger)
+	serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, s.tlsConf, s.logger)
 	if err != nil {
 		return err
 	}
@@ -170,9 +190,10 @@ func (s *server) setupTLS() error {
 			case <-s.errorChan:
 				return
 			case tlsSession := <-sessionChan:
-				// The connection ID is a randomly chosen 8 byte value.
+				// The connection ID is a randomly chosen value.
 				// It is safe to assume that it doesn't collide with other randomly chosen values.
-				s.sessionHandler.Add(tlsSession.connID, tlsSession.sess)
+				serverSession := newServerSession(tlsSession.sess, s.config, s.logger)
+				s.sessionHandler.Add(tlsSession.connID, serverSession)
 			}
 		}
 	}()
@@ -240,6 +261,10 @@ func populateServerConfig(config *Config) *Config {
 	} else if maxIncomingUniStreams < 0 {
 		maxIncomingUniStreams = 0
 	}
+	connIDLen := config.ConnectionIDLength
+	if connIDLen == 0 {
+		connIDLen = protocol.DefaultConnectionIDLength
+	}
 
 	return &Config{
 		Versions:                              versions,
@@ -251,27 +276,7 @@ func populateServerConfig(config *Config) *Config {
 		MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
 		MaxIncomingStreams:                    maxIncomingStreams,
 		MaxIncomingUniStreams:                 maxIncomingUniStreams,
-	}
-}
-
-// serve listens on an existing PacketConn
-func (s *server) serve() {
-	for {
-		data := *getPacketBuffer()
-		data = data[:protocol.MaxReceivePacketSize]
-		// The packet size should not exceed protocol.MaxReceivePacketSize bytes
-		// If it does, we only read a truncated packet, which will then end up undecryptable
-		n, remoteAddr, err := s.conn.ReadFrom(data)
-		if err != nil {
-			s.serverError = err
-			close(s.errorChan)
-			_ = s.Close()
-			return
-		}
-		data = data[:n]
-		if err := s.handlePacket(remoteAddr, data); err != nil {
-			s.logger.Errorf("error handling packet: %s", err.Error())
-		}
+		ConnectionIDLength:                    connIDLen,
 	}
 }
 
@@ -288,9 +293,17 @@ func (s *server) Accept() (Session, error) {
 
 // Close the server
 func (s *server) Close() error {
-	s.sessionHandler.Close(nil)
-	err := s.conn.Close()
-	<-s.errorChan // wait for serve() to return
+	s.sessionHandler.CloseServer()
+	if s.serverError == nil {
+		s.serverError = errors.New("server closed")
+	}
+	var err error
+	// If the server was started with ListenAddr, we created the packet conn.
+	// We need to close it in order to make the go routine reading from that conn return.
+	if s.createdPacketConn {
+		err = s.conn.Close()
+	}
+	close(s.errorChan)
 	return err
 }
 
@@ -299,141 +312,77 @@ func (s *server) Addr() net.Addr {
 	return s.conn.LocalAddr()
 }
 
-func (s *server) handlePacket(remoteAddr net.Addr, packet []byte) error {
-	rcvTime := time.Now()
-
-	r := bytes.NewReader(packet)
-	hdr, err := wire.ParseHeaderSentByClient(r)
-	if err != nil {
-		return qerr.Error(qerr.InvalidPacketHeader, err.Error())
-	}
-	hdr.Raw = packet[:len(packet)-r.Len()]
-	packetData := packet[len(packet)-r.Len():]
-
-	if hdr.IsPublicHeader {
-		return s.handleGQUICPacket(hdr, packetData, remoteAddr, rcvTime)
-	}
-	return s.handleIETFQUICPacket(hdr, packetData, remoteAddr, rcvTime)
+func (s *server) closeWithError(e error) error {
+	s.serverError = e
+	return s.Close()
 }
 
-func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
-	if hdr.IsLongHeader {
-		if !s.supportsTLS {
-			return errors.New("Received an IETF QUIC Long Header")
-		}
-		if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
-			return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
-		}
-		packetData = packetData[:int(hdr.PayloadLen)]
-		// TODO(#1312): implement parsing of compound packets
-
-		switch hdr.Type {
-		case protocol.PacketTypeInitial:
-			go s.serverTLS.HandleInitial(remoteAddr, hdr, packetData)
-			return nil
-		case protocol.PacketTypeHandshake:
-			// nothing to do here. Packet will be passed to the session.
-		default:
-			// Note that this also drops 0-RTT packets.
-			return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
-		}
-	}
-
-	session, sessionKnown := s.sessionHandler.Get(hdr.DestConnectionID)
-	if sessionKnown && session == nil {
-		// Late packet for closed session
-		return nil
+func (s *server) handlePacket(p *receivedPacket) {
+	if err := s.handlePacketImpl(p); err != nil {
+		s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
 	}
-	if !sessionKnown {
-		s.logger.Debugf("Received %s packet for unknown connection %s.", hdr.Type, hdr.DestConnectionID)
-		return nil
-	}
-
-	session.handlePacket(&receivedPacket{
-		remoteAddr: remoteAddr,
-		header:     hdr,
-		data:       packetData,
-		rcvTime:    rcvTime,
-	})
-	return nil
 }
 
-func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
-	// ignore all Public Reset packets
-	if hdr.ResetFlag {
-		s.logger.Infof("Received unexpected Public Reset for connection %s.", hdr.DestConnectionID)
-		return nil
-	}
+func (s *server) handlePacketImpl(p *receivedPacket) error {
+	hdr := p.header
 
-	session, sessionKnown := s.sessionHandler.Get(hdr.DestConnectionID)
-	if sessionKnown && session == nil {
-		// Late packet for closed session
+	if hdr.VersionFlag || hdr.IsLongHeader {
+		// send a Version Negotiation Packet if the client is speaking a different protocol version
+		if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
+			return s.sendVersionNegotiationPacket(p)
+		}
+	}
+	if hdr.Type == protocol.PacketTypeInitial {
+		go s.serverTLS.HandleInitial(p)
 		return nil
 	}
 
-	// If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset
-	// This should only happen after a server restart, when we still receive packets for connections that we lost the state for.
-	if !sessionKnown && !hdr.VersionFlag {
-		_, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), remoteAddr)
+	// TODO(#943): send Stateless Reset, if this an IETF QUIC packet
+	if !hdr.VersionFlag {
+		_, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), p.remoteAddr)
 		return err
 	}
 
-	// a session is only created once the client sent a supported version
-	// if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated
-	// it is safe to drop it
-	if sessionKnown && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
-		return nil
+	// This is (potentially) a Client Hello.
+	// Make sure it has the minimum required size before spending any more ressources on it.
+	if len(p.data) < protocol.MinClientHelloSize {
+		return errors.New("dropping small packet for unknown connection")
 	}
 
-	// send a Version Negotiation Packet if the client is speaking a different protocol version
-	// since the client send a Public Header (only gQUIC has a Version Flag), we need to send a gQUIC Version Negotiation Packet
-	if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
-		// drop packets that are too small to be valid first packets
-		if len(packetData) < protocol.MinClientHelloSize {
-			return errors.New("dropping small packet with unknown version")
-		}
-		s.logger.Infof("Client offered version %s, sending Version Negotiation Packet", hdr.Version)
-		_, err := s.conn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions), remoteAddr)
+	s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, hdr.Version, p.remoteAddr)
+	sess, err := s.newSession(
+		&conn{pconn: s.conn, currentAddr: p.remoteAddr},
+		s.sessionRunner,
+		hdr.Version,
+		hdr.DestConnectionID,
+		s.scfg,
+		s.tlsConf,
+		s.config,
+		s.logger,
+	)
+	if err != nil {
 		return err
 	}
+	s.sessionHandler.Add(hdr.DestConnectionID, newServerSession(sess, s.config, s.logger))
+	go sess.run()
+	sess.handlePacket(p)
+	return nil
+}
 
-	if !sessionKnown {
-		// This is (potentially) a Client Hello.
-		// Make sure it has the minimum required size before spending any more ressources on it.
-		if len(packetData) < protocol.MinClientHelloSize {
-			return errors.New("dropping small packet for unknown connection")
-		}
-
-		version := hdr.Version
-		if !protocol.IsSupportedVersion(s.config.Versions, version) {
-			return errors.New("Server BUG: negotiated version not supported")
-		}
+func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
+	hdr := p.header
+	s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
 
-		s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, version, remoteAddr)
-		sess, err := s.newSession(
-			&conn{pconn: s.conn, currentAddr: remoteAddr},
-			s.sessionRunner,
-			version,
-			hdr.DestConnectionID,
-			s.scfg,
-			s.tlsConf,
-			s.config,
-			s.logger,
-		)
+	var data []byte
+	if hdr.Version.UsesIETFFrameFormat() {
+		var err error
+		data, err = wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
 		if err != nil {
 			return err
 		}
-		s.sessionHandler.Add(hdr.DestConnectionID, sess)
-
-		go sess.run()
-		session = sess
+	} else {
+		data = wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions)
 	}
-
-	session.handlePacket(&receivedPacket{
-		remoteAddr: remoteAddr,
-		header:     hdr,
-		data:       packetData,
-		rcvTime:    rcvTime,
-	})
-	return nil
+	_, err := s.conn.WriteTo(data, p.remoteAddr)
+	return err
 }

+ 63 - 0
vendor/github.com/lucas-clemente/quic-go/server_session.go

@@ -0,0 +1,63 @@
+package quic
+
+import (
+	"fmt"
+
+	"github.com/lucas-clemente/quic-go/internal/protocol"
+	"github.com/lucas-clemente/quic-go/internal/utils"
+)
+
+type serverSession struct {
+	quicSession
+
+	config *Config
+
+	logger utils.Logger
+}
+
+var _ packetHandler = &serverSession{}
+
+func newServerSession(sess quicSession, config *Config, logger utils.Logger) packetHandler {
+	return &serverSession{
+		quicSession: sess,
+		config:      config,
+		logger:      logger,
+	}
+}
+
+func (s *serverSession) handlePacket(p *receivedPacket) {
+	if err := s.handlePacketImpl(p); err != nil {
+		s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
+	}
+}
+
+func (s *serverSession) handlePacketImpl(p *receivedPacket) error {
+	hdr := p.header
+	// ignore all Public Reset packets
+	if hdr.ResetFlag {
+		return fmt.Errorf("Received unexpected Public Reset for connection %s", hdr.DestConnectionID)
+	}
+
+	// Probably an old packet that was sent by the client before the version was negotiated.
+	// It is safe to drop it.
+	if (hdr.VersionFlag || hdr.IsLongHeader) && hdr.Version != s.quicSession.GetVersion() {
+		return nil
+	}
+
+	if hdr.IsLongHeader {
+		switch hdr.Type {
+		case protocol.PacketTypeHandshake:
+			// nothing to do here. Packet will be passed to the session.
+		default:
+			// Note that this also drops 0-RTT packets.
+			return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
+		}
+	}
+
+	s.quicSession.handlePacket(p)
+	return nil
+}
+
+func (s *serverSession) GetPerspective() protocol.Perspective {
+	return protocol.PerspectiveServer
+}

+ 94 - 169
vendor/github.com/lucas-clemente/quic-go/server_tls.go

@@ -1,48 +1,31 @@
 package quic
 
 import (
+	"bytes"
 	"crypto/tls"
 	"errors"
-	"fmt"
 	"net"
 
 	"github.com/bifurcation/mint"
-	"github.com/lucas-clemente/quic-go/internal/crypto"
 	"github.com/lucas-clemente/quic-go/internal/handshake"
 	"github.com/lucas-clemente/quic-go/internal/protocol"
 	"github.com/lucas-clemente/quic-go/internal/utils"
 	"github.com/lucas-clemente/quic-go/internal/wire"
-	"github.com/lucas-clemente/quic-go/qerr"
 )
 
-type nullAEAD struct {
-	aead crypto.AEAD
-}
-
-var _ quicAEAD = &nullAEAD{}
-
-func (n *nullAEAD) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
-	return n.aead.Open(dst, src, packetNumber, associatedData)
-}
-
-func (n *nullAEAD) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
-	return nil, errors.New("no 1-RTT keys")
-}
-
 type tlsSession struct {
 	connID protocol.ConnectionID
-	sess   packetHandler
+	sess   quicSession
 }
 
 type serverTLS struct {
-	conn              net.PacketConn
-	config            *Config
-	supportedVersions []protocol.VersionNumber
-	mintConf          *mint.Config
-	params            *handshake.TransportParameters
-	newMintConn       func(*handshake.CryptoStreamConn, protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error)
+	conn            net.PacketConn
+	config          *Config
+	mintConf        *mint.Config
+	params          *handshake.TransportParameters
+	cookieGenerator *handshake.CookieGenerator
 
-	newSession func(connection, sessionRunner, protocol.ConnectionID, protocol.ConnectionID, protocol.PacketNumber, *Config, handshake.MintTLS, *handshake.CryptoStreamConn, crypto.AEAD, *handshake.TransportParameters, protocol.VersionNumber, utils.Logger) (quicSession, error)
+	newSession func(connection, sessionRunner, protocol.ConnectionID, protocol.ConnectionID, protocol.ConnectionID, protocol.PacketNumber, *Config, *mint.Config, *handshake.TransportParameters, utils.Logger, protocol.VersionNumber) (quicSession, error)
 
 	sessionRunner sessionRunner
 	sessionChan   chan<- tlsSession
@@ -54,48 +37,47 @@ func newServerTLS(
 	conn net.PacketConn,
 	config *Config,
 	runner sessionRunner,
-	cookieHandler *handshake.CookieHandler,
 	tlsConf *tls.Config,
 	logger utils.Logger,
 ) (*serverTLS, <-chan tlsSession, error) {
-	mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer)
+	cookieGenerator, err := handshake.NewCookieGenerator()
 	if err != nil {
 		return nil, nil, err
 	}
-	mconf.RequireCookie = true
-	cs, err := mint.NewDefaultCookieProtector()
+	params := &handshake.TransportParameters{
+		StreamFlowControlWindow:     protocol.ReceiveStreamFlowControlWindow,
+		ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
+		IdleTimeout:                 config.IdleTimeout,
+		MaxBidiStreams:              uint16(config.MaxIncomingStreams),
+		MaxUniStreams:               uint16(config.MaxIncomingUniStreams),
+		DisableMigration:            true,
+		// TODO(#855): generate a real token
+		StatelessResetToken: bytes.Repeat([]byte{42}, 16),
+	}
+	mconf, err := tlsToMintConfig(tlsConf, protocol.PerspectiveServer)
 	if err != nil {
 		return nil, nil, err
 	}
-	mconf.CookieProtector = cs
-	mconf.CookieHandler = cookieHandler
 
 	sessionChan := make(chan tlsSession)
 	s := &serverTLS{
-		conn:              conn,
-		config:            config,
-		supportedVersions: config.Versions,
-		mintConf:          mconf,
-		sessionRunner:     runner,
-		sessionChan:       sessionChan,
-		params: &handshake.TransportParameters{
-			StreamFlowControlWindow:     protocol.ReceiveStreamFlowControlWindow,
-			ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
-			IdleTimeout:                 config.IdleTimeout,
-			MaxBidiStreams:              uint16(config.MaxIncomingStreams),
-			MaxUniStreams:               uint16(config.MaxIncomingUniStreams),
-		},
-		newSession: newTLSServerSession,
-		logger:     logger,
-	}
-	s.newMintConn = s.newMintConnImpl
+		conn:            conn,
+		config:          config,
+		mintConf:        mconf,
+		sessionRunner:   runner,
+		sessionChan:     sessionChan,
+		cookieGenerator: cookieGenerator,
+		params:          params,
+		newSession:      newTLSServerSession,
+		logger:          logger,
+	}
 	return s, sessionChan, nil
 }
 
-func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []byte) {
+func (s *serverTLS) HandleInitial(p *receivedPacket) {
 	// TODO: add a check that DestConnID == SrcConnID
-	s.logger.Debugf("Received a Packet. Handling it statelessly.")
-	sess, connID, err := s.handleInitialImpl(remoteAddr, hdr, data)
+	s.logger.Debugf("<- Received Initial packet.")
+	sess, connID, err := s.handleInitialImpl(p)
 	if err != nil {
 		s.logger.Errorf("Error occurred handling initial packet: %s", err)
 		return
@@ -109,145 +91,88 @@ func (s *serverTLS) HandleInitial(remoteAddr net.Addr, hdr *wire.Header, data []
 	}
 }
 
-// will be set to s.newMintConn by the constructor
-func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.VersionNumber) (handshake.MintTLS, <-chan handshake.TransportParameters, error) {
-	extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, v, s.logger)
-	conf := s.mintConf.Clone()
-	conf.ExtensionHandler = extHandler
-	return newMintController(bc, conf, protocol.PerspectiveServer), extHandler.GetPeerParams(), nil
-}
-
-func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Header, aead crypto.AEAD, closeErr error) error {
-	ccf := &wire.ConnectionCloseFrame{
-		ErrorCode:    qerr.HandshakeFailed,
-		ReasonPhrase: closeErr.Error(),
-	}
-	replyHdr := &wire.Header{
-		IsLongHeader:     true,
-		Type:             protocol.PacketTypeHandshake,
-		SrcConnectionID:  clientHdr.DestConnectionID,
-		DestConnectionID: clientHdr.SrcConnectionID,
-		PacketNumber:     1, // random packet number
-		PacketNumberLen:  protocol.PacketNumberLen1,
-		Version:          clientHdr.Version,
-	}
-	data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer, s.logger)
-	if err != nil {
-		return err
-	}
-	_, err = s.conn.WriteTo(data, remoteAddr)
-	return err
-}
-
-func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, protocol.ConnectionID, error) {
-	if hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
+func (s *serverTLS) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) {
+	hdr := p.header
+	if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
 		return nil, nil, errors.New("dropping Initial packet with too short connection ID")
 	}
-	if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize {
+	if len(hdr.Raw)+len(p.data) < protocol.MinInitialPacketSize {
 		return nil, nil, errors.New("dropping too small Initial packet")
 	}
-	// check version, if not matching send VNP
-	if !protocol.IsSupportedVersion(s.supportedVersions, hdr.Version) {
-		s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
-		vnp, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.supportedVersions)
-		if err != nil {
-			return nil, nil, err
-		}
-		_, err = s.conn.WriteTo(vnp, remoteAddr)
-		return nil, nil, err
-	}
-
-	// unpack packet and check stream frame contents
-	aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.DestConnectionID, protocol.VersionTLS)
-	if err != nil {
-		return nil, nil, err
-	}
-	frame, err := unpackInitialPacket(aead, hdr, data, s.logger, hdr.Version)
-	if err != nil {
-		s.logger.Debugf("Error unpacking initial packet: %s", err)
-		return nil, nil, nil
-	}
-	sess, connID, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead)
-	if err != nil {
-		if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil {
-			s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", ccerr)
-		}
-		return nil, nil, err
-	}
-	return sess, connID, nil
-}
 
-func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, protocol.ConnectionID, error) {
-	version := hdr.Version
-	bc := handshake.NewCryptoStreamConn(remoteAddr)
-	bc.AddDataForReading(frame.Data)
-	tls, paramsChan, err := s.newMintConn(bc, version)
-	if err != nil {
-		return nil, nil, err
-	}
-	alert := tls.Handshake()
-	if alert == mint.AlertStatelessRetry {
-		// the HelloRetryRequest was written to the bufferConn
-		// Take that data and write send a Retry packet
-		f := &wire.StreamFrame{
-			StreamID: version.CryptoStreamID(),
-			Data:     bc.GetDataForWriting(),
-		}
-		replyHdr := &wire.Header{
-			IsLongHeader:     true,
-			Type:             protocol.PacketTypeRetry,
-			DestConnectionID: hdr.SrcConnectionID,
-			SrcConnectionID:  hdr.DestConnectionID,
-			PayloadLen:       f.Length(version) + protocol.ByteCount(aead.Overhead()),
-			PacketNumber:     hdr.PacketNumber, // echo the client's packet number
-			PacketNumberLen:  hdr.PacketNumberLen,
-			Version:          version,
+	var cookie *handshake.Cookie
+	if len(hdr.Token) > 0 {
+		c, err := s.cookieGenerator.DecodeToken(hdr.Token)
+		if err == nil {
+			cookie = c
 		}
-		data, err := packUnencryptedPacket(aead, replyHdr, f, protocol.PerspectiveServer, s.logger)
-		if err != nil {
-			return nil, nil, err
-		}
-		_, err = s.conn.WriteTo(data, remoteAddr)
-		return nil, nil, err
-	}
-	if alert != mint.AlertNoAlert {
-		return nil, nil, alert
-	}
-	if tls.State() != mint.StateServerNegotiated {
-		return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerNegotiated, tls.State())
-	}
-	if alert := tls.Handshake(); alert != mint.AlertNoAlert {
-		return nil, nil, alert
 	}
-	if tls.State() != mint.StateServerWaitFlight2 {
-		return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
+	if !s.config.AcceptCookie(p.remoteAddr, cookie) {
+		// Log the Initial packet now.
+		// If no Retry is sent, the packet will be logged by the session.
+		p.header.Log(s.logger)
+		return nil, nil, s.sendRetry(p.remoteAddr, hdr)
 	}
-	params := <-paramsChan
-	connID, err := protocol.GenerateConnectionID()
+
+	extHandler := handshake.NewExtensionHandlerServer(s.params, s.config.Versions, hdr.Version, s.logger)
+	mconf := s.mintConf.Clone()
+	mconf.ExtensionHandler = extHandler
+
+	// A server is allowed to perform multiple Retries.
+	// It doesn't make much sense, but it's something that our API allows.
+	// In that case it must use a source connection ID of at least 8 bytes.
+	connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
 	if err != nil {
 		return nil, nil, err
 	}
-	s.logger.Debugf("Changing source connection ID to %s.", connID)
+	s.logger.Debugf("Changing connection ID to %s.", connID)
 	sess, err := s.newSession(
-		&conn{pconn: s.conn, currentAddr: remoteAddr},
+		&conn{pconn: s.conn, currentAddr: p.remoteAddr},
 		s.sessionRunner,
+		hdr.DestConnectionID,
 		hdr.SrcConnectionID,
 		connID,
-		protocol.PacketNumber(1), // TODO: use a random packet number here
+		1,
 		s.config,
-		tls,
-		bc,
-		aead,
-		&params,
-		version,
+		mconf,
+		s.params,
 		s.logger,
+		hdr.Version,
 	)
 	if err != nil {
 		return nil, nil, err
 	}
-	cs := sess.getCryptoStream()
-	cs.setReadOffset(frame.DataLen())
-	bc.SetStream(cs)
 	go sess.run()
+	sess.handlePacket(p)
 	return sess, connID, nil
 }
+
+func (s *serverTLS) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
+	token, err := s.cookieGenerator.NewToken(remoteAddr)
+	if err != nil {
+		return err
+	}
+	connID, err := protocol.GenerateConnectionIDForInitial()
+	if err != nil {
+		return err
+	}
+	replyHdr := &wire.Header{
+		IsLongHeader:         true,
+		Type:                 protocol.PacketTypeRetry,
+		Version:              hdr.Version,
+		SrcConnectionID:      connID,
+		DestConnectionID:     hdr.SrcConnectionID,
+		OrigDestConnectionID: hdr.DestConnectionID,
+		Token:                token,
+	}
+	s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID)
+	replyHdr.Log(s.logger)
+	buf := &bytes.Buffer{}
+	if err := replyHdr.Write(buf, protocol.PerspectiveServer, hdr.Version); err != nil {
+		return err
+	}
+	if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
+		s.logger.Debugf("Error sending Retry: %s", err)
+	}
+	return nil
+}

+ 79 - 62
vendor/github.com/lucas-clemente/quic-go/session.go

@@ -10,9 +10,9 @@ import (
 	"sync"
 	"time"
 
+	"github.com/bifurcation/mint"
 	"github.com/lucas-clemente/quic-go/internal/ackhandler"
 	"github.com/lucas-clemente/quic-go/internal/congestion"
-	"github.com/lucas-clemente/quic-go/internal/crypto"
 	"github.com/lucas-clemente/quic-go/internal/flowcontrol"
 	"github.com/lucas-clemente/quic-go/internal/handshake"
 	"github.com/lucas-clemente/quic-go/internal/protocol"
@@ -67,8 +67,9 @@ var (
 )
 
 type closeError struct {
-	err    error
-	remote bool
+	err       error
+	remote    bool
+	sendClose bool
 }
 
 // A Session is a QUIC session
@@ -85,7 +86,7 @@ type session struct {
 	conn connection
 
 	streamsMap   streamManager
-	cryptoStream cryptoStreamI
+	cryptoStream cryptoStream
 
 	rttStats *congestion.RTTStats
 
@@ -209,6 +210,7 @@ func newSession(
 		1,
 		s.sentPacketHandler.GetPacketNumberLen,
 		s.RemoteAddr(),
+		nil, // no token
 		divNonce,
 		cs,
 		s.streamFramer,
@@ -279,6 +281,7 @@ var newClientSession = func(
 		1,
 		s.sentPacketHandler.GetPacketNumberLen,
 		s.RemoteAddr(),
+		nil, // no token
 		nil, // no diversification nonce
 		cs,
 		s.streamFramer,
@@ -291,16 +294,15 @@ var newClientSession = func(
 func newTLSServerSession(
 	conn connection,
 	runner sessionRunner,
+	origConnID protocol.ConnectionID,
 	destConnID protocol.ConnectionID,
 	srcConnID protocol.ConnectionID,
 	initialPacketNumber protocol.PacketNumber,
 	config *Config,
-	tls handshake.MintTLS,
-	cryptoStreamConn *handshake.CryptoStreamConn,
-	nullAEAD crypto.AEAD,
+	mintConf *mint.Config,
 	peerParams *handshake.TransportParameters,
-	v protocol.VersionNumber,
 	logger utils.Logger,
+	v protocol.VersionNumber,
 ) (quicSession, error) {
 	handshakeEvent := make(chan struct{}, 1)
 	s := &session{
@@ -315,13 +317,16 @@ func newTLSServerSession(
 		logger:         logger,
 	}
 	s.preSetup()
-	cs := handshake.NewCryptoSetupTLSServer(
-		tls,
-		cryptoStreamConn,
-		nullAEAD,
+	cs, err := handshake.NewCryptoSetupTLSServer(
+		s.cryptoStream,
+		origConnID,
+		mintConf,
 		handshakeEvent,
 		v,
 	)
+	if err != nil {
+		return nil, err
+	}
 	s.cryptoStreamHandler = cs
 	s.streamsMap = newStreamsMap(s, s.newFlowController, s.config.MaxIncomingStreams, s.config.MaxIncomingUniStreams, s.perspective, s.version)
 	s.streamFramer = newStreamFramer(s.cryptoStream, s.streamsMap, s.version)
@@ -331,6 +336,7 @@ func newTLSServerSession(
 		initialPacketNumber,
 		s.sentPacketHandler.GetPacketNumberLen,
 		s.RemoteAddr(),
+		nil, // no token
 		nil, // no diversification nonce
 		cs,
 		s.streamFramer,
@@ -350,21 +356,21 @@ func newTLSServerSession(
 var newTLSClientSession = func(
 	conn connection,
 	runner sessionRunner,
-	hostname string,
-	v protocol.VersionNumber,
+	token []byte,
 	destConnID protocol.ConnectionID,
 	srcConnID protocol.ConnectionID,
-	config *Config,
-	tls handshake.MintTLS,
+	conf *Config,
+	mintConf *mint.Config,
 	paramsChan <-chan handshake.TransportParameters,
 	initialPacketNumber protocol.PacketNumber,
 	logger utils.Logger,
+	v protocol.VersionNumber,
 ) (quicSession, error) {
 	handshakeEvent := make(chan struct{}, 1)
 	s := &session{
 		conn:           conn,
 		sessionRunner:  runner,
-		config:         config,
+		config:         conf,
 		srcConnID:      srcConnID,
 		destConnID:     destConnID,
 		perspective:    protocol.PerspectiveClient,
@@ -374,13 +380,11 @@ var newTLSClientSession = func(
 		logger:         logger,
 	}
 	s.preSetup()
-	tls.SetCryptoStream(s.cryptoStream)
 	cs, err := handshake.NewCryptoSetupTLSClient(
 		s.cryptoStream,
 		s.destConnID,
-		hostname,
+		mintConf,
 		handshakeEvent,
-		tls,
 		v,
 	)
 	if err != nil {
@@ -396,6 +400,7 @@ var newTLSClientSession = func(
 		initialPacketNumber,
 		s.sentPacketHandler.GetPacketNumberLen,
 		s.RemoteAddr(),
+		token,
 		nil, // no diversification nonce
 		cs,
 		s.streamFramer,
@@ -441,7 +446,7 @@ func (s *session) run() error {
 
 	go func() {
 		if err := s.cryptoStreamHandler.HandleCryptoStream(); err != nil {
-			s.Close(err)
+			s.closeLocal(err)
 		}
 	}()
 
@@ -506,7 +511,8 @@ runLoop:
 			pacingDeadline = s.sentPacketHandler.TimeUntilSend()
 		}
 		if s.config.KeepAlive && !s.keepAlivePingSent && s.handshakeComplete && time.Since(s.lastNetworkActivityTime) >= s.peerParams.IdleTimeout/2 {
-			// send the PING frame since there is no activity in the session
+			// send a PING frame since there is no activity in the session
+			s.logger.Debugf("Sending a keep-alive ping to keep the connection alive.")
 			s.packer.QueueControlFrame(&wire.PingFrame{})
 			s.keepAlivePingSent = true
 		} else if !pacingDeadline.IsZero() && now.Before(pacingDeadline) {
@@ -823,9 +829,17 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encrypt
 	return nil
 }
 
+// closeLocal closes the session and send a CONNECTION_CLOSE containing the error
 func (s *session) closeLocal(e error) {
 	s.closeOnce.Do(func() {
-		s.closeChan <- closeError{err: e, remote: false}
+		s.closeChan <- closeError{err: e, sendClose: true, remote: false}
+	})
+}
+
+// destroy closes the session without sending the error on the wire
+func (s *session) destroy(e error) {
+	s.closeOnce.Do(func() {
+		s.closeChan <- closeError{err: e, sendClose: false, remote: false}
 	})
 }
 
@@ -835,10 +849,16 @@ func (s *session) closeRemote(e error) {
 	})
 }
 
-// Close the connection. If err is nil it will be set to qerr.PeerGoingAway.
+// Close the connection. It sends a qerr.PeerGoingAway.
 // It waits until the run loop has stopped before returning
-func (s *session) Close(e error) error {
-	s.closeLocal(e)
+func (s *session) Close() error {
+	s.closeLocal(nil)
+	<-s.ctx.Done()
+	return nil
+}
+
+func (s *session) CloseWithError(code protocol.ApplicationErrorCode, e error) error {
+	s.closeLocal(qerr.Error(qerr.ErrorCode(code), e.Error()))
 	<-s.ctx.Done()
 	return nil
 }
@@ -863,7 +883,7 @@ func (s *session) handleCloseError(closeErr closeError) error {
 	s.cryptoStream.closeForShutdown(quicErr)
 	s.streamsMap.CloseWithError(quicErr)
 
-	if closeErr.err == errCloseSessionForNewVersion || closeErr.err == handshake.ErrCloseSessionForRetry {
+	if !closeErr.sendClose {
 		return nil
 	}
 
@@ -913,37 +933,11 @@ sendLoop:
 			// There will only be a new ACK after receiving new packets.
 			// SendAck is only returned when we're congestion limited, so we don't need to set the pacingt timer.
 			return s.maybeSendAckOnlyPacket()
-		case ackhandler.SendRTO:
-			// try to send a retransmission first
-			sentPacket, err := s.maybeSendRetransmission()
-			if err != nil {
+		case ackhandler.SendTLP, ackhandler.SendRTO:
+			if err := s.sendProbePacket(); err != nil {
 				return err
 			}
-			if !sentPacket {
-				// In RTO mode, a probe packet has to be sent.
-				// Add a PING frame to make sure a (retransmittable) packet will be sent.
-				s.queueControlFrame(&wire.PingFrame{})
-				sentPacket, err := s.sendPacket()
-				if err != nil {
-					return err
-				}
-				if !sentPacket {
-					return errors.New("session BUG: expected a packet to be sent in RTO mode")
-				}
-			}
 			numPacketsSent++
-		case ackhandler.SendTLP:
-			// In TLP mode, a probe packet has to be sent.
-			// Add a PING frame to make sure a (retransmittable) packet will be sent.
-			s.queueControlFrame(&wire.PingFrame{})
-			sentPacket, err := s.sendPacket()
-			if err != nil {
-				return err
-			}
-			if !sentPacket {
-				return errors.New("session BUG: expected a packet to be sent in TLP mode")
-			}
-			return nil
 		case ackhandler.SendRetransmission:
 			sentPacket, err := s.maybeSendRetransmission()
 			if err != nil {
@@ -1045,6 +1039,33 @@ func (s *session) maybeSendRetransmission() (bool, error) {
 	return true, nil
 }
 
+func (s *session) sendProbePacket() error {
+	p, err := s.sentPacketHandler.DequeueProbePacket()
+	if err != nil {
+		return err
+	}
+	s.logger.Debugf("Sending a retransmission for %#x as a probe packet.", p.PacketNumber)
+
+	if s.version.UsesStopWaitingFrames() {
+		s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true))
+	}
+	packets, err := s.packer.PackRetransmission(p)
+	if err != nil {
+		return err
+	}
+	ackhandlerPackets := make([]*ackhandler.Packet, len(packets))
+	for i, packet := range packets {
+		ackhandlerPackets[i] = packet.ToAckHandlerPacket()
+	}
+	s.sentPacketHandler.SentPacketsAsRetransmission(ackhandlerPackets, p.PacketNumber)
+	for _, packet := range packets {
+		if err := s.sendPackedPacket(packet); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (s *session) sendPacket() (bool, error) {
 	if isBlocked, offset := s.connFlowController.IsNewlyBlocked(); isBlocked {
 		s.packer.QueueControlFrame(&wire.BlockedFrame{Offset: offset})
@@ -1165,7 +1186,7 @@ func (s *session) newFlowController(id protocol.StreamID) flowcontrol.StreamFlow
 	)
 }
 
-func (s *session) newCryptoStream() cryptoStreamI {
+func (s *session) newCryptoStream() cryptoStream {
 	id := s.version.CryptoStreamID()
 	flowController := flowcontrol.NewStreamFlowController(
 		id,
@@ -1182,7 +1203,7 @@ func (s *session) newCryptoStream() cryptoStreamI {
 }
 
 func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error {
-	s.logger.Infof("Sending public reset for connection %x, packet number %d", s.destConnID, rejectedPacketNumber)
+	s.logger.Infof("Sending PUBLIC_RESET for connection %s, packet number %d", s.destConnID, rejectedPacketNumber)
 	return s.conn.Write(wire.WritePublicReset(s.destConnID, rejectedPacketNumber, 0))
 }
 
@@ -1241,7 +1262,7 @@ func (s *session) onHasStreamData(id protocol.StreamID) {
 
 func (s *session) onStreamCompleted(id protocol.StreamID) {
 	if err := s.streamsMap.DeleteStream(id); err != nil {
-		s.Close(err)
+		s.closeLocal(err)
 	}
 }
 
@@ -1253,10 +1274,6 @@ func (s *session) RemoteAddr() net.Addr {
 	return s.conn.RemoteAddr()
 }
 
-func (s *session) getCryptoStream() cryptoStreamI {
-	return s.cryptoStream
-}
-
 func (s *session) GetVersion() protocol.VersionNumber {
 	return s.version
 }

+ 2 - 2
vendor/github.com/lucas-clemente/quic-go/stream_framer.go

@@ -9,7 +9,7 @@ import (
 
 type streamFramer struct {
 	streamGetter streamGetter
-	cryptoStream cryptoStreamI
+	cryptoStream cryptoStream
 	version      protocol.VersionNumber
 
 	streamQueueMutex    sync.Mutex
@@ -19,7 +19,7 @@ type streamFramer struct {
 }
 
 func newStreamFramer(
-	cryptoStream cryptoStreamI,
+	cryptoStream cryptoStream,
 	streamGetter streamGetter,
 	v protocol.VersionNumber,
 ) *streamFramer {

+ 2 - 2
vendor/github.com/lucas-clemente/quic-go/window_update_queue.go

@@ -14,7 +14,7 @@ type windowUpdateQueue struct {
 	queue      map[protocol.StreamID]bool // used as a set
 	queuedConn bool                       // connection-level window update
 
-	cryptoStream       cryptoStreamI
+	cryptoStream       cryptoStream
 	streamGetter       streamGetter
 	connFlowController flowcontrol.ConnectionFlowController
 	callback           func(wire.Frame)
@@ -22,7 +22,7 @@ type windowUpdateQueue struct {
 
 func newWindowUpdateQueue(
 	streamGetter streamGetter,
-	cryptoStream cryptoStreamI,
+	cryptoStream cryptoStream,
 	connFC flowcontrol.ConnectionFlowController,
 	cb func(wire.Frame),
 ) *windowUpdateQueue {

+ 27 - 27
vendor/vendor.json

@@ -219,10 +219,10 @@
 			"revisionTime": "2017-10-27T16:34:21Z"
 		},
 		{
-			"checksumSHA1": "26ynUbB0vZuG+xhYS8+61i5+Ds0=",
+			"checksumSHA1": "FUvpp4RI9ZqYdH46mt1bjojuuo0=",
 			"path": "github.com/lucas-clemente/quic-go",
-			"revision": "3f9212b5f73c2679268447584e76cd7c9edf8e98",
-			"revisionTime": "2018-06-27T02:56:43Z"
+			"revision": "4d2d2420a4389e2af24d96337feff51951839b22",
+			"revisionTime": "2018-08-20T11:33:42Z"
 		},
 		{
 			"checksumSHA1": "OA9E+y7g05x/mWJJHmA7oPxWKQo=",
@@ -231,58 +231,58 @@
 			"revisionTime": "2016-08-23T09:51:56Z"
 		},
 		{
-			"checksumSHA1": "q5Mmgdu/11zEFx8Qew7wt0BvR34=",
+			"checksumSHA1": "xofp3Exz+2Bna8U2fSFil8aeNK4=",
 			"path": "github.com/lucas-clemente/quic-go/internal/ackhandler",
-			"revision": "3f9212b5f73c2679268447584e76cd7c9edf8e98",
-			"revisionTime": "2018-06-27T02:56:43Z"
+			"revision": "4d2d2420a4389e2af24d96337feff51951839b22",
+			"revisionTime": "2018-08-20T11:33:42Z"
 		},
 		{
 			"checksumSHA1": "i1yfut7QQqMehw5yE9llhWNnrxk=",
 			"path": "github.com/lucas-clemente/quic-go/internal/congestion",
-			"revision": "3f9212b5f73c2679268447584e76cd7c9edf8e98",
-			"revisionTime": "2018-06-27T02:56:43Z"
+			"revision": "4d2d2420a4389e2af24d96337feff51951839b22",
+			"revisionTime": "2018-08-20T11:33:42Z"
 		},
 		{
-			"checksumSHA1": "8CRRInUpwdxqXFGWnrW1KTUYOUE=",
+			"checksumSHA1": "iDyiuv67gAM4KKfl51vU3QtOFz8=",
 			"path": "github.com/lucas-clemente/quic-go/internal/crypto",
-			"revision": "3f9212b5f73c2679268447584e76cd7c9edf8e98",
-			"revisionTime": "2018-06-27T02:56:43Z"
+			"revision": "4d2d2420a4389e2af24d96337feff51951839b22",
+			"revisionTime": "2018-08-20T11:33:42Z"
 		},
 		{
-			"checksumSHA1": "rnRicg73lPAeRh9Nko6a0CZQS5I=",
+			"checksumSHA1": "hLazAfY6qHoV3USMxA7pSPnTqy8=",
 			"path": "github.com/lucas-clemente/quic-go/internal/flowcontrol",
-			"revision": "3f9212b5f73c2679268447584e76cd7c9edf8e98",
-			"revisionTime": "2018-06-27T02:56:43Z"
+			"revision": "4d2d2420a4389e2af24d96337feff51951839b22",
+			"revisionTime": "2018-08-20T11:33:42Z"
 		},
 		{
-			"checksumSHA1": "VDAmO1aQcHrZGfHBolw4bMjlCIo=",
+			"checksumSHA1": "1EPOPYxoK/ZVqB91d7329CMSsE8=",
 			"path": "github.com/lucas-clemente/quic-go/internal/handshake",
-			"revision": "3f9212b5f73c2679268447584e76cd7c9edf8e98",
-			"revisionTime": "2018-06-27T02:56:43Z"
+			"revision": "4d2d2420a4389e2af24d96337feff51951839b22",
+			"revisionTime": "2018-08-20T11:33:42Z"
 		},
 		{
-			"checksumSHA1": "WcERuY6LQVVwsulsp733jXwSXrE=",
+			"checksumSHA1": "vh1QIciVIx9N+0C7J4hQfdAW4iY=",
 			"path": "github.com/lucas-clemente/quic-go/internal/protocol",
-			"revision": "3f9212b5f73c2679268447584e76cd7c9edf8e98",
-			"revisionTime": "2018-06-27T02:56:43Z"
+			"revision": "4d2d2420a4389e2af24d96337feff51951839b22",
+			"revisionTime": "2018-08-20T11:33:42Z"
 		},
 		{
 			"checksumSHA1": "0vSbWIQ7O34u4kDMR+FHr7/FINk=",
 			"path": "github.com/lucas-clemente/quic-go/internal/utils",
-			"revision": "3f9212b5f73c2679268447584e76cd7c9edf8e98",
-			"revisionTime": "2018-06-27T02:56:43Z"
+			"revision": "4d2d2420a4389e2af24d96337feff51951839b22",
+			"revisionTime": "2018-08-20T11:33:42Z"
 		},
 		{
-			"checksumSHA1": "XffpGTFqeLH70f+ToHMqmmt4wjE=",
+			"checksumSHA1": "bBhsaiBWBOUTXJoj2Rju7Q8BXnU=",
 			"path": "github.com/lucas-clemente/quic-go/internal/wire",
-			"revision": "3f9212b5f73c2679268447584e76cd7c9edf8e98",
-			"revisionTime": "2018-06-27T02:56:43Z"
+			"revision": "4d2d2420a4389e2af24d96337feff51951839b22",
+			"revisionTime": "2018-08-20T11:33:42Z"
 		},
 		{
 			"checksumSHA1": "bFSC4TOZGOZGBJEFmLAT3V4ieoo=",
 			"path": "github.com/lucas-clemente/quic-go/qerr",
-			"revision": "3f9212b5f73c2679268447584e76cd7c9edf8e98",
-			"revisionTime": "2018-06-27T02:56:43Z"
+			"revision": "4d2d2420a4389e2af24d96337feff51951839b22",
+			"revisionTime": "2018-08-20T11:33:42Z"
 		},
 		{
 			"checksumSHA1": "sY8sshVIEXnJgg3S6C5FcN33Vq4=",