Browse Source

Update gotapdance with context support and bug fixes

Rod Hynes 7 years ago
parent
commit
4155cf30a7

+ 27 - 66
psiphon/common/tapdance/tapdance.go

@@ -46,8 +46,6 @@ import (
 
 const (
 	READ_PROXY_PROTOCOL_HEADER_TIMEOUT = 5 * time.Second
-	REDIAL_TCP_TIMEOUT_MIN             = 10 * time.Second
-	REDIAL_TCP_TIMEOUT_MAX             = 15 * time.Second
 )
 
 func init() {
@@ -96,7 +94,8 @@ func Listen(address string) (*Listener, error) {
 // all pending dials and established conns immediately. This ensures that
 // blocking calls within refraction_networking_tapdance, such as tls.Handhake,
 // are interrupted:
-// E.g., https://github.com/sergeyfrolov/gotapdance/blob/4581c3f01ac46b90ed4b58cce9c0438f732bf915/tapdance/conn_raw.go#L274
+// E.g., https://github.com/sergeyfrolov/gotapdance/blob/2ce6ef6667d52f7391a92fd8ec9dffb97ec4e2e8/tapdance/conn_raw.go#L260
+// (...preceeding SetDeadline is insufficient for immediate cancellation.)
 type dialManager struct {
 	tcpDialer func(ctx context.Context, network, address string) (net.Conn, error)
 
@@ -110,48 +109,43 @@ type dialManager struct {
 }
 
 func newDialManager(
-	tcpDialer func(ctx context.Context, network, address string) (net.Conn, error),
-	initialDialCtx context.Context) *dialManager {
+	tcpDialer func(ctx context.Context, network, address string) (net.Conn, error)) *dialManager {
 
 	runCtx, stopRunning := context.WithCancel(context.Background())
 
 	return &dialManager{
-		tcpDialer:      tcpDialer,
-		initialDialCtx: initialDialCtx,
-		runCtx:         runCtx,
-		stopRunning:    stopRunning,
-		conns:          common.NewConns(),
+		tcpDialer:   tcpDialer,
+		runCtx:      runCtx,
+		stopRunning: stopRunning,
+		conns:       common.NewConns(),
 	}
 }
 
-func (manager *dialManager) dial(network, address string) (net.Conn, error) {
+func (manager *dialManager) dial(ctx context.Context, network, address string) (net.Conn, error) {
 
 	if network != "tcp" {
 		return nil, common.ContextError(fmt.Errorf("unsupported network: %s", network))
 	}
 
 	// The context for this dial is either:
-	// - manager.initialDialCtx during the initial tapdance.Dial, in which case
-	//   this is Psiphon tunnel establishment, which has an externally specified
-	//   timeout.
+	// - ctx, during the initial tapdance.DialContext, when this is Psiphon tunnel
+	//   establishment.
 	// - manager.runCtx after the initial tapdance.Dial completes, in which case
 	//   this is a Tapdance protocol reconnection that occurs periodically for
-	//   already established tunnels; this uses an internal timeout.
+	//   already established tunnels.
 
 	manager.ctxMutex.Lock()
-	var ctx context.Context
-	var cancelFunc context.CancelFunc
 	if manager.useRunCtx {
-		// Random timeout replicates tapdance client behavior with stock dialer:
-		// https://github.com/sergeyfrolov/gotapdance/blob/4581c3f01ac46b90ed4b58cce9c0438f732bf915/tapdance/conn_raw.go#L246
-		timeout, err := common.MakeSecureRandomPeriod(REDIAL_TCP_TIMEOUT_MIN, REDIAL_TCP_TIMEOUT_MAX)
-		if err != nil {
-			manager.ctxMutex.Unlock()
-			return nil, common.ContextError(err)
+
+		// Preserve the random timeout configured by the tapdance client:
+		// https://github.com/sergeyfrolov/gotapdance/blob/2ce6ef6667d52f7391a92fd8ec9dffb97ec4e2e8/tapdance/conn_raw.go#L219
+		deadline, ok := ctx.Deadline()
+		if !ok {
+			return nil, common.ContextError(fmt.Errorf("unexpected nil deadline"))
 		}
-		ctx, cancelFunc = context.WithTimeout(manager.runCtx, timeout)
-	} else {
-		ctx = manager.initialDialCtx
+		var cancelFunc context.CancelFunc
+		ctx, cancelFunc = context.WithDeadline(manager.runCtx, deadline)
+		defer cancelFunc()
 	}
 	manager.ctxMutex.Unlock()
 
@@ -160,10 +154,6 @@ func (manager *dialManager) dial(network, address string) (net.Conn, error) {
 		return nil, common.ContextError(err)
 	}
 
-	if cancelFunc != nil {
-		cancelFunc()
-	}
-
 	conn = &managedConn{
 		Conn:    conn,
 		manager: manager,
@@ -253,51 +243,22 @@ func Dial(
 		return nil, common.ContextError(errors.New("dial context has no timeout"))
 	}
 
-	manager := newDialManager(netDialer.DialContext, ctx)
+	manager := newDialManager(netDialer.DialContext)
 
-	type tapdanceDialResult struct {
-		conn net.Conn
-		err  error
+	tapdanceDialer := &refraction_networking_tapdance.Dialer{
+		TcpDialer: manager.dial,
 	}
 
-	resultChannel := make(chan tapdanceDialResult)
-
-	go func() {
-		tapdanceDialer := &refraction_networking_tapdance.Dialer{
-			TcpDialer: manager.dial,
-		}
-
-		conn, err := tapdanceDialer.Dial("tcp", address)
-		if err != nil {
-			err = common.ContextError(err)
-		}
-
-		resultChannel <- tapdanceDialResult{
-			conn: conn,
-			err:  err,
-		}
-	}()
-
-	var result tapdanceDialResult
-
-	select {
-	case result = <-resultChannel:
-	case <-ctx.Done():
-		result.err = ctx.Err()
-		// Interrupt the goroutine
-		manager.close()
-		<-resultChannel
-	}
-
-	if result.err != nil {
+	conn, err := tapdanceDialer.DialContext(ctx, "tcp", address)
+	if err != nil {
 		manager.close()
-		return nil, common.ContextError(result.err)
+		return nil, common.ContextError(err)
 	}
 
 	manager.startUsingRunCtx()
 
 	return &tapdanceConn{
-		Conn:    result.conn,
+		Conn:    conn,
 		manager: manager,
 	}, nil
 }

+ 1 - 1
psiphon/net.go

@@ -163,7 +163,7 @@ func (d *NetDialer) Dial(network, address string) (net.Conn, error) {
 func (d *NetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
 	switch network {
 	case "tcp":
-		return d.dialTCP(context.Background(), "tcp", address)
+		return d.dialTCP(ctx, "tcp", address)
 	default:
 		return nil, common.ContextError(fmt.Errorf("unsupported network: %s", network))
 	}

+ 68 - 31
vendor/github.com/sergeyfrolov/bsbuffer/bsbuffer.go

@@ -7,6 +7,7 @@ package bsbuffer
 import (
 	"bytes"
 	"io"
+	"io/ioutil"
 	"sync"
 )
 
@@ -15,15 +16,17 @@ import (
 // S - Safe - Supports arbitrary amount of readers and writers.
 // Could be unblocked and turned into SBuffer.
 type BSBuffer struct {
-	sync.Mutex
-	bufIn  bytes.Buffer
-	bufOut bytes.Buffer
-	r      *io.PipeReader
-	w      *io.PipeWriter
+	mu sync.Mutex
 
-	hasData    chan struct{}
-	engineExit chan struct{}
-	unblocked  bool
+	bufBlocked   bytes.Buffer // used before Unblock() is called
+	bufUnblocked bytes.Buffer // used after Unblock() is called
+
+	r *io.PipeReader
+	w *io.PipeWriter
+
+	unblocked  chan struct{} // closed on unblocking
+	engineExit chan struct{} // after unblocking, engine will wrap up, close this and exit
+	hasData    chan struct{} // never closed
 
 	unblockOnce sync.Once
 }
@@ -35,25 +38,51 @@ func NewBSBuffer() *BSBuffer {
 	bsb.r, bsb.w = io.Pipe()
 
 	bsb.hasData = make(chan struct{}, 1)
+	bsb.unblocked = make(chan struct{})
 	bsb.engineExit = make(chan struct{})
 	go bsb.engine()
 	return bsb
 }
 
+// # How this is supposed to work #
+// (all operations, except piped ones, are locked)
+//
+// before Unblock:
+//    Write stores data to bufBlocked
+//    engine copies data from bufBlocked, writes to pipe
+//    Read reads from pipe
+// after Unblock:
+//    Write still writes data to bufBlocked
+//    engine will copy data from bufBlocked to bufUnblocked and close `engineExit`
+//    Read reads from pipe
+// after engineExit is closed:
+//    Write writes to bufUnblocked
+//    Read reads from bufUnblocked
+
 func (b *BSBuffer) engine() {
 	for {
 		select {
 		case _ = <-b.hasData:
-			b.Lock()
-			b.bufOut.ReadFrom(&b.bufIn)
-			_, err := b.bufOut.WriteTo(b.w)
-			if b.unblocked || err != nil {
-				b.r.Close()
+			b.mu.Lock()
+			buf, _ := ioutil.ReadAll(&b.bufBlocked)
+			b.mu.Unlock()
+			n, _ := b.w.Write(buf) // blocking, unless Unblock was called
+			select {
+			case _ = <-b.unblocked:
+				b.mu.Lock()
+				// copy from buf whatever wasn't written to the pipe
+				b.bufUnblocked.Write(buf[n:])
+
+				// copy everything from bufBlocked to bufUnblocked
+				// bufBlocked shouldn't be touched after engineExit is closed
+				// and we have the Lock.
+				b.bufUnblocked.Write(b.bufBlocked.Bytes())
+
 				close(b.engineExit)
-				b.Unlock()
+				b.mu.Unlock()
 				return
+			default:
 			}
-			b.Unlock()
 		}
 	}
 }
@@ -62,7 +91,7 @@ func (b *BSBuffer) engine() {
 // If the write end is closed with an error, that error is returned as err; otherwise err is EOF.
 // Supports multiple concurrent goroutines and p is valid forever.
 func (b *BSBuffer) Read(p []byte) (n int, err error) {
-	n, err = b.r.Read(p)
+	n, err = b.r.Read(p) // blocking, unless Unblock was called
 	if err != nil {
 		if n != 0 {
 			// There might be remaining data in underlying buffer, and we want user to
@@ -71,9 +100,9 @@ func (b *BSBuffer) Read(p []byte) (n int, err error) {
 		} else {
 			// Unblocked and no data in engine.
 			// Operate as SafeBuffer
-			b.Lock()
-			n, err = b.bufOut.Read(p)
-			b.Unlock()
+			b.mu.Lock()
+			n, err = b.bufUnblocked.Read(p)
+			b.mu.Unlock()
 		}
 	}
 	return
@@ -87,20 +116,22 @@ func (b *BSBuffer) Write(p []byte) (n int, err error) {
 	if len(p) == 0 {
 		return 0, nil
 	}
-	b.Lock()
-	if b.unblocked {
-		// Wait for engine to exit and operate as Safe Buffer.
-		_ = <-b.engineExit
-		n, err = b.bufOut.Write(p)
-	} else {
+
+	b.mu.Lock()
+	select {
+	case _ = <-b.engineExit:
+		n, err = b.bufUnblocked.Write(p)
+		b.mu.Unlock()
+	default:
 		// Push data to engine and wake it up, if needed.
-		n, err = b.bufIn.Write(p)
+		n, err = b.bufBlocked.Write(p)
 		select {
 		case b.hasData <- struct{}{}:
 		default:
 		}
+		b.mu.Unlock()
 	}
-	b.Unlock()
+
 	return
 }
 
@@ -108,10 +139,16 @@ func (b *BSBuffer) Write(p []byte) (n int, err error) {
 // Unblock() is safe to call multiple times.
 func (b *BSBuffer) Unblock() {
 	b.unblockOnce.Do(func() {
-		b.Lock()
-		b.unblocked = true
+		// closing the pipes will make engine and reads non-blocking
 		b.w.Close()
-		close(b.hasData)
-		b.Unlock()
+		b.r.Close()
+
+		b.mu.Lock()
+		close(b.unblocked)
+		select {
+		case b.hasData <- struct{}{}:
+		default:
+		}
+		b.mu.Unlock()
 	})
 }

+ 0 - 27
vendor/github.com/sergeyfrolov/gotapdance/tapdance/common.go

@@ -127,33 +127,6 @@ var tapDanceSupportedCiphers = []uint16{
 	tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
 }
 
-func forceSupportedCiphersFirst(suites []uint16) []uint16 {
-	swapSuites := func(i, j int) {
-		if i == j {
-			return
-		}
-		tmp := suites[j]
-		suites[j] = suites[i]
-		suites[i] = tmp
-	}
-	lastSupportedCipherIdx := 0
-	for i := range suites {
-		for _, supportedS := range tapDanceSupportedCiphers {
-			if suites[i] == supportedS {
-				swapSuites(i, lastSupportedCipherIdx)
-				lastSupportedCipherIdx += 1
-			}
-		}
-	}
-	alwaysSuggestedSuite := tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
-	for i := range suites {
-		if suites[i] == alwaysSuggestedSuite {
-			return suites
-		}
-	}
-	return append([]uint16{alwaysSuggestedSuite}, suites[lastSupportedCipherIdx:]...)
-}
-
 // How much time to sleep on trying to connect to decoys to prevent overwhelming them
 func sleepBeforeConnect(attempt int) (waitTime <-chan time.Time) {
 	if attempt >= 2 { // return nil for first 2 attempts

+ 10 - 5
vendor/github.com/sergeyfrolov/gotapdance/tapdance/conn_dual.go

@@ -1,6 +1,7 @@
 package tapdance
 
 import (
+	"context"
 	"crypto/rand"
 	"errors"
 	"net"
@@ -19,7 +20,7 @@ type DualConn struct {
 }
 
 // returns TapDance connection that utilizes 2 flows underneath: reader and writer
-func dialSplitFlow(customDialer func(string, string) (net.Conn, error)) (net.Conn, error) {
+func dialSplitFlow(ctx context.Context, customDialer func(context.Context, string, string) (net.Conn, error)) (net.Conn, error) {
 	dualConn := DualConn{sessionId: sessionsTotal.GetAndInc()}
 	stationPubkey := Assets().GetPubkey()
 
@@ -29,7 +30,9 @@ func dialSplitFlow(customDialer func(string, string) (net.Conn, error)) (net.Con
 	rawRConn := makeTdRaw(tagHttpGetIncomplete,
 		stationPubkey[:],
 		remoteConnId[:])
-	rawRConn.customDialer = customDialer
+	if customDialer != nil {
+		rawRConn.TcpDialer = customDialer
+	}
 	rawRConn.sessionId = dualConn.sessionId
 	rawRConn.strIdSuffix = "R"
 
@@ -38,7 +41,7 @@ func dialSplitFlow(customDialer func(string, string) (net.Conn, error)) (net.Con
 	if err != nil {
 		return nil, err
 	}
-	err = dualConn.readerConn.Dial()
+	err = dualConn.readerConn.DialContext(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -57,7 +60,9 @@ func dialSplitFlow(customDialer func(string, string) (net.Conn, error)) (net.Con
 	rawWConn := makeTdRaw(tagHttpPostIncomplete,
 		stationPubkey[:],
 		remoteConnId[:])
-	rawWConn.customDialer = customDialer
+	if customDialer != nil {
+		rawRConn.TcpDialer = customDialer
+	}
 	rawWConn.sessionId = dualConn.sessionId
 	rawWConn.strIdSuffix = "W"
 	rawWConn.decoySpec = rawRConn.decoySpec
@@ -68,7 +73,7 @@ func dialSplitFlow(customDialer func(string, string) (net.Conn, error)) (net.Con
 		dualConn.readerConn.closeWithErrorOnce(err)
 		return nil, err
 	}
-	err = dualConn.writerConn.Dial()
+	err = dualConn.writerConn.DialContext(ctx)
 	if err != nil {
 		dualConn.readerConn.closeWithErrorOnce(err)
 		return nil, err

+ 6 - 8
vendor/github.com/sergeyfrolov/gotapdance/tapdance/conn_flow.go

@@ -7,6 +7,7 @@ TODO: confirm that all writes are recorded towards data limit
 package tapdance
 
 import (
+	"context"
 	"crypto/rand"
 	"encoding/binary"
 	"encoding/hex"
@@ -82,10 +83,10 @@ func makeTdFlow(flow flowType, tdRaw *tdRawConn) (*TapdanceFlowConn, error) {
 
 // Dial establishes direct connection to TapDance station proxy.
 // Users are expected to send HTTP CONNECT request next.
-func (flowConn *TapdanceFlowConn) Dial() error {
+func (flowConn *TapdanceFlowConn) DialContext(ctx context.Context) error {
 	if flowConn.tdRaw.tlsConn == nil {
 		// if still hasn't dialed
-		err := flowConn.tdRaw.Dial()
+		err := flowConn.tdRaw.DialContext(ctx)
 		if err != nil {
 			return err
 		}
@@ -434,7 +435,7 @@ func (flowConn *TapdanceFlowConn) actOnReadError(err error) error {
 			err == io.ErrUnexpectedEOF {
 			Logger().Infoln(flowConn.tdRaw.idStr() + " reconnect: FIN is unexpected")
 		}
-		err = flowConn.tdRaw.Redial()
+		err = flowConn.tdRaw.RedialContext(context.Background())
 		if flowConn.flowType != flowReadOnly {
 			// wake up writer engine
 			select {
@@ -458,11 +459,8 @@ func (flowConn *TapdanceFlowConn) actOnReadError(err error) error {
 			return io.EOF
 		} // else: proceed and exit as a crash
 	}
-	if flowConn.closeErr != nil {
-		return flowConn.closeErr
-	}
-	Logger().Infoln(flowConn.tdRaw.idStr() + " crashing due to " + err.Error())
-	return io.ErrUnexpectedEOF
+
+	return flowConn.closeWithErrorOnce(err)
 }
 
 // Sets read deadline to {when raw connection was establihsed} + {timeout} - {small random value}

+ 26 - 15
vendor/github.com/sergeyfrolov/gotapdance/tapdance/conn_raw.go

@@ -13,6 +13,7 @@ import (
 	"sync"
 	"time"
 
+	"context"
 	"github.com/golang/protobuf/proto"
 	"github.com/refraction-networking/utls"
 )
@@ -28,7 +29,7 @@ type tdRawConn struct {
 	sessionId   uint64
 	strIdSuffix string
 
-	customDialer func(string, string) (net.Conn, error)
+	TcpDialer func(context.Context, string, string) (net.Conn, error)
 
 	decoySpec     pb.TLSDecoySpec
 	establishedAt time.Time
@@ -58,16 +59,16 @@ func makeTdRaw(handshakeType tdTagType,
 	return tdRaw
 }
 
-func (tdRaw *tdRawConn) Redial() error {
-	tdRaw.flowId += 1
-	return tdRaw.dial(true)
+func (tdRaw *tdRawConn) DialContext(ctx context.Context) error {
+	return tdRaw.dial(ctx, false)
 }
 
-func (tdRaw *tdRawConn) Dial() error {
-	return tdRaw.dial(false)
+func (tdRaw *tdRawConn) RedialContext(ctx context.Context) error {
+	tdRaw.flowId += 1
+	return tdRaw.dial(ctx, true)
 }
 
-func (tdRaw *tdRawConn) dial(reconnect bool) error {
+func (tdRaw *tdRawConn) dial(ctx context.Context, reconnect bool) error {
 	var maxConnectionAttempts int
 	var err error
 
@@ -95,6 +96,8 @@ func (tdRaw *tdRawConn) dial(reconnect bool) error {
 		if waitTime := sleepBeforeConnect(i); waitTime != nil {
 			select {
 			case <-waitTime:
+			case <-ctx.Done():
+				return context.Canceled
 			case <-tdRaw.closed:
 				return errors.New("Closed")
 			}
@@ -112,7 +115,7 @@ func (tdRaw *tdRawConn) dial(reconnect bool) error {
 			}
 		}
 
-		err = tdRaw.tryDialOnce(expectedTransition)
+		err = tdRaw.tryDialOnce(ctx, expectedTransition)
 		if err == nil {
 			return err
 		}
@@ -122,10 +125,10 @@ func (tdRaw *tdRawConn) dial(reconnect bool) error {
 	return err
 }
 
-func (tdRaw *tdRawConn) tryDialOnce(expectedTransition pb.S2C_Transition) (err error) {
+func (tdRaw *tdRawConn) tryDialOnce(ctx context.Context, expectedTransition pb.S2C_Transition) (err error) {
 	Logger().Infoln(tdRaw.idStr() + " Attempting to connect to decoy " +
 		tdRaw.decoySpec.GetHostname() + " (" + tdRaw.decoySpec.GetIpv4AddrStr() + ")")
-	err = tdRaw.establishTLStoDecoy()
+	err = tdRaw.establishTLStoDecoy(ctx)
 	if err != nil {
 		Logger().Errorf(tdRaw.idStr() + " establishTLStoDecoy(" +
 			tdRaw.decoySpec.GetHostname() + "," + tdRaw.decoySpec.GetIpv4AddrStr() +
@@ -234,16 +237,23 @@ func (tdRaw *tdRawConn) tryDialOnce(expectedTransition pb.S2C_Transition) (err e
 	return nil
 }
 
-func (tdRaw *tdRawConn) establishTLStoDecoy() (err error) {
+func (tdRaw *tdRawConn) establishTLStoDecoy(ctx context.Context) (err error) {
 	var dialConn net.Conn
-	if tdRaw.customDialer != nil {
-		dialConn, err = tdRaw.customDialer("tcp", tdRaw.decoySpec.GetIpv4AddrStr())
+	deadline, deadlineAlreadySet := ctx.Deadline()
+	if !deadlineAlreadySet {
+		deadline = time.Now().Add(getRandomDuration(deadlineTCPtoDecoyMin, deadlineTCPtoDecoyMax))
+	}
+	childCtx, childCancelFunc := context.WithDeadline(ctx, deadline)
+	defer childCancelFunc()
+
+	if tdRaw.TcpDialer != nil {
+		dialConn, err = tdRaw.TcpDialer(childCtx, "tcp", tdRaw.decoySpec.GetIpv4AddrStr())
 		if err != nil {
 			return err
 		}
 	} else {
-		dialConn, err = net.DialTimeout("tcp", tdRaw.decoySpec.GetIpv4AddrStr(),
-			getRandomDuration(deadlineTCPtoDecoyMin, deadlineTCPtoDecoyMax))
+		d := net.Dialer{}
+		dialConn, err = d.DialContext(childCtx, "tcp", tdRaw.decoySpec.GetIpv4AddrStr())
 		if err != nil {
 			return err
 		}
@@ -271,6 +281,7 @@ func (tdRaw *tdRawConn) establishTLStoDecoy() (err error) {
 		dialConn.Close()
 		return
 	}
+	tdRaw.tlsConn.SetDeadline(deadline)
 	err = tdRaw.tlsConn.Handshake()
 	if err != nil {
 		dialConn.Close()

+ 18 - 14
vendor/github.com/sergeyfrolov/gotapdance/tapdance/dialer.go

@@ -2,6 +2,7 @@ package tapdance
 
 import (
 	"bufio"
+	"context"
 	"errors"
 	"fmt"
 	"net"
@@ -10,12 +11,10 @@ import (
 
 var sessionsTotal CounterUint64
 
-// Dialer contains options for establishing TapDance connection.
+// Dialer contains options and implements advanced functions for establishing TapDance connection.
 type Dialer struct {
-	// TODO?: add Context support(not as a field, it has to "flow through program like river")
-	// https://medium.com/@cep21/how-to-correctly-use-context-context-in-go-1-7-8f2c0fafdf39
 	SplitFlows bool
-	TcpDialer  func(string, string) (net.Conn, error)
+	TcpDialer  func(context.Context, string, string) (net.Conn, error)
 }
 
 // Dial connects to the address on the named network.
@@ -33,6 +32,12 @@ func Dial(network, address string) (net.Conn, error) {
 }
 
 // Dial connects to the address on the named network.
+func (d *Dialer) Dial(network, address string) (net.Conn, error) {
+	return d.DialContext(context.Background(), network, address)
+}
+
+// DialContext connects to the address on the named network using the provided context.
+// Long deadline is strongly advised, since tapdance will try multiple decoys.
 //
 // The only supported network at this time: "tcp".
 // The address has the form "host:port".
@@ -41,7 +46,7 @@ func Dial(network, address string) (net.Conn, error) {
 // To avoid abuse, only certain whitelisted ports are allowed.
 //
 // Example: Dial("tcp", "golang.org:80")
-func (d *Dialer) Dial(network, address string) (net.Conn, error) {
+func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
 	if network != "tcp" {
 		return nil, &net.OpError{Op: "dial", Net: network, Err: net.UnknownNetworkError(network)}
 	}
@@ -50,7 +55,7 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) {
 		return nil, err
 	}
 
-	flow, err := d.DialProxy()
+	flow, err := d.DialProxyContext(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -74,21 +79,20 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) {
 
 // DialProxy establishes direct connection to TapDance station proxy.
 // Users are expected to send HTTP CONNECT request next.
-func DialProxy() (net.Conn, error) {
-	var d Dialer
-	return d.DialProxy()
+func (d *Dialer) DialProxy() (net.Conn, error) {
+	return d.DialProxyContext(context.Background())
 }
 
-// DialProxy establishes direct connection to TapDance station proxy.
+// DialProxy establishes direct connection to TapDance station proxy using the provided context.
 // Users are expected to send HTTP CONNECT request next.
-func (d *Dialer) DialProxy() (net.Conn, error) {
+func (d *Dialer) DialProxyContext(ctx context.Context) (net.Conn, error) {
 	if !d.SplitFlows {
 		flow, err := makeTdFlow(flowBidirectional, nil)
 		if err != nil {
 			return nil, err
 		}
-		flow.tdRaw.customDialer = d.TcpDialer
-		return flow, flow.Dial()
+		flow.tdRaw.TcpDialer = d.TcpDialer
+		return flow, flow.DialContext(ctx)
 	}
-	return dialSplitFlow(d.TcpDialer)
+	return dialSplitFlow(ctx, d.TcpDialer)
 }

+ 8 - 8
vendor/vendor.json

@@ -460,22 +460,22 @@
 			"revisionTime": "2017-01-28T01:21:29Z"
 		},
 		{
-			"checksumSHA1": "KY4600ldPI8LeOagY1P7QAvsJgU=",
+			"checksumSHA1": "Hj4pJ8jepJQ64sTPVJKlBXGC53Y=",
 			"path": "github.com/sergeyfrolov/bsbuffer",
-			"revision": "1049e53e3f9ee6f3ea4d6b2714729563ee493193",
-			"revisionTime": "2017-07-10T02:15:16Z"
+			"revision": "94e85abb850729a5f54f383e8175e62931d04748",
+			"revisionTime": "2018-09-03T21:38:11Z"
 		},
 		{
 			"checksumSHA1": "dOVMxadkUJFjdc8Ed9vbfYwvZzE=",
 			"path": "github.com/sergeyfrolov/gotapdance/protobuf",
-			"revision": "4581c3f01ac46b90ed4b58cce9c0438f732bf915",
-			"revisionTime": "2018-07-17T02:09:26Z"
+			"revision": "2ceeda9fef5bf3609cd3d1b04d4785ffac83d87c",
+			"revisionTime": "2018-09-05T22:38:24Z"
 		},
 		{
-			"checksumSHA1": "Ai6CK0f71QwLKjr2IJ5JUz1RXfg=",
+			"checksumSHA1": "6PTbPuGiX2bJxURtlm4LCLzwZwk=",
 			"path": "github.com/sergeyfrolov/gotapdance/tapdance",
-			"revision": "5b5c507e165050668a074c51fbd5f45544e6c475",
-			"revisionTime": "2018-08-15T19:09:24Z"
+			"revision": "2ceeda9fef5bf3609cd3d1b04d4785ffac83d87c",
+			"revisionTime": "2018-09-05T22:38:24Z"
 		},
 		{
 			"checksumSHA1": "Egp3n8yTaAuVtrA14LJrTWDgkO4=",