| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- // +build MARIONETTE
- /*
- * Copyright (c) 2018, Psiphon Inc.
- * All rights reserved.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see <http://www.gnu.org/licenses/>.
- *
- */
- /*
- Package marionette wraps github.com/redjack/marionette with net.Listener and
- net.Conn types that provide a drop-in replacement for net.TCPConn.
- Each marionette session has exactly one stream, which is the equivilent of a TCP
- stream.
- */
- package marionette
- import (
- "context"
- "net"
- "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
- "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
- redjack_marionette "github.com/redjack/marionette"
- "github.com/redjack/marionette/mar"
- _ "github.com/redjack/marionette/plugins"
- "go.uber.org/zap"
- )
- func init() {
- // Override the Logger initialized by redjack_marionette.init()
- redjack_marionette.Logger = zap.NewNop()
- }
- // Enabled indicates if Marionette functionality is enabled.
- func Enabled() bool {
- return true
- }
- // Listener is a net.Listener.
- type Listener struct {
- net.Listener
- }
- // Listen creates a new Marionette Listener. The address input should not
- // include a port number as the port is defined in the Marionette format.
- func Listen(address, format string) (net.Listener, error) {
- data, err := mar.ReadFormat(format)
- if err != nil {
- return nil, errors.Trace(err)
- }
- doc, err := mar.Parse(redjack_marionette.PartyServer, data)
- if err != nil {
- return nil, errors.Trace(err)
- }
- listener, err := redjack_marionette.Listen(doc, address)
- if err != nil {
- return nil, errors.Trace(err)
- }
- return &Listener{Listener: listener}, nil
- }
- // Dial establishes a new Marionette session and stream to the server
- // specified by address. The address input should not include a port number as
- // that's defined in the Marionette format.
- func Dial(
- ctx context.Context,
- netDialer common.NetDialer,
- format string,
- address string) (net.Conn, error) {
- data, err := mar.ReadFormat(format)
- if err != nil {
- return nil, errors.Trace(err)
- }
- doc, err := mar.Parse(redjack_marionette.PartyClient, data)
- if err != nil {
- return nil, errors.Trace(err)
- }
- streamSet := redjack_marionette.NewStreamSet()
- dialer := redjack_marionette.NewDialer(doc, address, streamSet)
- dialer.Dialer = netDialer
- err = dialer.Open()
- if err != nil {
- streamSet.Close()
- return nil, errors.Trace(err)
- }
- // dialer.Dial does not block on network I/O
- conn, err := dialer.Dial()
- if err != nil {
- streamSet.Close()
- dialer.Close()
- return nil, errors.Trace(err)
- }
- return &Conn{
- Conn: conn,
- streamSet: streamSet,
- dialer: dialer,
- }, nil
- }
- // Conn is a net.Conn and psiphon/common.Closer.
- type Conn struct {
- net.Conn
- streamSet *redjack_marionette.StreamSet
- dialer *redjack_marionette.Dialer
- }
- func (conn *Conn) Close() error {
- if conn.IsClosed() {
- return nil
- }
- retErr := conn.Conn.Close()
- err := conn.streamSet.Close()
- if retErr == nil && err != nil {
- retErr = err
- }
- err = conn.dialer.Close()
- if retErr == nil && err != nil {
- retErr = err
- }
- return retErr
- }
- func (conn *Conn) IsClosed() bool {
- return conn.dialer.Closed()
- }
|