Browse Source

Fix: fragmentor now passes CloseWrite through to underlying conn

Rod Hynes 7 years ago
parent
commit
55180feb7e

+ 7 - 0
psiphon/common/fragmentor/fragmentor.go

@@ -341,6 +341,13 @@ func (c *Conn) Write(buffer []byte) (int, error) {
 	return totalBytesWritten, nil
 }
 
+func (c *Conn) CloseWrite() error {
+	if closeWriter, ok := c.Conn.(common.CloseWriter); ok {
+		return closeWriter.CloseWrite()
+	}
+	return common.ContextError(errors.New("underlying conn is not a CloseWriter"))
+}
+
 func (c *Conn) Close() (err error) {
 	if !atomic.CompareAndSwapInt32(&c.isClosed, 0, 1) {
 		return nil

+ 6 - 0
psiphon/common/net.go

@@ -43,6 +43,12 @@ type Closer interface {
 	IsClosed() bool
 }
 
+// CloseWriter defines the interface to a type, typically
+// a net.TCPConn, that implements CloseWrite.
+type CloseWriter interface {
+	CloseWrite() error
+}
+
 // TerminateHTTPConnection sends a 404 response to a client and also closes
 // the persistent connection.
 func TerminateHTTPConnection(

+ 12 - 7
psiphon/common/tapdance/tapdance.go

@@ -154,6 +154,15 @@ func (manager *dialManager) dial(ctx context.Context, network, address string) (
 		return nil, common.ContextError(err)
 	}
 
+	// Fail immediately if CloseWrite isn't available in the underlying dialed
+	// conn. The equivalent check in managedConn.CloseWrite isn't fatal and
+	// tapdance will run in a degraded state.
+	// Limitation: if the underlying conn _also_ passes through CloseWrite, this
+	// check may be insufficient.
+	if _, ok := conn.(common.CloseWriter); !ok {
+		return nil, common.ContextError(errors.New("underlying conn is not a CloseWriter"))
+	}
+
 	conn = &managedConn{
 		Conn:    conn,
 		manager: manager,
@@ -184,17 +193,13 @@ type managedConn struct {
 	manager *dialManager
 }
 
-type closeWriter interface {
-	CloseWrite() error
-}
-
 // CloseWrite exposes the net.TCPConn.CloseWrite() functionality
 // required by tapdance.
 func (conn *managedConn) CloseWrite() error {
-	if closeWriter, ok := conn.Conn.(closeWriter); ok {
+	if closeWriter, ok := conn.Conn.(common.CloseWriter); ok {
 		return closeWriter.CloseWrite()
 	}
-	return common.ContextError(errors.New("dialedConn is not a closeWriter"))
+	return common.ContextError(errors.New("underlying conn is not a CloseWriter"))
 }
 
 func (conn *managedConn) Close() error {
@@ -227,7 +232,7 @@ func (conn *tapdanceConn) IsClosed() bool {
 //
 // The Tapdance station config assets are read from dataDirectory/"tapdance".
 // When no config is found, default assets are paved. ctx is expected to have
-// a timeout for the  dial.
+// a timeout for the dial.
 func Dial(
 	ctx context.Context,
 	dataDirectory string,