marionette.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. // +build MARIONETTE
  2. /*
  3. * Copyright (c) 2018, Psiphon Inc.
  4. * All rights reserved.
  5. *
  6. * This program is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU General Public License as published by
  8. * the Free Software Foundation, either version 3 of the License, or
  9. * (at your option) any later version.
  10. *
  11. * This program is distributed in the hope that it will be useful,
  12. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  14. * GNU General Public License for more details.
  15. *
  16. * You should have received a copy of the GNU General Public License
  17. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  18. *
  19. */
  20. /*
  21. Package marionette wraps github.com/redjack/marionette with net.Listener and
  22. net.Conn types that provide a drop-in replacement for net.TCPConn.
  23. Each marionette session has exactly one stream, which is the equivilent of a TCP
  24. stream.
  25. */
  26. package marionette
  27. import (
  28. "context"
  29. "net"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
  31. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  32. redjack_marionette "github.com/redjack/marionette"
  33. "github.com/redjack/marionette/mar"
  34. _ "github.com/redjack/marionette/plugins"
  35. "go.uber.org/zap"
  36. )
  37. func init() {
  38. // Override the Logger initialized by redjack_marionette.init()
  39. redjack_marionette.Logger = zap.NewNop()
  40. }
  41. // Enabled indicates if Marionette functionality is enabled.
  42. func Enabled() bool {
  43. return true
  44. }
  45. // Listener is a net.Listener.
  46. type Listener struct {
  47. net.Listener
  48. }
  49. // Listen creates a new Marionette Listener. The address input should not
  50. // include a port number as the port is defined in the Marionette format.
  51. func Listen(address, format string) (net.Listener, error) {
  52. data, err := mar.ReadFormat(format)
  53. if err != nil {
  54. return nil, errors.Trace(err)
  55. }
  56. doc, err := mar.Parse(redjack_marionette.PartyServer, data)
  57. if err != nil {
  58. return nil, errors.Trace(err)
  59. }
  60. listener, err := redjack_marionette.Listen(doc, address)
  61. if err != nil {
  62. return nil, errors.Trace(err)
  63. }
  64. return &Listener{Listener: listener}, nil
  65. }
  66. // Dial establishes a new Marionette session and stream to the server
  67. // specified by address. The address input should not include a port number as
  68. // that's defined in the Marionette format.
  69. func Dial(
  70. ctx context.Context,
  71. netDialer common.NetDialer,
  72. format string,
  73. address string) (net.Conn, error) {
  74. data, err := mar.ReadFormat(format)
  75. if err != nil {
  76. return nil, errors.Trace(err)
  77. }
  78. doc, err := mar.Parse(redjack_marionette.PartyClient, data)
  79. if err != nil {
  80. return nil, errors.Trace(err)
  81. }
  82. streamSet := redjack_marionette.NewStreamSet()
  83. dialer := redjack_marionette.NewDialer(doc, address, streamSet)
  84. dialer.Dialer = netDialer
  85. err = dialer.Open()
  86. if err != nil {
  87. streamSet.Close()
  88. return nil, errors.Trace(err)
  89. }
  90. // dialer.Dial does not block on network I/O
  91. conn, err := dialer.Dial()
  92. if err != nil {
  93. streamSet.Close()
  94. dialer.Close()
  95. return nil, errors.Trace(err)
  96. }
  97. return &Conn{
  98. Conn: conn,
  99. streamSet: streamSet,
  100. dialer: dialer,
  101. }, nil
  102. }
  103. // Conn is a net.Conn and psiphon/common.Closer.
  104. type Conn struct {
  105. net.Conn
  106. streamSet *redjack_marionette.StreamSet
  107. dialer *redjack_marionette.Dialer
  108. }
  109. func (conn *Conn) Close() error {
  110. if conn.IsClosed() {
  111. return nil
  112. }
  113. retErr := conn.Conn.Close()
  114. err := conn.streamSet.Close()
  115. if retErr == nil && err != nil {
  116. retErr = err
  117. }
  118. err = conn.dialer.Close()
  119. if retErr == nil && err != nil {
  120. retErr = err
  121. }
  122. return retErr
  123. }
  124. func (conn *Conn) IsClosed() bool {
  125. return conn.dialer.Closed()
  126. }