packet_handler_map.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. package gquic
  2. import (
  3. "bytes"
  4. "fmt"
  5. "net"
  6. "sync"
  7. "time"
  8. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
  9. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
  10. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/wire"
  11. )
  12. // The packetHandlerMap stores packetHandlers, identified by connection ID.
  13. // It is used:
  14. // * by the server to store sessions
  15. // * when multiplexing outgoing connections to store clients
  16. type packetHandlerMap struct {
  17. mutex sync.RWMutex
  18. conn net.PacketConn
  19. connIDLen int
  20. handlers map[string] /* string(ConnectionID)*/ packetHandler
  21. server unknownPacketHandler
  22. closed bool
  23. deleteClosedSessionsAfter time.Duration
  24. logger utils.Logger
  25. }
  26. var _ packetHandlerManager = &packetHandlerMap{}
  27. func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager {
  28. m := &packetHandlerMap{
  29. conn: conn,
  30. connIDLen: connIDLen,
  31. handlers: make(map[string]packetHandler),
  32. deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
  33. logger: logger,
  34. }
  35. go m.listen()
  36. return m
  37. }
  38. func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
  39. h.mutex.Lock()
  40. h.handlers[string(id)] = handler
  41. h.mutex.Unlock()
  42. }
  43. func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
  44. h.removeByConnectionIDAsString(string(id))
  45. }
  46. func (h *packetHandlerMap) removeByConnectionIDAsString(id string) {
  47. h.mutex.Lock()
  48. h.handlers[id] = nil
  49. h.mutex.Unlock()
  50. time.AfterFunc(h.deleteClosedSessionsAfter, func() {
  51. h.mutex.Lock()
  52. delete(h.handlers, id)
  53. h.mutex.Unlock()
  54. })
  55. }
  56. func (h *packetHandlerMap) SetServer(s unknownPacketHandler) {
  57. h.mutex.Lock()
  58. h.server = s
  59. h.mutex.Unlock()
  60. }
  61. func (h *packetHandlerMap) CloseServer() {
  62. h.mutex.Lock()
  63. h.server = nil
  64. var wg sync.WaitGroup
  65. for id, handler := range h.handlers {
  66. if handler != nil && handler.GetPerspective() == protocol.PerspectiveServer {
  67. wg.Add(1)
  68. go func(id string, handler packetHandler) {
  69. // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped
  70. _ = handler.Close()
  71. h.removeByConnectionIDAsString(id)
  72. wg.Done()
  73. }(id, handler)
  74. }
  75. }
  76. h.mutex.Unlock()
  77. wg.Wait()
  78. }
  79. func (h *packetHandlerMap) close(e error) error {
  80. h.mutex.Lock()
  81. if h.closed {
  82. h.mutex.Unlock()
  83. return nil
  84. }
  85. h.closed = true
  86. var wg sync.WaitGroup
  87. for _, handler := range h.handlers {
  88. if handler != nil {
  89. wg.Add(1)
  90. go func(handler packetHandler) {
  91. handler.destroy(e)
  92. wg.Done()
  93. }(handler)
  94. }
  95. }
  96. // [Psiphon]
  97. // Call h.server.setCloseError(e) outside of mutex to prevent deadlock
  98. //
  99. // sync.(*RWMutex).Lock
  100. // [...]/lucas-clemente/quic-go.(*packetHandlerMap).CloseServer
  101. // [...]/lucas-clemente/quic-go.(*server).closeWithMutex
  102. // [...]/lucas-clemente/quic-go.(*server).closeWithError
  103. // [...]/lucas-clemente/quic-go.(*packetHandlerMap).close
  104. // [...]/lucas-clemente/quic-go.(*packetHandlerMap).listen
  105. //
  106. // packetHandlerMap.CloseServer is attempting to lock the same mutex that
  107. // is already locked in packetHandlerMap.close, which deadlocks. As
  108. // packetHandlerMap and its mutex are used by all client sessions, this
  109. // effectively hangs the entire server.
  110. var server unknownPacketHandler
  111. if h.server != nil {
  112. server = h.server
  113. }
  114. h.mutex.Unlock()
  115. if server != nil {
  116. server.closeWithError(e)
  117. }
  118. wg.Wait()
  119. return nil
  120. }
  121. func (h *packetHandlerMap) listen() {
  122. for {
  123. data := *getPacketBuffer()
  124. data = data[:protocol.MaxReceivePacketSize]
  125. // The packet size should not exceed protocol.MaxReceivePacketSize bytes
  126. // If it does, we only read a truncated packet, which will then end up undecryptable
  127. n, addr, err := h.conn.ReadFrom(data)
  128. if err != nil {
  129. // [Psiphon]
  130. // Do not unconditionally shutdown
  131. if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() {
  132. h.close(err)
  133. return
  134. }
  135. }
  136. data = data[:n]
  137. if err := h.handlePacket(addr, data); err != nil {
  138. h.logger.Debugf("error handling packet from %s: %s", addr, err)
  139. }
  140. }
  141. }
  142. func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
  143. rcvTime := time.Now()
  144. r := bytes.NewReader(data)
  145. iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen)
  146. // drop the packet if we can't parse the header
  147. if err != nil {
  148. return fmt.Errorf("error parsing invariant header: %s", err)
  149. }
  150. h.mutex.RLock()
  151. handler, ok := h.handlers[string(iHdr.DestConnectionID)]
  152. server := h.server
  153. h.mutex.RUnlock()
  154. var sentBy protocol.Perspective
  155. var version protocol.VersionNumber
  156. var handlePacket func(*receivedPacket)
  157. if ok && handler == nil {
  158. // Late packet for closed session
  159. return nil
  160. }
  161. if !ok {
  162. if server == nil { // no server set
  163. return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
  164. }
  165. handlePacket = server.handlePacket
  166. sentBy = protocol.PerspectiveClient
  167. version = iHdr.Version
  168. } else {
  169. sentBy = handler.GetPerspective().Opposite()
  170. version = handler.GetVersion()
  171. handlePacket = handler.handlePacket
  172. }
  173. hdr, err := iHdr.Parse(r, sentBy, version)
  174. if err != nil {
  175. return fmt.Errorf("error parsing header: %s", err)
  176. }
  177. hdr.Raw = data[:len(data)-r.Len()]
  178. packetData := data[len(data)-r.Len():]
  179. if hdr.IsLongHeader && hdr.Version.UsesLengthInHeader() {
  180. if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
  181. return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
  182. }
  183. packetData = packetData[:int(hdr.PayloadLen)]
  184. // TODO(#1312): implement parsing of compound packets
  185. }
  186. handlePacket(&receivedPacket{
  187. remoteAddr: addr,
  188. header: hdr,
  189. data: packetData,
  190. rcvTime: rcvTime,
  191. })
  192. return nil
  193. }