client.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637
  1. package stun
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "log"
  7. "net"
  8. "runtime"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. )
  13. // Dial connects to the address on the named network and then
  14. // initializes Client on that connection, returning error if any.
  15. func Dial(network, address string) (*Client, error) {
  16. conn, err := net.Dial(network, address)
  17. if err != nil {
  18. return nil, err
  19. }
  20. return NewClient(conn)
  21. }
  22. // ErrNoConnection means that ClientOptions.Connection is nil.
  23. var ErrNoConnection = errors.New("no connection provided")
  24. // ClientOption sets some client option.
  25. type ClientOption func(c *Client)
  26. // WithHandler sets client handler which is called if Agent emits the Event
  27. // with TransactionID that is not currently registered by Client.
  28. // Useful for handling Data indications from TURN server.
  29. func WithHandler(h Handler) ClientOption {
  30. return func(c *Client) {
  31. c.handler = h
  32. }
  33. }
  34. // WithRTO sets client RTO as defined in STUN RFC.
  35. func WithRTO(rto time.Duration) ClientOption {
  36. return func(c *Client) {
  37. c.rto = int64(rto)
  38. }
  39. }
  40. // WithClock sets Clock of client, the source of current time.
  41. // Also clock is passed to default collector if set.
  42. func WithClock(clock Clock) ClientOption {
  43. return func(c *Client) {
  44. c.clock = clock
  45. }
  46. }
  47. // WithTimeoutRate sets RTO timer minimum resolution.
  48. func WithTimeoutRate(d time.Duration) ClientOption {
  49. return func(c *Client) {
  50. c.rtoRate = d
  51. }
  52. }
  53. // WithAgent sets client STUN agent.
  54. //
  55. // Defaults to agent implementation in current package,
  56. // see agent.go.
  57. func WithAgent(a ClientAgent) ClientOption {
  58. return func(c *Client) {
  59. c.a = a
  60. }
  61. }
  62. // WithCollector rests client timeout collector, the implementation
  63. // of ticker which calls function on each tick.
  64. func WithCollector(coll Collector) ClientOption {
  65. return func(c *Client) {
  66. c.collector = coll
  67. }
  68. }
  69. // WithNoConnClose prevents client from closing underlying connection when
  70. // the Close() method is called.
  71. func WithNoConnClose() ClientOption {
  72. return func(c *Client) {
  73. c.closeConn = false
  74. }
  75. }
  76. // WithNoRetransmit disables retransmissions and sets RTO to
  77. // defaultMaxAttempts * defaultRTO which will be effectively time out
  78. // if not set.
  79. //
  80. // Useful for TCP connections where transport handles RTO.
  81. func WithNoRetransmit(c *Client) {
  82. c.maxAttempts = 0
  83. if c.rto == 0 {
  84. c.rto = defaultMaxAttempts * int64(defaultRTO)
  85. }
  86. }
  87. const (
  88. defaultTimeoutRate = time.Millisecond * 5
  89. defaultRTO = time.Millisecond * 300
  90. defaultMaxAttempts = 7
  91. )
  92. // NewClient initializes new Client from provided options,
  93. // starting internal goroutines and using default options fields
  94. // if necessary. Call Close method after using Client to close conn and
  95. // release resources.
  96. //
  97. // The conn will be closed on Close call. Use WithNoConnClose option to
  98. // prevent that.
  99. //
  100. // Note that user should handle the protocol multiplexing, client does not
  101. // provide any API for it, so if you need to read application data, wrap the
  102. // connection with your (de-)multiplexer and pass the wrapper as conn.
  103. func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
  104. c := &Client{
  105. close: make(chan struct{}),
  106. c: conn,
  107. clock: systemClock(),
  108. rto: int64(defaultRTO),
  109. rtoRate: defaultTimeoutRate,
  110. t: make(map[transactionID]*clientTransaction, 100),
  111. maxAttempts: defaultMaxAttempts,
  112. closeConn: true,
  113. }
  114. for _, o := range options {
  115. o(c)
  116. }
  117. if c.c == nil {
  118. return nil, ErrNoConnection
  119. }
  120. if c.a == nil {
  121. c.a = NewAgent(nil)
  122. }
  123. if err := c.a.SetHandler(c.handleAgentCallback); err != nil {
  124. return nil, err
  125. }
  126. if c.collector == nil {
  127. c.collector = &tickerCollector{
  128. close: make(chan struct{}),
  129. clock: c.clock,
  130. }
  131. }
  132. if err := c.collector.Start(c.rtoRate, func(t time.Time) {
  133. closedOrPanic(c.a.Collect(t))
  134. }); err != nil {
  135. return nil, err
  136. }
  137. c.wg.Add(1)
  138. go c.readUntilClosed()
  139. runtime.SetFinalizer(c, clientFinalizer)
  140. return c, nil
  141. }
  142. func clientFinalizer(c *Client) {
  143. if c == nil {
  144. return
  145. }
  146. err := c.Close()
  147. if errors.Is(err, ErrClientClosed) {
  148. return
  149. }
  150. if err == nil {
  151. log.Println("client: called finalizer on non-closed client") // nolint
  152. return
  153. }
  154. log.Println("client: called finalizer on non-closed client:", err) // nolint
  155. }
  156. // Connection wraps Reader, Writer and Closer interfaces.
  157. type Connection interface {
  158. io.Reader
  159. io.Writer
  160. io.Closer
  161. }
  162. // ClientAgent is Agent implementation that is used by Client to
  163. // process transactions.
  164. type ClientAgent interface {
  165. Process(*Message) error
  166. Close() error
  167. Start(id [TransactionIDSize]byte, deadline time.Time) error
  168. Stop(id [TransactionIDSize]byte) error
  169. Collect(time.Time) error
  170. SetHandler(h Handler) error
  171. }
  172. // Client simulates "connection" to STUN server.
  173. type Client struct {
  174. rto int64 // time.Duration
  175. a ClientAgent
  176. c Connection
  177. close chan struct{}
  178. rtoRate time.Duration
  179. maxAttempts int32
  180. closed bool
  181. closeConn bool // should call c.Close() while closing
  182. wg sync.WaitGroup
  183. clock Clock
  184. handler Handler
  185. collector Collector
  186. t map[transactionID]*clientTransaction
  187. // mux guards closed and t
  188. mux sync.RWMutex
  189. }
  190. // clientTransaction represents transaction in progress.
  191. // If transaction is succeed or failed, f will be called
  192. // provided by event.
  193. // Concurrent access is invalid.
  194. type clientTransaction struct {
  195. id transactionID
  196. attempt int32
  197. calls int32
  198. h Handler
  199. start time.Time
  200. rto time.Duration
  201. raw []byte
  202. }
  203. func (t *clientTransaction) handle(e Event) {
  204. if atomic.AddInt32(&t.calls, 1) == 1 {
  205. t.h(e)
  206. }
  207. }
  208. var clientTransactionPool = &sync.Pool{ // nolint:gochecknoglobals
  209. New: func() interface{} {
  210. return &clientTransaction{
  211. raw: make([]byte, 1500),
  212. }
  213. },
  214. }
  215. func acquireClientTransaction() *clientTransaction {
  216. return clientTransactionPool.Get().(*clientTransaction) //nolint:forcetypeassert
  217. }
  218. func putClientTransaction(t *clientTransaction) {
  219. t.raw = t.raw[:0]
  220. t.start = time.Time{}
  221. t.attempt = 0
  222. t.id = transactionID{}
  223. clientTransactionPool.Put(t)
  224. }
  225. func (t *clientTransaction) nextTimeout(now time.Time) time.Time {
  226. return now.Add(time.Duration(t.attempt+1) * t.rto)
  227. }
  228. // start registers transaction.
  229. //
  230. // Could return ErrClientClosed, ErrTransactionExists.
  231. func (c *Client) start(t *clientTransaction) error {
  232. c.mux.Lock()
  233. defer c.mux.Unlock()
  234. if c.closed {
  235. return ErrClientClosed
  236. }
  237. _, exists := c.t[t.id]
  238. if exists {
  239. return ErrTransactionExists
  240. }
  241. c.t[t.id] = t
  242. return nil
  243. }
  244. // Clock abstracts the source of current time.
  245. type Clock interface {
  246. Now() time.Time
  247. }
  248. type systemClockService struct{}
  249. func (systemClockService) Now() time.Time { return time.Now() }
  250. func systemClock() systemClockService {
  251. return systemClockService{}
  252. }
  253. // SetRTO sets current RTO value.
  254. func (c *Client) SetRTO(rto time.Duration) {
  255. atomic.StoreInt64(&c.rto, int64(rto))
  256. }
  257. // StopErr occurs when Client fails to stop transaction while
  258. // processing error.
  259. // nolint:errname
  260. type StopErr struct {
  261. Err error // value returned by Stop()
  262. Cause error // error that caused Stop() call
  263. }
  264. func (e StopErr) Error() string {
  265. return fmt.Sprintf("error while stopping due to %s: %s", sprintErr(e.Cause), sprintErr(e.Err))
  266. }
  267. // CloseErr indicates client close failure.
  268. // nolint:errname
  269. type CloseErr struct {
  270. AgentErr error
  271. ConnectionErr error
  272. }
  273. func sprintErr(err error) string {
  274. if err == nil {
  275. return "<nil>" // nolint:goconst
  276. }
  277. return err.Error()
  278. }
  279. func (c CloseErr) Error() string {
  280. return fmt.Sprintf("failed to close: %s (connection), %s (agent)", sprintErr(c.ConnectionErr), sprintErr(c.AgentErr))
  281. }
  282. func (c *Client) readUntilClosed() {
  283. defer c.wg.Done()
  284. m := new(Message)
  285. m.Raw = make([]byte, 1024)
  286. for {
  287. select {
  288. case <-c.close:
  289. return
  290. default:
  291. }
  292. _, err := m.ReadFrom(c.c)
  293. if err == nil {
  294. if pErr := c.a.Process(m); errors.Is(pErr, ErrAgentClosed) {
  295. return
  296. }
  297. }
  298. }
  299. }
  300. func closedOrPanic(err error) {
  301. if err == nil || errors.Is(err, ErrAgentClosed) {
  302. return
  303. }
  304. panic(err) // nolint
  305. }
  306. type tickerCollector struct {
  307. close chan struct{}
  308. wg sync.WaitGroup
  309. clock Clock
  310. }
  311. // Collector calls function f with constant rate.
  312. //
  313. // The simple Collector is ticker which calls function on each tick.
  314. type Collector interface {
  315. Start(rate time.Duration, f func(now time.Time)) error
  316. Close() error
  317. }
  318. func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error {
  319. t := time.NewTicker(rate)
  320. a.wg.Add(1)
  321. go func() {
  322. defer a.wg.Done()
  323. for {
  324. select {
  325. case <-a.close:
  326. t.Stop()
  327. return
  328. case <-t.C:
  329. f(a.clock.Now())
  330. }
  331. }
  332. }()
  333. return nil
  334. }
  335. func (a *tickerCollector) Close() error {
  336. close(a.close)
  337. a.wg.Wait()
  338. return nil
  339. }
  340. // ErrClientClosed indicates that client is closed.
  341. var ErrClientClosed = errors.New("client is closed")
  342. // Close stops internal connection and agent, returning CloseErr on error.
  343. func (c *Client) Close() error {
  344. if err := c.checkInit(); err != nil {
  345. return err
  346. }
  347. c.mux.Lock()
  348. if c.closed {
  349. c.mux.Unlock()
  350. return ErrClientClosed
  351. }
  352. c.closed = true
  353. c.mux.Unlock()
  354. if closeErr := c.collector.Close(); closeErr != nil {
  355. return closeErr
  356. }
  357. var connErr error
  358. agentErr := c.a.Close()
  359. if c.closeConn {
  360. connErr = c.c.Close()
  361. }
  362. close(c.close)
  363. c.wg.Wait()
  364. if agentErr == nil && connErr == nil {
  365. return nil
  366. }
  367. return CloseErr{
  368. AgentErr: agentErr,
  369. ConnectionErr: connErr,
  370. }
  371. }
  372. // Indicate sends indication m to server. Shorthand to Start call
  373. // with zero deadline and callback.
  374. func (c *Client) Indicate(m *Message) error {
  375. return c.Start(m, nil)
  376. }
  377. // callbackWaitHandler blocks on wait() call until callback is called.
  378. type callbackWaitHandler struct {
  379. handler Handler
  380. callback func(event Event)
  381. cond *sync.Cond
  382. processed bool
  383. }
  384. func (s *callbackWaitHandler) HandleEvent(e Event) {
  385. s.cond.L.Lock()
  386. if s.callback == nil {
  387. panic("s.callback is nil") // nolint
  388. }
  389. s.callback(e)
  390. s.processed = true
  391. s.cond.Broadcast()
  392. s.cond.L.Unlock()
  393. }
  394. func (s *callbackWaitHandler) wait() {
  395. s.cond.L.Lock()
  396. for !s.processed {
  397. s.cond.Wait()
  398. }
  399. s.processed = false
  400. s.callback = nil
  401. s.cond.L.Unlock()
  402. }
  403. func (s *callbackWaitHandler) setCallback(f func(event Event)) {
  404. if f == nil {
  405. panic("f is nil") // nolint
  406. }
  407. s.cond.L.Lock()
  408. s.callback = f
  409. if s.handler == nil {
  410. s.handler = s.HandleEvent
  411. }
  412. s.cond.L.Unlock()
  413. }
  414. var callbackWaitHandlerPool = sync.Pool{ // nolint:gochecknoglobals
  415. New: func() interface{} {
  416. return &callbackWaitHandler{
  417. cond: sync.NewCond(new(sync.Mutex)),
  418. }
  419. },
  420. }
  421. // ErrClientNotInitialized means that client connection or agent is nil.
  422. var ErrClientNotInitialized = errors.New("client not initialized")
  423. func (c *Client) checkInit() error {
  424. if c == nil || c.c == nil || c.a == nil || c.close == nil {
  425. return ErrClientNotInitialized
  426. }
  427. return nil
  428. }
  429. // Do is Start wrapper that waits until callback is called. If no callback
  430. // provided, Indicate is called instead.
  431. //
  432. // Do has cpu overhead due to blocking, see BenchmarkClient_Do.
  433. // Use Start method for less overhead.
  434. func (c *Client) Do(m *Message, f func(Event)) error {
  435. if err := c.checkInit(); err != nil {
  436. return err
  437. }
  438. if f == nil {
  439. return c.Indicate(m)
  440. }
  441. h := callbackWaitHandlerPool.Get().(*callbackWaitHandler) //nolint:forcetypeassert
  442. h.setCallback(f)
  443. defer func() {
  444. callbackWaitHandlerPool.Put(h)
  445. }()
  446. if err := c.Start(m, h.handler); err != nil {
  447. return err
  448. }
  449. h.wait()
  450. return nil
  451. }
  452. func (c *Client) delete(id transactionID) {
  453. c.mux.Lock()
  454. if c.t != nil {
  455. delete(c.t, id)
  456. }
  457. c.mux.Unlock()
  458. }
  459. type buffer struct {
  460. buf []byte
  461. }
  462. var bufferPool = &sync.Pool{ // nolint:gochecknoglobals
  463. New: func() interface{} {
  464. return &buffer{buf: make([]byte, 2048)}
  465. },
  466. }
  467. func (c *Client) handleAgentCallback(e Event) {
  468. c.mux.Lock()
  469. if c.closed {
  470. c.mux.Unlock()
  471. return
  472. }
  473. t, found := c.t[e.TransactionID]
  474. if found {
  475. delete(c.t, t.id)
  476. }
  477. c.mux.Unlock()
  478. if !found {
  479. if c.handler != nil && !errors.Is(e.Error, ErrTransactionStopped) {
  480. c.handler(e)
  481. }
  482. // Ignoring.
  483. return
  484. }
  485. if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil {
  486. // Transaction completed.
  487. t.handle(e)
  488. putClientTransaction(t)
  489. return
  490. }
  491. // Doing re-transmission.
  492. t.attempt++
  493. b := bufferPool.Get().(*buffer) //nolint:forcetypeassert
  494. b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)]
  495. defer bufferPool.Put(b)
  496. var (
  497. now = c.clock.Now()
  498. timeOut = t.nextTimeout(now)
  499. id = t.id
  500. )
  501. // Starting client transaction.
  502. if startErr := c.start(t); startErr != nil {
  503. c.delete(id)
  504. e.Error = startErr
  505. t.handle(e)
  506. putClientTransaction(t)
  507. return
  508. }
  509. // Starting agent transaction.
  510. if startErr := c.a.Start(id, timeOut); startErr != nil {
  511. c.delete(id)
  512. e.Error = startErr
  513. t.handle(e)
  514. putClientTransaction(t)
  515. return
  516. }
  517. // Writing message to connection again.
  518. _, writeErr := c.c.Write(b.buf)
  519. if writeErr != nil {
  520. c.delete(id)
  521. e.Error = writeErr
  522. // Stopping agent transaction instead of waiting until it's deadline.
  523. // This will call handleAgentCallback with "ErrTransactionStopped" error
  524. // which will be ignored.
  525. if stopErr := c.a.Stop(id); stopErr != nil {
  526. // Failed to stop agent transaction. Wrapping the error in StopError.
  527. e.Error = StopErr{
  528. Err: stopErr,
  529. Cause: writeErr,
  530. }
  531. }
  532. t.handle(e)
  533. putClientTransaction(t)
  534. return
  535. }
  536. }
  537. // Start starts transaction (if h set) and writes message to server, handler
  538. // is called asynchronously.
  539. func (c *Client) Start(m *Message, h Handler) error {
  540. if err := c.checkInit(); err != nil {
  541. return err
  542. }
  543. c.mux.RLock()
  544. closed := c.closed
  545. c.mux.RUnlock()
  546. if closed {
  547. return ErrClientClosed
  548. }
  549. if h != nil {
  550. // Starting transaction only if h is set. Useful for indications.
  551. t := acquireClientTransaction()
  552. t.id = m.TransactionID
  553. t.start = c.clock.Now()
  554. t.h = h
  555. t.rto = time.Duration(atomic.LoadInt64(&c.rto))
  556. t.attempt = 0
  557. t.raw = append(t.raw[:0], m.Raw...)
  558. t.calls = 0
  559. d := t.nextTimeout(t.start)
  560. if err := c.start(t); err != nil {
  561. return err
  562. }
  563. if err := c.a.Start(m.TransactionID, d); err != nil {
  564. return err
  565. }
  566. }
  567. _, err := m.WriteTo(c.c)
  568. if err != nil && h != nil {
  569. c.delete(m.TransactionID)
  570. // Stopping transaction instead of waiting until deadline.
  571. if stopErr := c.a.Stop(m.TransactionID); stopErr != nil {
  572. return StopErr{
  573. Err: stopErr,
  574. Cause: err,
  575. }
  576. }
  577. }
  578. return err
  579. }