Просмотр исходного кода

Fix: passthrough may be accessed concurrently

mirokuratczyk 2 лет назад
Родитель
Сommit
746c3a211d

+ 14 - 10
psiphon/common/transforms/httpNormalizer.go

@@ -26,6 +26,7 @@ import (
 	"net"
 	"net/textproto"
 	"strconv"
+	"sync/atomic"
 	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
@@ -119,9 +120,12 @@ type HTTPNormalizer struct {
 	// ValidateMeekCookieResult stores the result from calling
 	// validateMeekCookie.
 	ValidateMeekCookieResult []byte
-	// passthrough is set if the normalizer has established a passthrough
-	// session.
-	passthrough bool
+	// passthrough is set to 1 if the normalizer has established a passthrough
+	// session; otherwise 0.
+	// Note: may be accessed concurrently so must be get and set atomically.
+	// E.g. the net.Conn interface methods implemented by HTTPNormalizer may be
+	// called concurrent to each other.
+	passthrough int32
 	// passthroughDialer is used to establish any passthrough sessions.
 	passthroughDialer func(network, address string) (net.Conn, error)
 	// passthroughAddress is the passthrough address that will be used for any
@@ -166,7 +170,7 @@ func NewHTTPNormalizer(conn net.Conn) *HTTPNormalizer {
 // Warning: Does not handle chunked encoding. Must be called synchronously.
 func (t *HTTPNormalizer) Read(buffer []byte) (int, error) {
 
-	if t.passthrough {
+	if atomic.LoadInt32(&t.passthrough) == 1 {
 		return 0, io.EOF
 	}
 
@@ -544,7 +548,7 @@ func (t *HTTPNormalizer) startPassthrough(tunnelError error, logFields map[strin
 
 	go passthrough(t.Conn, t.passthroughAddress, t.passthroughDialer, t.b.Bytes())
 
-	t.passthrough = true
+	atomic.StoreInt32(&t.passthrough, 1)
 }
 
 func passthrough(conn net.Conn, address string, dialer func(network, address string) (net.Conn, error), buf []byte) {
@@ -589,35 +593,35 @@ func passthrough(conn net.Conn, address string, dialer func(network, address str
 }
 
 func (t *HTTPNormalizer) Write(b []byte) (n int, err error) {
-	if t.passthrough {
+	if atomic.LoadInt32(&t.passthrough) == 1 {
 		return 0, ErrPassthroughActive
 	}
 	return t.Conn.Write(b)
 }
 
 func (t *HTTPNormalizer) Close() error {
-	if t.passthrough {
+	if atomic.LoadInt32(&t.passthrough) == 1 {
 		return nil
 	}
 	return t.Conn.Close()
 }
 
 func (t *HTTPNormalizer) SetDeadline(tt time.Time) error {
-	if t.passthrough {
+	if atomic.LoadInt32(&t.passthrough) == 1 {
 		return nil
 	}
 	return t.Conn.SetDeadline(tt)
 }
 
 func (t *HTTPNormalizer) SetReadDeadline(tt time.Time) error {
-	if t.passthrough {
+	if atomic.LoadInt32(&t.passthrough) == 1 {
 		return nil
 	}
 	return t.Conn.SetReadDeadline(tt)
 }
 
 func (t *HTTPNormalizer) SetWriteDeadline(tt time.Time) error {
-	if t.passthrough {
+	if atomic.LoadInt32(&t.passthrough) == 1 {
 		return nil
 	}
 	return t.Conn.SetWriteDeadline(tt)

+ 20 - 5
psiphon/common/transforms/httpTransformer_test.go

@@ -585,21 +585,36 @@ func (c *testConn) Close() error {
 }
 
 func (c *testConn) LocalAddr() net.Addr {
-	return c.Conn.LocalAddr()
+	if c.Conn != nil {
+		return c.Conn.LocalAddr()
+	}
+	return &net.TCPAddr{}
 }
 
 func (c *testConn) RemoteAddr() net.Addr {
-	return c.Conn.RemoteAddr()
+	if c.Conn != nil {
+		return c.Conn.RemoteAddr()
+	}
+	return &net.TCPAddr{}
 }
 
 func (c *testConn) SetDeadline(t time.Time) error {
-	return c.Conn.SetDeadline(t)
+	if c.Conn != nil {
+		return c.Conn.SetDeadline(t)
+	}
+	return nil
 }
 
 func (c *testConn) SetReadDeadline(t time.Time) error {
-	return c.Conn.SetReadDeadline(t)
+	if c.Conn != nil {
+		return c.Conn.SetReadDeadline(t)
+	}
+	return nil
 }
 
 func (c *testConn) SetWriteDeadline(t time.Time) error {
-	return c.Conn.SetWriteDeadline(t)
+	if c.Conn != nil {
+		return c.Conn.SetWriteDeadline(t)
+	}
+	return nil
 }