tcp_packet_conn.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package ice
  4. import (
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. "github.com/pion/logging"
  13. "github.com/pion/transport/v2/packetio"
  14. )
  15. type bufferedConn struct {
  16. net.Conn
  17. buf *packetio.Buffer
  18. logger logging.LeveledLogger
  19. closed int32
  20. }
  21. func newBufferedConn(conn net.Conn, bufSize int, logger logging.LeveledLogger) net.Conn {
  22. buf := packetio.NewBuffer()
  23. if bufSize > 0 {
  24. buf.SetLimitSize(bufSize)
  25. }
  26. bc := &bufferedConn{
  27. Conn: conn,
  28. buf: buf,
  29. logger: logger,
  30. }
  31. go bc.writeProcess()
  32. return bc
  33. }
  34. func (bc *bufferedConn) Write(b []byte) (int, error) {
  35. n, err := bc.buf.Write(b)
  36. if err != nil {
  37. return n, err
  38. }
  39. return n, nil
  40. }
  41. func (bc *bufferedConn) writeProcess() {
  42. pktBuf := make([]byte, receiveMTU)
  43. for atomic.LoadInt32(&bc.closed) == 0 {
  44. n, err := bc.buf.Read(pktBuf)
  45. if errors.Is(err, io.EOF) {
  46. return
  47. }
  48. if err != nil {
  49. bc.logger.Warnf("Failed to read from buffer: %s", err)
  50. continue
  51. }
  52. if _, err := bc.Conn.Write(pktBuf[:n]); err != nil {
  53. bc.logger.Warnf("Failed to write: %s", err)
  54. continue
  55. }
  56. }
  57. }
  58. func (bc *bufferedConn) Close() error {
  59. atomic.StoreInt32(&bc.closed, 1)
  60. _ = bc.buf.Close()
  61. return bc.Conn.Close()
  62. }
  63. type tcpPacketConn struct {
  64. params *tcpPacketParams
  65. // conns is a map of net.Conns indexed by remote net.Addr.String()
  66. conns map[string]net.Conn
  67. recvChan chan streamingPacket
  68. mu sync.Mutex
  69. wg sync.WaitGroup
  70. closedChan chan struct{}
  71. closeOnce sync.Once
  72. aliveTimer *time.Timer
  73. }
  74. type streamingPacket struct {
  75. Data []byte
  76. RAddr net.Addr
  77. Err error
  78. }
  79. type tcpPacketParams struct {
  80. ReadBuffer int
  81. LocalAddr net.Addr
  82. Logger logging.LeveledLogger
  83. WriteBuffer int
  84. AliveDuration time.Duration
  85. }
  86. func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
  87. p := &tcpPacketConn{
  88. params: &params,
  89. conns: map[string]net.Conn{},
  90. recvChan: make(chan streamingPacket, params.ReadBuffer),
  91. closedChan: make(chan struct{}),
  92. }
  93. if params.AliveDuration > 0 {
  94. p.aliveTimer = time.AfterFunc(params.AliveDuration, func() {
  95. p.params.Logger.Warn("close tcp packet conn by alive timeout")
  96. _ = p.Close()
  97. })
  98. }
  99. return p
  100. }
  101. func (t *tcpPacketConn) ClearAliveTimer() {
  102. t.mu.Lock()
  103. if t.aliveTimer != nil {
  104. t.aliveTimer.Stop()
  105. }
  106. t.mu.Unlock()
  107. }
  108. func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error {
  109. t.params.Logger.Infof("Added connection: %s remote %s to local %s", conn.RemoteAddr().Network(), conn.RemoteAddr(), conn.LocalAddr())
  110. t.mu.Lock()
  111. defer t.mu.Unlock()
  112. select {
  113. case <-t.closedChan:
  114. return io.ErrClosedPipe
  115. default:
  116. }
  117. if _, ok := t.conns[conn.RemoteAddr().String()]; ok {
  118. return fmt.Errorf("%w: %s", errConnectionAddrAlreadyExist, conn.RemoteAddr().String())
  119. }
  120. if t.params.WriteBuffer > 0 {
  121. conn = newBufferedConn(conn, t.params.WriteBuffer, t.params.Logger)
  122. }
  123. t.conns[conn.RemoteAddr().String()] = conn
  124. t.wg.Add(1)
  125. go func() {
  126. defer t.wg.Done()
  127. if firstPacketData != nil {
  128. select {
  129. case <-t.closedChan:
  130. // NOTE: recvChan can fill up and never drain in edge
  131. // cases while closing a connection, which can cause the
  132. // packetConn to never finish closing. Bail out early
  133. // here to prevent that.
  134. return
  135. case t.recvChan <- streamingPacket{firstPacketData, conn.RemoteAddr(), nil}:
  136. }
  137. }
  138. t.startReading(conn)
  139. }()
  140. return nil
  141. }
  142. func (t *tcpPacketConn) startReading(conn net.Conn) {
  143. buf := make([]byte, receiveMTU)
  144. for {
  145. n, err := readStreamingPacket(conn, buf)
  146. if err != nil {
  147. t.params.Logger.Warnf("Failed to read streaming packet: %s", err)
  148. t.handleRecv(streamingPacket{nil, conn.RemoteAddr(), err})
  149. t.removeConn(conn)
  150. return
  151. }
  152. data := make([]byte, n)
  153. copy(data, buf[:n])
  154. t.handleRecv(streamingPacket{data, conn.RemoteAddr(), nil})
  155. }
  156. }
  157. func (t *tcpPacketConn) handleRecv(pkt streamingPacket) {
  158. t.mu.Lock()
  159. recvChan := t.recvChan
  160. if t.isClosed() {
  161. recvChan = nil
  162. }
  163. t.mu.Unlock()
  164. select {
  165. case recvChan <- pkt:
  166. case <-t.closedChan:
  167. }
  168. }
  169. func (t *tcpPacketConn) isClosed() bool {
  170. select {
  171. case <-t.closedChan:
  172. return true
  173. default:
  174. return false
  175. }
  176. }
  177. // WriteTo is for passive and s-o candidates.
  178. func (t *tcpPacketConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
  179. pkt, ok := <-t.recvChan
  180. if !ok {
  181. return 0, nil, io.ErrClosedPipe
  182. }
  183. if pkt.Err != nil {
  184. return 0, pkt.RAddr, pkt.Err
  185. }
  186. if cap(b) < len(pkt.Data) {
  187. return 0, pkt.RAddr, io.ErrShortBuffer
  188. }
  189. n = len(pkt.Data)
  190. copy(b, pkt.Data[:n])
  191. return n, pkt.RAddr, err
  192. }
  193. // WriteTo is for active and s-o candidates.
  194. func (t *tcpPacketConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
  195. t.mu.Lock()
  196. conn, ok := t.conns[rAddr.String()]
  197. t.mu.Unlock()
  198. if !ok {
  199. return 0, io.ErrClosedPipe
  200. }
  201. n, err = writeStreamingPacket(conn, buf)
  202. if err != nil {
  203. t.params.Logger.Tracef("%w %s", errWrite, rAddr)
  204. return n, err
  205. }
  206. return n, err
  207. }
  208. func (t *tcpPacketConn) closeAndLogError(closer io.Closer) {
  209. err := closer.Close()
  210. if err != nil {
  211. t.params.Logger.Warnf("%v: %s", errClosingConnection, err)
  212. }
  213. }
  214. func (t *tcpPacketConn) removeConn(conn net.Conn) {
  215. t.mu.Lock()
  216. defer t.mu.Unlock()
  217. t.closeAndLogError(conn)
  218. delete(t.conns, conn.RemoteAddr().String())
  219. }
  220. func (t *tcpPacketConn) Close() error {
  221. t.mu.Lock()
  222. var shouldCloseRecvChan bool
  223. t.closeOnce.Do(func() {
  224. close(t.closedChan)
  225. shouldCloseRecvChan = true
  226. if t.aliveTimer != nil {
  227. t.aliveTimer.Stop()
  228. }
  229. })
  230. for _, conn := range t.conns {
  231. t.closeAndLogError(conn)
  232. delete(t.conns, conn.RemoteAddr().String())
  233. }
  234. t.mu.Unlock()
  235. t.wg.Wait()
  236. if shouldCloseRecvChan {
  237. close(t.recvChan)
  238. }
  239. return nil
  240. }
  241. func (t *tcpPacketConn) LocalAddr() net.Addr {
  242. return t.params.LocalAddr
  243. }
  244. func (t *tcpPacketConn) SetDeadline(time.Time) error {
  245. return nil
  246. }
  247. func (t *tcpPacketConn) SetReadDeadline(time.Time) error {
  248. return nil
  249. }
  250. func (t *tcpPacketConn) SetWriteDeadline(time.Time) error {
  251. return nil
  252. }
  253. func (t *tcpPacketConn) CloseChannel() <-chan struct{} {
  254. return t.closedChan
  255. }
  256. func (t *tcpPacketConn) String() string {
  257. return fmt.Sprintf("tcpPacketConn{LocalAddr: %s}", t.params.LocalAddr)
  258. }