packet_unpacker.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. package quic
  2. import (
  3. "bytes"
  4. "fmt"
  5. "time"
  6. "github.com/Psiphon-Labs/quic-go/internal/handshake"
  7. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  8. "github.com/Psiphon-Labs/quic-go/internal/qerr"
  9. "github.com/Psiphon-Labs/quic-go/internal/wire"
  10. )
  11. type headerDecryptor interface {
  12. DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
  13. }
  14. type headerParseError struct {
  15. err error
  16. }
  17. func (e *headerParseError) Unwrap() error {
  18. return e.err
  19. }
  20. func (e *headerParseError) Error() string {
  21. return e.err.Error()
  22. }
  23. type unpackedPacket struct {
  24. hdr *wire.ExtendedHeader
  25. encryptionLevel protocol.EncryptionLevel
  26. data []byte
  27. }
  28. // The packetUnpacker unpacks QUIC packets.
  29. type packetUnpacker struct {
  30. cs handshake.CryptoSetup
  31. shortHdrConnIDLen int
  32. }
  33. var _ unpacker = &packetUnpacker{}
  34. func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetUnpacker {
  35. return &packetUnpacker{
  36. cs: cs,
  37. shortHdrConnIDLen: shortHdrConnIDLen,
  38. }
  39. }
  40. // UnpackLongHeader unpacks a Long Header packet.
  41. // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
  42. // If any other error occurred when parsing the header, the error is of type headerParseError.
  43. // If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
  44. func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.VersionNumber) (*unpackedPacket, error) {
  45. var encLevel protocol.EncryptionLevel
  46. var extHdr *wire.ExtendedHeader
  47. var decrypted []byte
  48. //nolint:exhaustive // Retry packets can't be unpacked.
  49. switch hdr.Type {
  50. case protocol.PacketTypeInitial:
  51. encLevel = protocol.EncryptionInitial
  52. opener, err := u.cs.GetInitialOpener()
  53. if err != nil {
  54. return nil, err
  55. }
  56. extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
  57. if err != nil {
  58. return nil, err
  59. }
  60. case protocol.PacketTypeHandshake:
  61. encLevel = protocol.EncryptionHandshake
  62. opener, err := u.cs.GetHandshakeOpener()
  63. if err != nil {
  64. return nil, err
  65. }
  66. extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
  67. if err != nil {
  68. return nil, err
  69. }
  70. case protocol.PacketType0RTT:
  71. encLevel = protocol.Encryption0RTT
  72. opener, err := u.cs.Get0RTTOpener()
  73. if err != nil {
  74. return nil, err
  75. }
  76. extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v)
  77. if err != nil {
  78. return nil, err
  79. }
  80. default:
  81. return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
  82. }
  83. if len(decrypted) == 0 {
  84. return nil, &qerr.TransportError{
  85. ErrorCode: qerr.ProtocolViolation,
  86. ErrorMessage: "empty packet",
  87. }
  88. }
  89. return &unpackedPacket{
  90. hdr: extHdr,
  91. encryptionLevel: encLevel,
  92. data: decrypted,
  93. }, nil
  94. }
  95. func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
  96. opener, err := u.cs.Get1RTTOpener()
  97. if err != nil {
  98. return 0, 0, 0, nil, err
  99. }
  100. pn, pnLen, kp, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data)
  101. if err != nil {
  102. return 0, 0, 0, nil, err
  103. }
  104. if len(decrypted) == 0 {
  105. return 0, 0, 0, nil, &qerr.TransportError{
  106. ErrorCode: qerr.ProtocolViolation,
  107. ErrorMessage: "empty packet",
  108. }
  109. }
  110. return pn, pnLen, kp, decrypted, nil
  111. }
  112. func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, []byte, error) {
  113. extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v)
  114. // If the reserved bits are set incorrectly, we still need to continue unpacking.
  115. // This avoids a timing side-channel, which otherwise might allow an attacker
  116. // to gain information about the header encryption.
  117. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
  118. return nil, nil, parseErr
  119. }
  120. extHdrLen := extHdr.ParsedLen()
  121. extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen)
  122. decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
  123. if err != nil {
  124. return nil, nil, err
  125. }
  126. if parseErr != nil {
  127. return nil, nil, parseErr
  128. }
  129. return extHdr, decrypted, nil
  130. }
  131. func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
  132. l, pn, pnLen, kp, parseErr := u.unpackShortHeader(opener, data)
  133. // If the reserved bits are set incorrectly, we still need to continue unpacking.
  134. // This avoids a timing side-channel, which otherwise might allow an attacker
  135. // to gain information about the header encryption.
  136. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
  137. return 0, 0, 0, nil, &headerParseError{parseErr}
  138. }
  139. pn = opener.DecodePacketNumber(pn, pnLen)
  140. decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, pn, kp, data[:l])
  141. if err != nil {
  142. return 0, 0, 0, nil, err
  143. }
  144. return pn, pnLen, kp, decrypted, parseErr
  145. }
  146. func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int, protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, error) {
  147. hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen
  148. if len(data) < hdrLen+4+16 {
  149. return 0, 0, 0, 0, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
  150. }
  151. origPNBytes := make([]byte, 4)
  152. copy(origPNBytes, data[hdrLen:hdrLen+4])
  153. // 2. decrypt the header, assuming a 4 byte packet number
  154. hd.DecryptHeader(
  155. data[hdrLen+4:hdrLen+4+16],
  156. &data[0],
  157. data[hdrLen:hdrLen+4],
  158. )
  159. // 3. parse the header (and learn the actual length of the packet number)
  160. l, pn, pnLen, kp, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen)
  161. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
  162. return l, pn, pnLen, kp, parseErr
  163. }
  164. // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
  165. if pnLen != protocol.PacketNumberLen4 {
  166. copy(data[hdrLen+int(pnLen):hdrLen+4], origPNBytes[int(pnLen):])
  167. }
  168. return l, pn, pnLen, kp, parseErr
  169. }
  170. // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
  171. func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) {
  172. extHdr, err := unpackLongHeader(hd, hdr, data, v)
  173. if err != nil && err != wire.ErrInvalidReservedBits {
  174. return nil, &headerParseError{err: err}
  175. }
  176. return extHdr, err
  177. }
  178. func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.VersionNumber) (*wire.ExtendedHeader, error) {
  179. r := bytes.NewReader(data)
  180. hdrLen := hdr.ParsedLen()
  181. if protocol.ByteCount(len(data)) < hdrLen+4+16 {
  182. //nolint:stylecheck
  183. return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen)
  184. }
  185. // The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it.
  186. // 1. save a copy of the 4 bytes
  187. origPNBytes := make([]byte, 4)
  188. copy(origPNBytes, data[hdrLen:hdrLen+4])
  189. // 2. decrypt the header, assuming a 4 byte packet number
  190. hd.DecryptHeader(
  191. data[hdrLen+4:hdrLen+4+16],
  192. &data[0],
  193. data[hdrLen:hdrLen+4],
  194. )
  195. // 3. parse the header (and learn the actual length of the packet number)
  196. extHdr, parseErr := hdr.ParseExtended(r, v)
  197. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
  198. return nil, parseErr
  199. }
  200. // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
  201. if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
  202. copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
  203. }
  204. return extHdr, parseErr
  205. }