|
|
@@ -26,6 +26,7 @@ import (
|
|
|
"net"
|
|
|
"net/textproto"
|
|
|
"strconv"
|
|
|
+ "sync/atomic"
|
|
|
"time"
|
|
|
|
|
|
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
|
|
|
@@ -119,9 +120,12 @@ type HTTPNormalizer struct {
|
|
|
// ValidateMeekCookieResult stores the result from calling
|
|
|
// validateMeekCookie.
|
|
|
ValidateMeekCookieResult []byte
|
|
|
- // passthrough is set if the normalizer has established a passthrough
|
|
|
- // session.
|
|
|
- passthrough bool
|
|
|
+ // passthrough is set to 1 if the normalizer has established a passthrough
|
|
|
+ // session; otherwise 0.
|
|
|
+ // Note: may be accessed concurrently so must be get and set atomically.
|
|
|
+ // E.g. the net.Conn interface methods implemented by HTTPNormalizer may be
|
|
|
+ // called concurrent to each other.
|
|
|
+ passthrough int32
|
|
|
// passthroughDialer is used to establish any passthrough sessions.
|
|
|
passthroughDialer func(network, address string) (net.Conn, error)
|
|
|
// passthroughAddress is the passthrough address that will be used for any
|
|
|
@@ -166,7 +170,7 @@ func NewHTTPNormalizer(conn net.Conn) *HTTPNormalizer {
|
|
|
// Warning: Does not handle chunked encoding. Must be called synchronously.
|
|
|
func (t *HTTPNormalizer) Read(buffer []byte) (int, error) {
|
|
|
|
|
|
- if t.passthrough {
|
|
|
+ if atomic.LoadInt32(&t.passthrough) == 1 {
|
|
|
return 0, io.EOF
|
|
|
}
|
|
|
|
|
|
@@ -544,7 +548,7 @@ func (t *HTTPNormalizer) startPassthrough(tunnelError error, logFields map[strin
|
|
|
|
|
|
go passthrough(t.Conn, t.passthroughAddress, t.passthroughDialer, t.b.Bytes())
|
|
|
|
|
|
- t.passthrough = true
|
|
|
+ atomic.StoreInt32(&t.passthrough, 1)
|
|
|
}
|
|
|
|
|
|
func passthrough(conn net.Conn, address string, dialer func(network, address string) (net.Conn, error), buf []byte) {
|
|
|
@@ -589,35 +593,35 @@ func passthrough(conn net.Conn, address string, dialer func(network, address str
|
|
|
}
|
|
|
|
|
|
func (t *HTTPNormalizer) Write(b []byte) (n int, err error) {
|
|
|
- if t.passthrough {
|
|
|
+ if atomic.LoadInt32(&t.passthrough) == 1 {
|
|
|
return 0, ErrPassthroughActive
|
|
|
}
|
|
|
return t.Conn.Write(b)
|
|
|
}
|
|
|
|
|
|
func (t *HTTPNormalizer) Close() error {
|
|
|
- if t.passthrough {
|
|
|
+ if atomic.LoadInt32(&t.passthrough) == 1 {
|
|
|
return nil
|
|
|
}
|
|
|
return t.Conn.Close()
|
|
|
}
|
|
|
|
|
|
func (t *HTTPNormalizer) SetDeadline(tt time.Time) error {
|
|
|
- if t.passthrough {
|
|
|
+ if atomic.LoadInt32(&t.passthrough) == 1 {
|
|
|
return nil
|
|
|
}
|
|
|
return t.Conn.SetDeadline(tt)
|
|
|
}
|
|
|
|
|
|
func (t *HTTPNormalizer) SetReadDeadline(tt time.Time) error {
|
|
|
- if t.passthrough {
|
|
|
+ if atomic.LoadInt32(&t.passthrough) == 1 {
|
|
|
return nil
|
|
|
}
|
|
|
return t.Conn.SetReadDeadline(tt)
|
|
|
}
|
|
|
|
|
|
func (t *HTTPNormalizer) SetWriteDeadline(tt time.Time) error {
|
|
|
- if t.passthrough {
|
|
|
+ if atomic.LoadInt32(&t.passthrough) == 1 {
|
|
|
return nil
|
|
|
}
|
|
|
return t.Conn.SetWriteDeadline(tt)
|