protocol.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. package hysteria
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "io"
  7. "github.com/apernet/quic-go/quicvarint"
  8. "github.com/xtls/xray-core/common/errors"
  9. )
  10. const (
  11. FrameTypeTCPRequest = 0x401
  12. // Max length values are for preventing DoS attacks
  13. MaxAddressLength = 2048
  14. MaxMessageLength = 2048
  15. MaxPaddingLength = 4096
  16. MaxUDPSize = 4096
  17. maxVarInt1 = 63
  18. maxVarInt2 = 16383
  19. maxVarInt4 = 1073741823
  20. maxVarInt8 = 4611686018427387903
  21. )
  22. // TCPRequest format:
  23. // 0x401 (QUIC varint)
  24. // Address length (QUIC varint)
  25. // Address (bytes)
  26. // Padding length (QUIC varint)
  27. // Padding (bytes)
  28. func WriteTCPRequest(w io.Writer, addr string) error {
  29. padding := tcpRequestPadding.String()
  30. paddingLen := len(padding)
  31. addrLen := len(addr)
  32. sz := int(quicvarint.Len(FrameTypeTCPRequest)) +
  33. int(quicvarint.Len(uint64(addrLen))) + addrLen +
  34. int(quicvarint.Len(uint64(paddingLen))) + paddingLen
  35. buf := make([]byte, sz)
  36. i := varintPut(buf, FrameTypeTCPRequest)
  37. i += varintPut(buf[i:], uint64(addrLen))
  38. i += copy(buf[i:], addr)
  39. i += varintPut(buf[i:], uint64(paddingLen))
  40. copy(buf[i:], padding)
  41. _, err := w.Write(buf)
  42. return err
  43. }
  44. // TCPResponse format:
  45. // Status (byte, 0=ok, 1=error)
  46. // Message length (QUIC varint)
  47. // Message (bytes)
  48. // Padding length (QUIC varint)
  49. // Padding (bytes)
  50. func ReadTCPResponse(r io.Reader) (bool, string, error) {
  51. var status [1]byte
  52. if _, err := io.ReadFull(r, status[:]); err != nil {
  53. return false, "", err
  54. }
  55. bReader := quicvarint.NewReader(r)
  56. msgLen, err := quicvarint.Read(bReader)
  57. if err != nil {
  58. return false, "", err
  59. }
  60. if msgLen > MaxMessageLength {
  61. return false, "", errors.New("invalid message length")
  62. }
  63. var msgBuf []byte
  64. // No message is fine
  65. if msgLen > 0 {
  66. msgBuf = make([]byte, msgLen)
  67. _, err = io.ReadFull(r, msgBuf)
  68. if err != nil {
  69. return false, "", err
  70. }
  71. }
  72. paddingLen, err := quicvarint.Read(bReader)
  73. if err != nil {
  74. return false, "", err
  75. }
  76. if paddingLen > MaxPaddingLength {
  77. return false, "", errors.New("invalid padding length")
  78. }
  79. if paddingLen > 0 {
  80. _, err = io.CopyN(io.Discard, r, int64(paddingLen))
  81. if err != nil {
  82. return false, "", err
  83. }
  84. }
  85. return status[0] == 0, string(msgBuf), nil
  86. }
  87. // UDPMessage format:
  88. // Session ID (uint32 BE)
  89. // Packet ID (uint16 BE)
  90. // Fragment ID (uint8)
  91. // Fragment count (uint8)
  92. // Address length (QUIC varint)
  93. // Address (bytes)
  94. // Data...
  95. type UDPMessage struct {
  96. SessionID uint32 // 4
  97. PacketID uint16 // 2
  98. FragID uint8 // 1
  99. FragCount uint8 // 1
  100. Addr string // varint + bytes
  101. Data []byte
  102. }
  103. func (m *UDPMessage) HeaderSize() int {
  104. lAddr := len(m.Addr)
  105. return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr
  106. }
  107. func (m *UDPMessage) Size() int {
  108. return m.HeaderSize() + len(m.Data)
  109. }
  110. func (m *UDPMessage) Serialize(buf []byte) int {
  111. // Make sure the buffer is big enough
  112. if len(buf) < m.Size() {
  113. return -1
  114. }
  115. // binary.BigEndian.PutUint32(buf, m.SessionID)
  116. binary.BigEndian.PutUint16(buf[4:], m.PacketID)
  117. buf[6] = m.FragID
  118. buf[7] = m.FragCount
  119. i := varintPut(buf[8:], uint64(len(m.Addr)))
  120. i += copy(buf[8+i:], m.Addr)
  121. i += copy(buf[8+i:], m.Data)
  122. return 8 + i
  123. }
  124. func ParseUDPMessage(msg []byte) (*UDPMessage, error) {
  125. m := &UDPMessage{}
  126. buf := bytes.NewBuffer(msg)
  127. if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil {
  128. return nil, err
  129. }
  130. if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil {
  131. return nil, err
  132. }
  133. if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil {
  134. return nil, err
  135. }
  136. if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil {
  137. return nil, err
  138. }
  139. lAddr, err := quicvarint.Read(buf)
  140. if err != nil {
  141. return nil, err
  142. }
  143. if lAddr == 0 || lAddr > MaxMessageLength {
  144. return nil, errors.New("invalid address length")
  145. }
  146. bs := buf.Bytes()
  147. if len(bs) <= int(lAddr) {
  148. // We use <= instead of < here as we expect at least one byte of data after the address
  149. return nil, errors.New("invalid message length")
  150. }
  151. m.Addr = string(bs[:lAddr])
  152. m.Data = bs[lAddr:]
  153. return m, nil
  154. }
  155. // varintPut is like quicvarint.Append, but instead of appending to a slice,
  156. // it writes to a fixed-size buffer. Returns the number of bytes written.
  157. func varintPut(b []byte, i uint64) int {
  158. if i <= maxVarInt1 {
  159. b[0] = uint8(i)
  160. return 1
  161. }
  162. if i <= maxVarInt2 {
  163. b[0] = uint8(i>>8) | 0x40
  164. b[1] = uint8(i)
  165. return 2
  166. }
  167. if i <= maxVarInt4 {
  168. b[0] = uint8(i>>24) | 0x80
  169. b[1] = uint8(i >> 16)
  170. b[2] = uint8(i >> 8)
  171. b[3] = uint8(i)
  172. return 4
  173. }
  174. if i <= maxVarInt8 {
  175. b[0] = uint8(i>>56) | 0xc0
  176. b[1] = uint8(i >> 48)
  177. b[2] = uint8(i >> 40)
  178. b[3] = uint8(i >> 32)
  179. b[4] = uint8(i >> 24)
  180. b[5] = uint8(i >> 16)
  181. b[6] = uint8(i >> 8)
  182. b[7] = uint8(i)
  183. return 8
  184. }
  185. panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i))
  186. }