packet_unpacker.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. package quic
  2. import (
  3. "fmt"
  4. "time"
  5. "github.com/Psiphon-Labs/quic-go/internal/handshake"
  6. "github.com/Psiphon-Labs/quic-go/internal/protocol"
  7. "github.com/Psiphon-Labs/quic-go/internal/qerr"
  8. "github.com/Psiphon-Labs/quic-go/internal/wire"
  9. )
  10. type headerDecryptor interface {
  11. DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
  12. }
  13. type headerParseError struct {
  14. err error
  15. }
  16. func (e *headerParseError) Unwrap() error {
  17. return e.err
  18. }
  19. func (e *headerParseError) Error() string {
  20. return e.err.Error()
  21. }
  22. type unpackedPacket struct {
  23. hdr *wire.ExtendedHeader
  24. encryptionLevel protocol.EncryptionLevel
  25. data []byte
  26. }
  27. // The packetUnpacker unpacks QUIC packets.
  28. type packetUnpacker struct {
  29. cs handshake.CryptoSetup
  30. shortHdrConnIDLen int
  31. }
  32. var _ unpacker = &packetUnpacker{}
  33. func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetUnpacker {
  34. return &packetUnpacker{
  35. cs: cs,
  36. shortHdrConnIDLen: shortHdrConnIDLen,
  37. }
  38. }
  39. // UnpackLongHeader unpacks a Long Header packet.
  40. // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
  41. // If any other error occurred when parsing the header, the error is of type headerParseError.
  42. // If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
  43. func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
  44. var encLevel protocol.EncryptionLevel
  45. var extHdr *wire.ExtendedHeader
  46. var decrypted []byte
  47. //nolint:exhaustive // Retry packets can't be unpacked.
  48. switch hdr.Type {
  49. case protocol.PacketTypeInitial:
  50. encLevel = protocol.EncryptionInitial
  51. opener, err := u.cs.GetInitialOpener()
  52. if err != nil {
  53. return nil, err
  54. }
  55. extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
  56. if err != nil {
  57. return nil, err
  58. }
  59. case protocol.PacketTypeHandshake:
  60. encLevel = protocol.EncryptionHandshake
  61. opener, err := u.cs.GetHandshakeOpener()
  62. if err != nil {
  63. return nil, err
  64. }
  65. extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
  66. if err != nil {
  67. return nil, err
  68. }
  69. case protocol.PacketType0RTT:
  70. encLevel = protocol.Encryption0RTT
  71. opener, err := u.cs.Get0RTTOpener()
  72. if err != nil {
  73. return nil, err
  74. }
  75. extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
  76. if err != nil {
  77. return nil, err
  78. }
  79. default:
  80. return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
  81. }
  82. if len(decrypted) == 0 {
  83. return nil, &qerr.TransportError{
  84. ErrorCode: qerr.ProtocolViolation,
  85. ErrorMessage: "empty packet",
  86. }
  87. }
  88. return &unpackedPacket{
  89. hdr: extHdr,
  90. encryptionLevel: encLevel,
  91. data: decrypted,
  92. }, nil
  93. }
  94. func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
  95. opener, err := u.cs.Get1RTTOpener()
  96. if err != nil {
  97. return 0, 0, 0, nil, err
  98. }
  99. pn, pnLen, kp, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data)
  100. if err != nil {
  101. return 0, 0, 0, nil, err
  102. }
  103. if len(decrypted) == 0 {
  104. return 0, 0, 0, nil, &qerr.TransportError{
  105. ErrorCode: qerr.ProtocolViolation,
  106. ErrorMessage: "empty packet",
  107. }
  108. }
  109. return pn, pnLen, kp, decrypted, nil
  110. }
  111. func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
  112. extHdr, parseErr := u.unpackLongHeader(opener, hdr, data)
  113. // If the reserved bits are set incorrectly, we still need to continue unpacking.
  114. // This avoids a timing side-channel, which otherwise might allow an attacker
  115. // to gain information about the header encryption.
  116. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
  117. return nil, nil, parseErr
  118. }
  119. extHdrLen := extHdr.ParsedLen()
  120. extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen)
  121. decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
  122. if err != nil {
  123. return nil, nil, err
  124. }
  125. if parseErr != nil {
  126. return nil, nil, parseErr
  127. }
  128. return extHdr, decrypted, nil
  129. }
  130. func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
  131. l, pn, pnLen, kp, parseErr := u.unpackShortHeader(opener, data)
  132. // If the reserved bits are set incorrectly, we still need to continue unpacking.
  133. // This avoids a timing side-channel, which otherwise might allow an attacker
  134. // to gain information about the header encryption.
  135. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
  136. return 0, 0, 0, nil, &headerParseError{parseErr}
  137. }
  138. pn = opener.DecodePacketNumber(pn, pnLen)
  139. decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, pn, kp, data[:l])
  140. if err != nil {
  141. return 0, 0, 0, nil, err
  142. }
  143. return pn, pnLen, kp, decrypted, parseErr
  144. }
  145. func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int, protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, error) {
  146. hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen
  147. if len(data) < hdrLen+4+16 {
  148. return 0, 0, 0, 0, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
  149. }
  150. origPNBytes := make([]byte, 4)
  151. copy(origPNBytes, data[hdrLen:hdrLen+4])
  152. // 2. decrypt the header, assuming a 4 byte packet number
  153. hd.DecryptHeader(
  154. data[hdrLen+4:hdrLen+4+16],
  155. &data[0],
  156. data[hdrLen:hdrLen+4],
  157. )
  158. // 3. parse the header (and learn the actual length of the packet number)
  159. l, pn, pnLen, kp, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen)
  160. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
  161. return l, pn, pnLen, kp, parseErr
  162. }
  163. // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
  164. if pnLen != protocol.PacketNumberLen4 {
  165. copy(data[hdrLen+int(pnLen):hdrLen+4], origPNBytes[int(pnLen):])
  166. }
  167. return l, pn, pnLen, kp, parseErr
  168. }
  169. // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
  170. func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
  171. extHdr, err := unpackLongHeader(hd, hdr, data)
  172. if err != nil && err != wire.ErrInvalidReservedBits {
  173. return nil, &headerParseError{err: err}
  174. }
  175. return extHdr, err
  176. }
  177. func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
  178. hdrLen := hdr.ParsedLen()
  179. if protocol.ByteCount(len(data)) < hdrLen+4+16 {
  180. //nolint:stylecheck
  181. return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen)
  182. }
  183. // The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it.
  184. // 1. save a copy of the 4 bytes
  185. origPNBytes := make([]byte, 4)
  186. copy(origPNBytes, data[hdrLen:hdrLen+4])
  187. // 2. decrypt the header, assuming a 4 byte packet number
  188. hd.DecryptHeader(
  189. data[hdrLen+4:hdrLen+4+16],
  190. &data[0],
  191. data[hdrLen:hdrLen+4],
  192. )
  193. // 3. parse the header (and learn the actual length of the packet number)
  194. extHdr, parseErr := hdr.ParseExtended(data)
  195. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
  196. return nil, parseErr
  197. }
  198. // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
  199. if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
  200. copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
  201. }
  202. return extHdr, parseErr
  203. }