dialer.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. package marionette
  2. import (
  3. "context"
  4. "errors"
  5. "net"
  6. "sync"
  7. "github.com/redjack/marionette/mar"
  8. "go.uber.org/zap"
  9. )
  10. var (
  11. // ErrDialerClosed is returned when trying to operate on a closed dialer.
  12. ErrDialerClosed = errors.New("marionette: dialer closed")
  13. )
  14. // Dialer represents a client-side dialer that communicates over the marionette protocol.
  15. type Dialer struct {
  16. mu sync.RWMutex
  17. addr string // Server hostport to connect to
  18. doc *mar.Document // Parsed MAR document
  19. fsm FSM // Associated FSM
  20. streamSet *StreamSet // Associated StreamSet
  21. // Close management
  22. ctx context.Context
  23. cancel func()
  24. closed bool
  25. wg sync.WaitGroup
  26. // Underlying NetDialer used for net connection.
  27. Dialer NetDialer
  28. }
  29. // NewDialer returns a new instance of Dialer.
  30. func NewDialer(doc *mar.Document, addr string, streamSet *StreamSet) *Dialer {
  31. // Run execution in a separate goroutine.
  32. d := &Dialer{
  33. addr: addr,
  34. doc: doc,
  35. streamSet: streamSet,
  36. Dialer: &net.Dialer{},
  37. }
  38. d.ctx, d.cancel = context.WithCancel(context.Background())
  39. return d
  40. }
  41. // Open initializes the underlying connection.
  42. func (d *Dialer) Open() error {
  43. conn, err := d.Dialer.DialContext(d.ctx, d.doc.Transport, net.JoinHostPort(d.addr, d.doc.Port))
  44. if err != nil {
  45. return err
  46. }
  47. d.fsm = NewFSM(d.doc, d.addr, PartyClient, conn, d.streamSet)
  48. d.wg.Add(1)
  49. go func() { defer d.wg.Done(); d.execute() }()
  50. return nil
  51. }
  52. // Close stops the dialer and its underlying connections.
  53. func (d *Dialer) Close() error {
  54. err := d.close()
  55. d.wg.Wait()
  56. return err
  57. }
  58. func (d *Dialer) close() (err error) {
  59. d.mu.Lock()
  60. d.closed = true
  61. err = d.fsm.Close()
  62. d.mu.Unlock()
  63. d.cancel()
  64. return err
  65. }
  66. // Closed returns true if the dialer has been closed.
  67. func (d *Dialer) Closed() bool {
  68. d.mu.RLock()
  69. closed := d.closed
  70. d.mu.RUnlock()
  71. return closed
  72. }
  73. // Dial returns a new stream from the dialer.
  74. func (d *Dialer) Dial() (net.Conn, error) {
  75. if d.Closed() {
  76. return nil, ErrDialerClosed
  77. }
  78. return d.streamSet.Create(), nil
  79. }
  80. // execute continually executes the FSM until the stream and dialer are closed.
  81. func (d *Dialer) execute() {
  82. defer d.close()
  83. for !d.Closed() {
  84. if err := d.fsm.Execute(d.ctx); err == ErrStreamClosed {
  85. continue
  86. } else if err != nil {
  87. Logger.Debug("dialer error", zap.Error(err))
  88. return
  89. }
  90. d.fsm.Reset()
  91. }
  92. }
  93. // NetDialer is an abstract dialer. net.Dialer implements the NetDialer interface.
  94. type NetDialer interface {
  95. Dial(network, address string) (net.Conn, error)
  96. DialContext(ctx context.Context, network, address string) (net.Conn, error)
  97. }