| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912 |
- package quic
- import (
- crand "crypto/rand"
- "encoding/binary"
- "errors"
- "fmt"
- "golang.org/x/exp/rand"
- "github.com/Psiphon-Labs/quic-go/internal/ackhandler"
- "github.com/Psiphon-Labs/quic-go/internal/handshake"
- "github.com/Psiphon-Labs/quic-go/internal/protocol"
- "github.com/Psiphon-Labs/quic-go/internal/qerr"
- "github.com/Psiphon-Labs/quic-go/internal/wire"
- )
- var errNothingToPack = errors.New("nothing to pack")
- type packer interface {
- PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error)
- PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
- AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error)
- MaybePackProbePacket(protocol.EncryptionLevel, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
- PackConnectionClose(*qerr.TransportError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
- PackApplicationClose(*qerr.ApplicationError, protocol.ByteCount, protocol.VersionNumber) (*coalescedPacket, error)
- PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error)
- SetToken([]byte)
- }
- type sealer interface {
- handshake.LongHeaderSealer
- }
- type payload struct {
- streamFrames []ackhandler.StreamFrame
- frames []ackhandler.Frame
- ack *wire.AckFrame
- length protocol.ByteCount
- }
- type longHeaderPacket struct {
- header *wire.ExtendedHeader
- ack *wire.AckFrame
- frames []ackhandler.Frame
- streamFrames []ackhandler.StreamFrame // only used for 0-RTT packets
- length protocol.ByteCount
- }
- type shortHeaderPacket struct {
- PacketNumber protocol.PacketNumber
- Frames []ackhandler.Frame
- StreamFrames []ackhandler.StreamFrame
- Ack *wire.AckFrame
- Length protocol.ByteCount
- IsPathMTUProbePacket bool
- // used for logging
- DestConnID protocol.ConnectionID
- PacketNumberLen protocol.PacketNumberLen
- KeyPhase protocol.KeyPhaseBit
- }
- func (p *shortHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.Frames) }
- type coalescedPacket struct {
- buffer *packetBuffer
- longHdrPackets []*longHeaderPacket
- shortHdrPacket *shortHeaderPacket
- }
- // IsOnlyShortHeaderPacket says if this packet only contains a short header packet (and no long header packets).
- func (p *coalescedPacket) IsOnlyShortHeaderPacket() bool {
- return len(p.longHdrPackets) == 0 && p.shortHdrPacket != nil
- }
- func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel {
- //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data).
- switch p.header.Type {
- case protocol.PacketTypeInitial:
- return protocol.EncryptionInitial
- case protocol.PacketTypeHandshake:
- return protocol.EncryptionHandshake
- case protocol.PacketType0RTT:
- return protocol.Encryption0RTT
- default:
- panic("can't determine encryption level")
- }
- }
- func (p *longHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) }
- type packetNumberManager interface {
- PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
- PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
- }
- type sealingManager interface {
- GetInitialSealer() (handshake.LongHeaderSealer, error)
- GetHandshakeSealer() (handshake.LongHeaderSealer, error)
- Get0RTTSealer() (handshake.LongHeaderSealer, error)
- Get1RTTSealer() (handshake.ShortHeaderSealer, error)
- }
- type frameSource interface {
- HasData() bool
- AppendStreamFrames([]ackhandler.StreamFrame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.StreamFrame, protocol.ByteCount)
- AppendControlFrames([]ackhandler.Frame, protocol.ByteCount, protocol.VersionNumber) ([]ackhandler.Frame, protocol.ByteCount)
- }
- type ackFrameSource interface {
- GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame
- }
- type packetPacker struct {
- srcConnID protocol.ConnectionID
- getDestConnID func() protocol.ConnectionID
- perspective protocol.Perspective
- cryptoSetup sealingManager
- initialStream cryptoStream
- handshakeStream cryptoStream
- token []byte
- pnManager packetNumberManager
- framer frameSource
- acks ackFrameSource
- datagramQueue *datagramQueue
- retransmissionQueue *retransmissionQueue
- rand rand.Rand
- numNonAckElicitingAcks int
- }
- var _ packer = &packetPacker{}
- func newPacketPacker(
- srcConnID protocol.ConnectionID,
- getDestConnID func() protocol.ConnectionID,
- initialStream, handshakeStream cryptoStream,
- packetNumberManager packetNumberManager,
- retransmissionQueue *retransmissionQueue,
- cryptoSetup sealingManager,
- framer frameSource,
- acks ackFrameSource,
- datagramQueue *datagramQueue,
- perspective protocol.Perspective,
- ) *packetPacker {
- var b [8]byte
- _, _ = crand.Read(b[:])
- return &packetPacker{
- cryptoSetup: cryptoSetup,
- getDestConnID: getDestConnID,
- srcConnID: srcConnID,
- initialStream: initialStream,
- handshakeStream: handshakeStream,
- retransmissionQueue: retransmissionQueue,
- datagramQueue: datagramQueue,
- perspective: perspective,
- framer: framer,
- acks: acks,
- rand: *rand.New(rand.NewSource(binary.BigEndian.Uint64(b[:]))),
- pnManager: packetNumberManager,
- }
- }
- // PackConnectionClose packs a packet that closes the connection with a transport error.
- func (p *packetPacker) PackConnectionClose(e *qerr.TransportError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
- var reason string
- // don't send details of crypto errors
- if !e.ErrorCode.IsCryptoError() {
- reason = e.ErrorMessage
- }
- return p.packConnectionClose(false, uint64(e.ErrorCode), e.FrameType, reason, maxPacketSize, v)
- }
- // PackApplicationClose packs a packet that closes the connection with an application error.
- func (p *packetPacker) PackApplicationClose(e *qerr.ApplicationError, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
- return p.packConnectionClose(true, uint64(e.ErrorCode), 0, e.ErrorMessage, maxPacketSize, v)
- }
- func (p *packetPacker) packConnectionClose(
- isApplicationError bool,
- errorCode uint64,
- frameType uint64,
- reason string,
- maxPacketSize protocol.ByteCount,
- v protocol.VersionNumber,
- ) (*coalescedPacket, error) {
- var sealers [4]sealer
- var hdrs [3]*wire.ExtendedHeader
- var payloads [4]payload
- var size protocol.ByteCount
- var connID protocol.ConnectionID
- var oneRTTPacketNumber protocol.PacketNumber
- var oneRTTPacketNumberLen protocol.PacketNumberLen
- var keyPhase protocol.KeyPhaseBit // only set for 1-RTT
- var numLongHdrPackets uint8
- encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT}
- for i, encLevel := range encLevels {
- if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT {
- continue
- }
- ccf := &wire.ConnectionCloseFrame{
- IsApplicationError: isApplicationError,
- ErrorCode: errorCode,
- FrameType: frameType,
- ReasonPhrase: reason,
- }
- // don't send application errors in Initial or Handshake packets
- if isApplicationError && (encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake) {
- ccf.IsApplicationError = false
- ccf.ErrorCode = uint64(qerr.ApplicationErrorErrorCode)
- ccf.ReasonPhrase = ""
- }
- pl := payload{
- frames: []ackhandler.Frame{{Frame: ccf}},
- length: ccf.Length(v),
- }
- var sealer sealer
- var err error
- switch encLevel {
- case protocol.EncryptionInitial:
- sealer, err = p.cryptoSetup.GetInitialSealer()
- case protocol.EncryptionHandshake:
- sealer, err = p.cryptoSetup.GetHandshakeSealer()
- case protocol.Encryption0RTT:
- sealer, err = p.cryptoSetup.Get0RTTSealer()
- case protocol.Encryption1RTT:
- var s handshake.ShortHeaderSealer
- s, err = p.cryptoSetup.Get1RTTSealer()
- if err == nil {
- keyPhase = s.KeyPhase()
- }
- sealer = s
- }
- if err == handshake.ErrKeysNotYetAvailable || err == handshake.ErrKeysDropped {
- continue
- }
- if err != nil {
- return nil, err
- }
- sealers[i] = sealer
- var hdr *wire.ExtendedHeader
- if encLevel == protocol.Encryption1RTT {
- connID = p.getDestConnID()
- oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
- size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, pl)
- } else {
- hdr = p.getLongHeader(encLevel, v)
- hdrs[i] = hdr
- size += p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead())
- numLongHdrPackets++
- }
- payloads[i] = pl
- }
- buffer := getPacketBuffer()
- packet := &coalescedPacket{
- buffer: buffer,
- longHdrPackets: make([]*longHeaderPacket, 0, numLongHdrPackets),
- }
- for i, encLevel := range encLevels {
- if sealers[i] == nil {
- continue
- }
- var paddingLen protocol.ByteCount
- if encLevel == protocol.EncryptionInitial {
- paddingLen = p.initialPaddingLen(payloads[i].frames, size, maxPacketSize)
- }
- if encLevel == protocol.Encryption1RTT {
- shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, maxPacketSize, sealers[i], false, v)
- if err != nil {
- return nil, err
- }
- packet.shortHdrPacket = &shp
- } else {
- longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], v)
- if err != nil {
- return nil, err
- }
- packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket)
- }
- }
- return packet, nil
- }
- // longHeaderPacketLength calculates the length of a serialized long header packet.
- // It takes into account that packets that have a tiny payload need to be padded,
- // such that len(payload) + packet number len >= 4 + AEAD overhead
- func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, pl payload, v protocol.VersionNumber) protocol.ByteCount {
- var paddingLen protocol.ByteCount
- pnLen := protocol.ByteCount(hdr.PacketNumberLen)
- if pl.length < 4-pnLen {
- paddingLen = 4 - pnLen - pl.length
- }
- return hdr.GetLength(v) + pl.length + paddingLen
- }
- // shortHeaderPacketLength calculates the length of a serialized short header packet.
- // It takes into account that packets that have a tiny payload need to be padded,
- // such that len(payload) + packet number len >= 4 + AEAD overhead
- func (p *packetPacker) shortHeaderPacketLength(connID protocol.ConnectionID, pnLen protocol.PacketNumberLen, pl payload) protocol.ByteCount {
- var paddingLen protocol.ByteCount
- if pl.length < 4-protocol.ByteCount(pnLen) {
- paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length
- }
- return wire.ShortHeaderLen(connID, pnLen) + pl.length + paddingLen
- }
- // size is the expected size of the packet, if no padding was applied.
- func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, currentSize, maxPacketSize protocol.ByteCount) protocol.ByteCount {
- // For the server, only ack-eliciting Initial packets need to be padded.
- if p.perspective == protocol.PerspectiveServer && !ackhandler.HasAckElicitingFrames(frames) {
- return 0
- }
- if currentSize >= maxPacketSize {
- return 0
- }
- return maxPacketSize - currentSize
- }
- // PackCoalescedPacket packs a new packet.
- // It packs an Initial / Handshake if there is data to send in these packet number spaces.
- // It should only be called before the handshake is confirmed.
- func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
- var (
- initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
- initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload
- oneRTTPacketNumber protocol.PacketNumber
- oneRTTPacketNumberLen protocol.PacketNumberLen
- )
- // Try packing an Initial packet.
- initialSealer, err := p.cryptoSetup.GetInitialSealer()
- if err != nil && err != handshake.ErrKeysDropped {
- return nil, err
- }
- var size protocol.ByteCount
- if initialSealer != nil {
- initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, onlyAck, true, v)
- if initialPayload.length > 0 {
- size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead())
- }
- }
- // Add a Handshake packet.
- var handshakeSealer sealer
- if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) {
- var err error
- handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer()
- if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
- return nil, err
- }
- if handshakeSealer != nil {
- handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, onlyAck, size == 0, v)
- if handshakePayload.length > 0 {
- s := p.longHeaderPacketLength(handshakeHdr, handshakePayload, v) + protocol.ByteCount(handshakeSealer.Overhead())
- size += s
- }
- }
- }
- // Add a 0-RTT / 1-RTT packet.
- var zeroRTTSealer sealer
- var oneRTTSealer handshake.ShortHeaderSealer
- var connID protocol.ConnectionID
- var kp protocol.KeyPhaseBit
- if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) {
- var err error
- oneRTTSealer, err = p.cryptoSetup.Get1RTTSealer()
- if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
- return nil, err
- }
- if err == nil { // 1-RTT
- kp = oneRTTSealer.KeyPhase()
- connID = p.getDestConnID()
- oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
- hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen)
- oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxPacketSize-size, onlyAck, size == 0, v)
- if oneRTTPayload.length > 0 {
- size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead())
- }
- } else if p.perspective == protocol.PerspectiveClient && !onlyAck { // 0-RTT packets can't contain ACK frames
- var err error
- zeroRTTSealer, err = p.cryptoSetup.Get0RTTSealer()
- if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
- return nil, err
- }
- if zeroRTTSealer != nil {
- zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxPacketSize-size, v)
- if zeroRTTPayload.length > 0 {
- size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload, v) + protocol.ByteCount(zeroRTTSealer.Overhead())
- }
- }
- }
- }
- if initialPayload.length == 0 && handshakePayload.length == 0 && zeroRTTPayload.length == 0 && oneRTTPayload.length == 0 {
- return nil, nil
- }
- buffer := getPacketBuffer()
- packet := &coalescedPacket{
- buffer: buffer,
- longHdrPackets: make([]*longHeaderPacket, 0, 3),
- }
- if initialPayload.length > 0 {
- padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize)
- cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v)
- if err != nil {
- return nil, err
- }
- packet.longHdrPackets = append(packet.longHdrPackets, cont)
- }
- if handshakePayload.length > 0 {
- cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v)
- if err != nil {
- return nil, err
- }
- packet.longHdrPackets = append(packet.longHdrPackets, cont)
- }
- if zeroRTTPayload.length > 0 {
- longHdrPacket, err := p.appendLongHeaderPacket(buffer, zeroRTTHdr, zeroRTTPayload, 0, protocol.Encryption0RTT, zeroRTTSealer, v)
- if err != nil {
- return nil, err
- }
- packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket)
- } else if oneRTTPayload.length > 0 {
- shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxPacketSize, oneRTTSealer, false, v)
- if err != nil {
- return nil, err
- }
- packet.shortHdrPacket = &shp
- }
- return packet, nil
- }
- // PackAckOnlyPacket packs a packet containing only an ACK in the application data packet number space.
- // It should be called after the handshake is confirmed.
- func (p *packetPacker) PackAckOnlyPacket(maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) {
- buf := getPacketBuffer()
- packet, err := p.appendPacket(buf, true, maxPacketSize, v)
- return packet, buf, err
- }
- // AppendPacket packs a packet in the application data packet number space.
- // It should be called after the handshake is confirmed.
- func (p *packetPacker) AppendPacket(buf *packetBuffer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) {
- return p.appendPacket(buf, false, maxPacketSize, v)
- }
- func (p *packetPacker) appendPacket(buf *packetBuffer, onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, error) {
- sealer, err := p.cryptoSetup.Get1RTTSealer()
- if err != nil {
- return shortHeaderPacket{}, err
- }
- pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
- connID := p.getDestConnID()
- hdrLen := wire.ShortHeaderLen(connID, pnLen)
- pl := p.maybeGetShortHeaderPacket(sealer, hdrLen, maxPacketSize, onlyAck, true, v)
- if pl.length == 0 {
- return shortHeaderPacket{}, errNothingToPack
- }
- kp := sealer.KeyPhase()
- return p.appendShortHeaderPacket(buf, connID, pn, pnLen, kp, pl, 0, maxPacketSize, sealer, false, v)
- }
- func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) {
- if onlyAck {
- if ack := p.acks.GetAckFrame(encLevel, true); ack != nil {
- return p.getLongHeader(encLevel, v), payload{
- ack: ack,
- length: ack.Length(v),
- }
- }
- return nil, payload{}
- }
- var s cryptoStream
- var handler ackhandler.FrameHandler
- var hasRetransmission bool
- //nolint:exhaustive // Initial and Handshake are the only two encryption levels here.
- switch encLevel {
- case protocol.EncryptionInitial:
- s = p.initialStream
- handler = p.retransmissionQueue.InitialAckHandler()
- hasRetransmission = p.retransmissionQueue.HasInitialData()
- case protocol.EncryptionHandshake:
- s = p.handshakeStream
- handler = p.retransmissionQueue.HandshakeAckHandler()
- hasRetransmission = p.retransmissionQueue.HasHandshakeData()
- }
- hasData := s.HasData()
- var ack *wire.AckFrame
- if ackAllowed {
- ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData)
- }
- if !hasData && !hasRetransmission && ack == nil {
- // nothing to send
- return nil, payload{}
- }
- var pl payload
- if ack != nil {
- pl.ack = ack
- pl.length = ack.Length(v)
- maxPacketSize -= pl.length
- }
- hdr := p.getLongHeader(encLevel, v)
- maxPacketSize -= hdr.GetLength(v)
- if hasRetransmission {
- for {
- var f ackhandler.Frame
- //nolint:exhaustive // 0-RTT packets can't contain any retransmission.s
- switch encLevel {
- case protocol.EncryptionInitial:
- f.Frame = p.retransmissionQueue.GetInitialFrame(maxPacketSize, v)
- f.Handler = p.retransmissionQueue.InitialAckHandler()
- case protocol.EncryptionHandshake:
- f.Frame = p.retransmissionQueue.GetHandshakeFrame(maxPacketSize, v)
- f.Handler = p.retransmissionQueue.HandshakeAckHandler()
- }
- if f.Frame == nil {
- break
- }
- pl.frames = append(pl.frames, f)
- frameLen := f.Frame.Length(v)
- pl.length += frameLen
- maxPacketSize -= frameLen
- }
- } else if s.HasData() {
- cf := s.PopCryptoFrame(maxPacketSize)
- pl.frames = []ackhandler.Frame{{Frame: cf, Handler: handler}}
- pl.length += cf.Length(v)
- }
- return hdr, pl
- }
- func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*wire.ExtendedHeader, payload) {
- if p.perspective != protocol.PerspectiveClient {
- return nil, payload{}
- }
- hdr := p.getLongHeader(protocol.Encryption0RTT, v)
- maxPayloadSize := maxPacketSize - hdr.GetLength(v) - protocol.ByteCount(sealer.Overhead())
- return hdr, p.maybeGetAppDataPacket(maxPayloadSize, false, false, v)
- }
- func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
- maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead())
- return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed, v)
- }
- func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
- pl := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed, v)
- // check if we have anything to send
- if len(pl.frames) == 0 && len(pl.streamFrames) == 0 {
- if pl.ack == nil {
- return payload{}
- }
- // the packet only contains an ACK
- if p.numNonAckElicitingAcks >= protocol.MaxNonAckElicitingAcks {
- ping := &wire.PingFrame{}
- pl.frames = append(pl.frames, ackhandler.Frame{Frame: ping})
- pl.length += ping.Length(v)
- p.numNonAckElicitingAcks = 0
- } else {
- p.numNonAckElicitingAcks++
- }
- } else {
- p.numNonAckElicitingAcks = 0
- }
- return pl
- }
- func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool, v protocol.VersionNumber) payload {
- if onlyAck {
- if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil {
- return payload{ack: ack, length: ack.Length(v)}
- }
- return payload{}
- }
- pl := payload{streamFrames: make([]ackhandler.StreamFrame, 0, 1)}
- hasData := p.framer.HasData()
- hasRetransmission := p.retransmissionQueue.HasAppData()
- var hasAck bool
- if ackAllowed {
- if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil {
- pl.ack = ack
- pl.length += ack.Length(v)
- hasAck = true
- }
- }
- if p.datagramQueue != nil {
- if f := p.datagramQueue.Peek(); f != nil {
- size := f.Length(v)
- if size <= maxFrameSize-pl.length { // DATAGRAM frame fits
- pl.frames = append(pl.frames, ackhandler.Frame{Frame: f})
- pl.length += size
- p.datagramQueue.Pop()
- } else if !hasAck {
- // The DATAGRAM frame doesn't fit, and the packet doesn't contain an ACK.
- // Discard this frame. There's no point in retrying this in the next packet,
- // as it's unlikely that the available packet size will increase.
- p.datagramQueue.Pop()
- }
- // If the DATAGRAM frame was too large and the packet contained an ACK, we'll try to send it out later.
- }
- }
- if hasAck && !hasData && !hasRetransmission {
- return pl
- }
- if hasRetransmission {
- for {
- remainingLen := maxFrameSize - pl.length
- if remainingLen < protocol.MinStreamFrameSize {
- break
- }
- f := p.retransmissionQueue.GetAppDataFrame(remainingLen, v)
- if f == nil {
- break
- }
- pl.frames = append(pl.frames, ackhandler.Frame{Frame: f, Handler: p.retransmissionQueue.AppDataAckHandler()})
- pl.length += f.Length(v)
- }
- }
- if hasData {
- var lengthAdded protocol.ByteCount
- startLen := len(pl.frames)
- pl.frames, lengthAdded = p.framer.AppendControlFrames(pl.frames, maxFrameSize-pl.length, v)
- pl.length += lengthAdded
- // add handlers for the control frames that were added
- for i := startLen; i < len(pl.frames); i++ {
- switch pl.frames[i].Frame.(type) {
- case *wire.PathChallengeFrame, *wire.PathResponseFrame:
- // Path probing is currently not supported, therefore we don't need to set the OnAcked callback yet.
- // PATH_CHALLENGE and PATH_RESPONSE are never retransmitted.
- default:
- pl.frames[i].Handler = p.retransmissionQueue.AppDataAckHandler()
- }
- }
- pl.streamFrames, lengthAdded = p.framer.AppendStreamFrames(pl.streamFrames, maxFrameSize-pl.length, v)
- pl.length += lengthAdded
- }
- return pl
- }
- func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
- if encLevel == protocol.Encryption1RTT {
- s, err := p.cryptoSetup.Get1RTTSealer()
- if err != nil {
- return nil, err
- }
- kp := s.KeyPhase()
- connID := p.getDestConnID()
- pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
- hdrLen := wire.ShortHeaderLen(connID, pnLen)
- pl := p.maybeGetAppDataPacket(maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true, v)
- if pl.length == 0 {
- return nil, nil
- }
- buffer := getPacketBuffer()
- packet := &coalescedPacket{buffer: buffer}
- shp, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, 0, maxPacketSize, s, false, v)
- if err != nil {
- return nil, err
- }
- packet.shortHdrPacket = &shp
- return packet, nil
- }
- var hdr *wire.ExtendedHeader
- var pl payload
- var sealer handshake.LongHeaderSealer
- //nolint:exhaustive // Probe packets are never sent for 0-RTT.
- switch encLevel {
- case protocol.EncryptionInitial:
- var err error
- sealer, err = p.cryptoSetup.GetInitialSealer()
- if err != nil {
- return nil, err
- }
- hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v)
- case protocol.EncryptionHandshake:
- var err error
- sealer, err = p.cryptoSetup.GetHandshakeSealer()
- if err != nil {
- return nil, err
- }
- hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true, v)
- default:
- panic("unknown encryption level")
- }
- if pl.length == 0 {
- return nil, nil
- }
- buffer := getPacketBuffer()
- packet := &coalescedPacket{buffer: buffer}
- size := p.longHeaderPacketLength(hdr, pl, v) + protocol.ByteCount(sealer.Overhead())
- var padding protocol.ByteCount
- if encLevel == protocol.EncryptionInitial {
- padding = p.initialPaddingLen(pl.frames, size, maxPacketSize)
- }
- longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, pl, padding, encLevel, sealer, v)
- if err != nil {
- return nil, err
- }
- packet.longHdrPackets = []*longHeaderPacket{longHdrPacket}
- return packet, nil
- }
- func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, v protocol.VersionNumber) (shortHeaderPacket, *packetBuffer, error) {
- pl := payload{
- frames: []ackhandler.Frame{ping},
- length: ping.Frame.Length(v),
- }
- buffer := getPacketBuffer()
- s, err := p.cryptoSetup.Get1RTTSealer()
- if err != nil {
- return shortHeaderPacket{}, nil, err
- }
- connID := p.getDestConnID()
- pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
- padding := size - p.shortHeaderPacketLength(connID, pnLen, pl) - protocol.ByteCount(s.Overhead())
- kp := s.KeyPhase()
- packet, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, pl, padding, size, s, true, v)
- return packet, buffer, err
- }
- func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel, v protocol.VersionNumber) *wire.ExtendedHeader {
- pn, pnLen := p.pnManager.PeekPacketNumber(encLevel)
- hdr := &wire.ExtendedHeader{
- PacketNumber: pn,
- PacketNumberLen: pnLen,
- }
- hdr.Version = v
- hdr.SrcConnectionID = p.srcConnID
- hdr.DestConnectionID = p.getDestConnID()
- //nolint:exhaustive // 1-RTT packets are not long header packets.
- switch encLevel {
- case protocol.EncryptionInitial:
- hdr.Type = protocol.PacketTypeInitial
- hdr.Token = p.token
- case protocol.EncryptionHandshake:
- hdr.Type = protocol.PacketTypeHandshake
- case protocol.Encryption0RTT:
- hdr.Type = protocol.PacketType0RTT
- }
- return hdr
- }
- func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) {
- var paddingLen protocol.ByteCount
- pnLen := protocol.ByteCount(header.PacketNumberLen)
- if pl.length < 4-pnLen {
- paddingLen = 4 - pnLen - pl.length
- }
- paddingLen += padding
- header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length + paddingLen
- startLen := len(buffer.Data)
- raw := buffer.Data[startLen:]
- raw, err := header.Append(raw, v)
- if err != nil {
- return nil, err
- }
- payloadOffset := protocol.ByteCount(len(raw))
- raw, err = p.appendPacketPayload(raw, pl, paddingLen, v)
- if err != nil {
- return nil, err
- }
- raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen)
- buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
- if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber {
- return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber)
- }
- return &longHeaderPacket{
- header: header,
- ack: pl.ack,
- frames: pl.frames,
- streamFrames: pl.streamFrames,
- length: protocol.ByteCount(len(raw)),
- }, nil
- }
- func (p *packetPacker) appendShortHeaderPacket(
- buffer *packetBuffer,
- connID protocol.ConnectionID,
- pn protocol.PacketNumber,
- pnLen protocol.PacketNumberLen,
- kp protocol.KeyPhaseBit,
- pl payload,
- padding, maxPacketSize protocol.ByteCount,
- sealer sealer,
- isMTUProbePacket bool,
- v protocol.VersionNumber,
- ) (shortHeaderPacket, error) {
- var paddingLen protocol.ByteCount
- if pl.length < 4-protocol.ByteCount(pnLen) {
- paddingLen = 4 - protocol.ByteCount(pnLen) - pl.length
- }
- paddingLen += padding
- startLen := len(buffer.Data)
- raw := buffer.Data[startLen:]
- raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp)
- if err != nil {
- return shortHeaderPacket{}, err
- }
- payloadOffset := protocol.ByteCount(len(raw))
- raw, err = p.appendPacketPayload(raw, pl, paddingLen, v)
- if err != nil {
- return shortHeaderPacket{}, err
- }
- if !isMTUProbePacket {
- if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > maxPacketSize {
- return shortHeaderPacket{}, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, maxPacketSize)
- }
- }
- raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen))
- buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
- if newPN := p.pnManager.PopPacketNumber(protocol.Encryption1RTT); newPN != pn {
- return shortHeaderPacket{}, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, newPN)
- }
- return shortHeaderPacket{
- PacketNumber: pn,
- PacketNumberLen: pnLen,
- KeyPhase: kp,
- StreamFrames: pl.streamFrames,
- Frames: pl.frames,
- Ack: pl.ack,
- Length: protocol.ByteCount(len(raw)),
- DestConnID: connID,
- IsPathMTUProbePacket: isMTUProbePacket,
- }, nil
- }
- // appendPacketPayload serializes the payload of a packet into the raw byte slice.
- // It modifies the order of payload.frames.
- func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) {
- payloadOffset := len(raw)
- if pl.ack != nil {
- var err error
- raw, err = pl.ack.Append(raw, v)
- if err != nil {
- return nil, err
- }
- }
- if paddingLen > 0 {
- raw = append(raw, make([]byte, paddingLen)...)
- }
- // Randomize the order of the control frames.
- // This makes sure that the receiver doesn't rely on the order in which frames are packed.
- if len(pl.frames) > 1 {
- p.rand.Shuffle(len(pl.frames), func(i, j int) { pl.frames[i], pl.frames[j] = pl.frames[j], pl.frames[i] })
- }
- for _, f := range pl.frames {
- var err error
- raw, err = f.Frame.Append(raw, v)
- if err != nil {
- return nil, err
- }
- }
- for _, f := range pl.streamFrames {
- var err error
- raw, err = f.Frame.Append(raw, v)
- if err != nil {
- return nil, err
- }
- }
- if payloadSize := protocol.ByteCount(len(raw)-payloadOffset) - paddingLen; payloadSize != pl.length {
- return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", pl.length, payloadSize)
- }
- return raw, nil
- }
- func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.PacketNumber, payloadOffset, pnLen protocol.ByteCount) []byte {
- _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], pn, raw[:payloadOffset])
- raw = raw[:len(raw)+sealer.Overhead()]
- // apply header protection
- pnOffset := payloadOffset - pnLen
- sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[0], raw[pnOffset:payloadOffset])
- return raw
- }
- func (p *packetPacker) SetToken(token []byte) {
- p.token = token
- }
|