packet_packer.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. package gquic
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "time"
  8. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/ackhandler"
  9. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/handshake"
  10. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
  11. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
  12. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/wire"
  13. )
  14. type packer interface {
  15. PackPacket() (*packedPacket, error)
  16. MaybePackAckPacket() (*packedPacket, error)
  17. PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error)
  18. PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error)
  19. HandleTransportParameters(*handshake.TransportParameters)
  20. ChangeDestConnectionID(protocol.ConnectionID)
  21. }
  22. type packedPacket struct {
  23. header *wire.Header
  24. raw []byte
  25. frames []wire.Frame
  26. encryptionLevel protocol.EncryptionLevel
  27. }
  28. func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
  29. return &ackhandler.Packet{
  30. PacketNumber: p.header.PacketNumber,
  31. PacketType: p.header.Type,
  32. Frames: p.frames,
  33. Length: protocol.ByteCount(len(p.raw)),
  34. EncryptionLevel: p.encryptionLevel,
  35. SendTime: time.Now(),
  36. }
  37. }
  38. func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
  39. maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
  40. // If this is not a UDP address, we don't know anything about the MTU.
  41. // Use the minimum size of an Initial packet as the max packet size.
  42. if udpAddr, ok := addr.(*net.UDPAddr); ok {
  43. // If ip is not an IPv4 address, To4 returns nil.
  44. // Note that there might be some corner cases, where this is not correct.
  45. // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6.
  46. if udpAddr.IP.To4() == nil {
  47. maxSize = protocol.MaxPacketSizeIPv6
  48. } else {
  49. maxSize = protocol.MaxPacketSizeIPv4
  50. }
  51. }
  52. return maxSize
  53. }
  54. type sealingManager interface {
  55. GetSealer() (protocol.EncryptionLevel, handshake.Sealer)
  56. GetSealerForCryptoStream() (protocol.EncryptionLevel, handshake.Sealer)
  57. GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error)
  58. }
  59. type frameSource interface {
  60. AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame
  61. AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount)
  62. }
  63. type ackFrameSource interface {
  64. GetAckFrame() *wire.AckFrame
  65. GetStopWaitingFrame(bool) *wire.StopWaitingFrame
  66. }
  67. type packetPacker struct {
  68. destConnID protocol.ConnectionID
  69. srcConnID protocol.ConnectionID
  70. perspective protocol.Perspective
  71. version protocol.VersionNumber
  72. cryptoSetup sealingManager
  73. token []byte
  74. packetNumberGenerator *packetNumberGenerator
  75. getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen
  76. cryptoStream cryptoStream
  77. framer frameSource
  78. acks ackFrameSource
  79. maxPacketSize protocol.ByteCount
  80. hasSentPacket bool // has the packetPacker already sent a packet
  81. numNonRetransmittableAcks int
  82. }
  83. var _ packer = &packetPacker{}
  84. func newPacketPacker(
  85. destConnID protocol.ConnectionID,
  86. srcConnID protocol.ConnectionID,
  87. initialPacketNumber protocol.PacketNumber,
  88. getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen,
  89. remoteAddr net.Addr, // only used for determining the max packet size
  90. token []byte,
  91. cryptoStream cryptoStream,
  92. cryptoSetup sealingManager,
  93. framer frameSource,
  94. acks ackFrameSource,
  95. perspective protocol.Perspective,
  96. version protocol.VersionNumber,
  97. ) *packetPacker {
  98. return &packetPacker{
  99. cryptoStream: cryptoStream,
  100. cryptoSetup: cryptoSetup,
  101. token: token,
  102. destConnID: destConnID,
  103. srcConnID: srcConnID,
  104. perspective: perspective,
  105. version: version,
  106. framer: framer,
  107. acks: acks,
  108. getPacketNumberLen: getPacketNumberLen,
  109. packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength),
  110. maxPacketSize: getMaxPacketSize(remoteAddr),
  111. }
  112. }
  113. // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
  114. func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) {
  115. frames := []wire.Frame{ccf}
  116. encLevel, sealer := p.cryptoSetup.GetSealer()
  117. header := p.getHeader(encLevel)
  118. raw, err := p.writeAndSealPacket(header, frames, sealer)
  119. return &packedPacket{
  120. header: header,
  121. raw: raw,
  122. frames: frames,
  123. encryptionLevel: encLevel,
  124. }, err
  125. }
  126. func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
  127. ack := p.acks.GetAckFrame()
  128. if ack == nil {
  129. return nil, nil
  130. }
  131. encLevel, sealer := p.cryptoSetup.GetSealer()
  132. header := p.getHeader(encLevel)
  133. frames := []wire.Frame{ack}
  134. raw, err := p.writeAndSealPacket(header, frames, sealer)
  135. return &packedPacket{
  136. header: header,
  137. raw: raw,
  138. frames: frames,
  139. encryptionLevel: encLevel,
  140. }, err
  141. }
  142. // PackRetransmission packs a retransmission
  143. // For packets sent after completion of the handshake, it might happen that 2 packets have to be sent.
  144. // This can happen e.g. when a longer packet number is used in the header.
  145. func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) {
  146. if packet.EncryptionLevel != protocol.EncryptionForwardSecure {
  147. p, err := p.packHandshakeRetransmission(packet)
  148. return []*packedPacket{p}, err
  149. }
  150. var controlFrames []wire.Frame
  151. var streamFrames []*wire.StreamFrame
  152. for _, f := range packet.Frames {
  153. if sf, ok := f.(*wire.StreamFrame); ok {
  154. sf.DataLenPresent = true
  155. streamFrames = append(streamFrames, sf)
  156. } else {
  157. controlFrames = append(controlFrames, f)
  158. }
  159. }
  160. var packets []*packedPacket
  161. encLevel, sealer := p.cryptoSetup.GetSealer()
  162. for len(controlFrames) > 0 || len(streamFrames) > 0 {
  163. var frames []wire.Frame
  164. var length protocol.ByteCount
  165. header := p.getHeader(encLevel)
  166. headerLength, err := header.GetLength(p.version)
  167. if err != nil {
  168. return nil, err
  169. }
  170. maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength
  171. for len(controlFrames) > 0 {
  172. frame := controlFrames[0]
  173. frameLen := frame.Length(p.version)
  174. if length+frameLen > maxSize {
  175. break
  176. }
  177. length += frameLen
  178. frames = append(frames, frame)
  179. controlFrames = controlFrames[1:]
  180. }
  181. for len(streamFrames) > 0 && length+protocol.MinStreamFrameSize < maxSize {
  182. frame := streamFrames[0]
  183. frame.DataLenPresent = false
  184. frameToAdd := frame
  185. sf, err := frame.MaybeSplitOffFrame(maxSize-length, p.version)
  186. if err != nil {
  187. return nil, err
  188. }
  189. if sf != nil {
  190. frameToAdd = sf
  191. } else {
  192. streamFrames = streamFrames[1:]
  193. }
  194. frame.DataLenPresent = true
  195. length += frameToAdd.Length(p.version)
  196. frames = append(frames, frameToAdd)
  197. }
  198. if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
  199. sf.DataLenPresent = false
  200. }
  201. raw, err := p.writeAndSealPacket(header, frames, sealer)
  202. if err != nil {
  203. return nil, err
  204. }
  205. packets = append(packets, &packedPacket{
  206. header: header,
  207. raw: raw,
  208. frames: frames,
  209. encryptionLevel: encLevel,
  210. })
  211. }
  212. return packets, nil
  213. }
  214. // packHandshakeRetransmission retransmits a handshake packet, that was sent with less than forward-secure encryption
  215. func (p *packetPacker) packHandshakeRetransmission(packet *ackhandler.Packet) (*packedPacket, error) {
  216. sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(packet.EncryptionLevel)
  217. if err != nil {
  218. return nil, err
  219. }
  220. // make sure that the retransmission for an Initial packet is sent as an Initial packet
  221. if packet.PacketType == protocol.PacketTypeInitial {
  222. p.hasSentPacket = false
  223. }
  224. header := p.getHeader(packet.EncryptionLevel)
  225. header.Type = packet.PacketType
  226. raw, err := p.writeAndSealPacket(header, packet.Frames, sealer)
  227. return &packedPacket{
  228. header: header,
  229. raw: raw,
  230. frames: packet.Frames,
  231. encryptionLevel: packet.EncryptionLevel,
  232. }, err
  233. }
  234. // PackPacket packs a new packet
  235. // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
  236. func (p *packetPacker) PackPacket() (*packedPacket, error) {
  237. packet, err := p.maybePackCryptoPacket()
  238. if err != nil {
  239. return nil, err
  240. }
  241. if packet != nil {
  242. return packet, nil
  243. }
  244. // if this is the first packet to be send, make sure it contains stream data
  245. if !p.hasSentPacket && packet == nil {
  246. return nil, nil
  247. }
  248. encLevel, sealer := p.cryptoSetup.GetSealer()
  249. header := p.getHeader(encLevel)
  250. headerLength, err := header.GetLength(p.version)
  251. if err != nil {
  252. return nil, err
  253. }
  254. maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength
  255. frames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel))
  256. if err != nil {
  257. return nil, err
  258. }
  259. // Check if we have enough frames to send
  260. if len(frames) == 0 {
  261. return nil, nil
  262. }
  263. // check if this packet only contains an ACK
  264. if !ackhandler.HasRetransmittableFrames(frames) {
  265. if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks {
  266. frames = append(frames, &wire.PingFrame{})
  267. p.numNonRetransmittableAcks = 0
  268. } else {
  269. p.numNonRetransmittableAcks++
  270. }
  271. } else {
  272. p.numNonRetransmittableAcks = 0
  273. }
  274. raw, err := p.writeAndSealPacket(header, frames, sealer)
  275. if err != nil {
  276. return nil, err
  277. }
  278. return &packedPacket{
  279. header: header,
  280. raw: raw,
  281. frames: frames,
  282. encryptionLevel: encLevel,
  283. }, nil
  284. }
  285. func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
  286. if !p.cryptoStream.hasData() {
  287. return nil, nil
  288. }
  289. encLevel, sealer := p.cryptoSetup.GetSealerForCryptoStream()
  290. header := p.getHeader(encLevel)
  291. headerLength, err := header.GetLength(p.version)
  292. if err != nil {
  293. return nil, err
  294. }
  295. maxLen := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - protocol.NonForwardSecurePacketSizeReduction - headerLength
  296. sf, _ := p.cryptoStream.popStreamFrame(maxLen)
  297. sf.DataLenPresent = false
  298. frames := []wire.Frame{sf}
  299. raw, err := p.writeAndSealPacket(header, frames, sealer)
  300. if err != nil {
  301. return nil, err
  302. }
  303. return &packedPacket{
  304. header: header,
  305. raw: raw,
  306. frames: frames,
  307. encryptionLevel: encLevel,
  308. }, nil
  309. }
  310. func (p *packetPacker) composeNextPacket(
  311. maxFrameSize protocol.ByteCount,
  312. canSendStreamFrames bool,
  313. ) ([]wire.Frame, error) {
  314. var length protocol.ByteCount
  315. var frames []wire.Frame
  316. // ACKs need to go first, so that the sentPacketHandler will recognize them
  317. if ack := p.acks.GetAckFrame(); ack != nil {
  318. frames = append(frames, ack)
  319. length += ack.Length(p.version)
  320. }
  321. var lengthAdded protocol.ByteCount
  322. frames, lengthAdded = p.framer.AppendControlFrames(frames, maxFrameSize-length)
  323. length += lengthAdded
  324. if !canSendStreamFrames {
  325. return frames, nil
  326. }
  327. // temporarily increase the maxFrameSize by the (minimum) length of the DataLen field
  328. // this leads to a properly sized packet in all cases, since we do all the packet length calculations with STREAM frames that have the DataLen set
  329. // however, for the last STREAM frame in the packet, we can omit the DataLen, thus yielding a packet of exactly the correct size
  330. // the length is encoded to either 1 or 2 bytes
  331. maxFrameSize++
  332. frames = p.framer.AppendStreamFrames(frames, maxFrameSize-length)
  333. if len(frames) > 0 {
  334. lastFrame := frames[len(frames)-1]
  335. if sf, ok := lastFrame.(*wire.StreamFrame); ok {
  336. sf.DataLenPresent = false
  337. }
  338. }
  339. return frames, nil
  340. }
  341. func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header {
  342. pnum := p.packetNumberGenerator.Peek()
  343. packetNumberLen := p.getPacketNumberLen(pnum)
  344. header := &wire.Header{
  345. PacketNumber: pnum,
  346. PacketNumberLen: packetNumberLen,
  347. Version: p.version,
  348. DestConnectionID: p.destConnID,
  349. }
  350. if encLevel != protocol.EncryptionForwardSecure {
  351. header.IsLongHeader = true
  352. header.SrcConnectionID = p.srcConnID
  353. // Set the payload len to maximum size.
  354. // Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns.
  355. header.PayloadLen = p.maxPacketSize
  356. if !p.hasSentPacket && p.perspective == protocol.PerspectiveClient {
  357. header.Type = protocol.PacketTypeInitial
  358. header.Token = p.token
  359. } else {
  360. header.Type = protocol.PacketTypeHandshake
  361. }
  362. }
  363. return header
  364. }
  365. func (p *packetPacker) writeAndSealPacket(
  366. header *wire.Header,
  367. frames []wire.Frame,
  368. sealer handshake.Sealer,
  369. ) ([]byte, error) {
  370. raw := *getPacketBuffer()
  371. buffer := bytes.NewBuffer(raw[:0])
  372. // the payload length is only needed for Long Headers
  373. if header.IsLongHeader {
  374. if header.Type == protocol.PacketTypeInitial {
  375. headerLen, _ := header.GetLength(p.version)
  376. header.PayloadLen = protocol.ByteCount(protocol.MinInitialPacketSize) - headerLen
  377. } else {
  378. payloadLen := protocol.ByteCount(sealer.Overhead())
  379. for _, frame := range frames {
  380. payloadLen += frame.Length(p.version)
  381. }
  382. header.PayloadLen = payloadLen
  383. }
  384. }
  385. if err := header.Write(buffer, p.perspective, p.version); err != nil {
  386. return nil, err
  387. }
  388. payloadStartIndex := buffer.Len()
  389. // the Initial packet needs to be padded, so the last STREAM frame must have the data length present
  390. if header.Type == protocol.PacketTypeInitial {
  391. lastFrame := frames[len(frames)-1]
  392. if sf, ok := lastFrame.(*wire.StreamFrame); ok {
  393. sf.DataLenPresent = true
  394. }
  395. }
  396. for _, frame := range frames {
  397. if err := frame.Write(buffer, p.version); err != nil {
  398. return nil, err
  399. }
  400. }
  401. // if this is an Initial packet, we need to pad it to fulfill the minimum size requirement
  402. if header.Type == protocol.PacketTypeInitial {
  403. paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len()
  404. if paddingLen > 0 {
  405. buffer.Write(bytes.Repeat([]byte{0}, paddingLen))
  406. }
  407. }
  408. if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize {
  409. return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
  410. }
  411. raw = raw[0:buffer.Len()]
  412. _ = sealer.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], header.PacketNumber, raw[:payloadStartIndex])
  413. raw = raw[0 : buffer.Len()+sealer.Overhead()]
  414. num := p.packetNumberGenerator.Pop()
  415. if num != header.PacketNumber {
  416. return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
  417. }
  418. p.hasSentPacket = true
  419. return raw, nil
  420. }
  421. func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool {
  422. if p.perspective == protocol.PerspectiveClient {
  423. return encLevel >= protocol.EncryptionSecure
  424. }
  425. return encLevel == protocol.EncryptionForwardSecure
  426. }
  427. func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) {
  428. p.destConnID = connID
  429. }
  430. func (p *packetPacker) HandleTransportParameters(params *handshake.TransportParameters) {
  431. if params.MaxPacketSize != 0 {
  432. p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, params.MaxPacketSize)
  433. }
  434. }