Преглед изворни кода

Fix: use correct dial context in QUICTransporter

Rod Hynes пре 6 година
родитељ
комит
52de11dbab
2 измењених фајлова са 43 додато и 16 уклоњено
  1. 22 6
      psiphon/common/quic/quic.go
  2. 21 10
      psiphon/meekConn.go

+ 22 - 6
psiphon/common/quic/quic.go

@@ -502,25 +502,27 @@ func (conn *loggingPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
 // CloseIdleConnections.
 type QUICTransporter struct {
 	*h2quic.RoundTripper
-	ctx                  context.Context
-	udpDialer            func() (net.PacketConn, *net.UDPAddr, error)
+	udpDialer            func(ctx context.Context) (net.PacketConn, *net.UDPAddr, error)
 	quicSNIAddress       string
 	negotiateQUICVersion string
 	packetConn           atomic.Value
+
+	mutex sync.Mutex
+	ctx   context.Context
 }
 
 // NewQUICTransporter creates a new QUICTransporter.
 func NewQUICTransporter(
 	ctx context.Context,
-	udpDialer func() (net.PacketConn, *net.UDPAddr, error),
+	udpDialer func(ctx context.Context) (net.PacketConn, *net.UDPAddr, error),
 	quicSNIAddress string,
 	negotiateQUICVersion string) *QUICTransporter {
 
 	t := &QUICTransporter{
-		ctx:                  ctx,
 		udpDialer:            udpDialer,
 		quicSNIAddress:       quicSNIAddress,
 		negotiateQUICVersion: negotiateQUICVersion,
+		ctx:                  ctx,
 	}
 
 	t.RoundTripper = &h2quic.RoundTripper{Dial: t.dialQUIC}
@@ -528,6 +530,13 @@ func NewQUICTransporter(
 	return t
 }
 
+func (t *QUICTransporter) SetRequestContext(ctx context.Context) {
+	// Note: can't use sync.Value since underlying type of ctx changes.
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	t.ctx = ctx
+}
+
 // CloseIdleConnections wraps h2quic.RoundTripper.Close, which provides the
 // necessary functionality for psiphon.transporter as used by
 // psiphon.MeekConn. Note that, unlike http.Transport.CloseIdleConnections,
@@ -564,13 +573,20 @@ func (t *QUICTransporter) dialQUIC(
 		Versions:         versions,
 	}
 
-	packetConn, remoteAddr, err := t.udpDialer()
+	t.mutex.Lock()
+	ctx := t.ctx
+	t.mutex.Unlock()
+	if ctx == nil {
+		ctx = context.Background()
+	}
+
+	packetConn, remoteAddr, err := t.udpDialer(ctx)
 	if err != nil {
 		return nil, common.ContextError(err)
 	}
 
 	session, err := quic_go.DialContext(
-		t.ctx,
+		ctx,
 		packetConn,
 		remoteAddr,
 		t.quicSNIAddress,

+ 21 - 10
psiphon/meekConn.go

@@ -226,9 +226,9 @@ func DialMeek(
 
 		scheme = "https"
 
-		udpDialer := func() (net.PacketConn, *net.UDPAddr, error) {
+		udpDialer := func(ctx context.Context) (net.PacketConn, *net.UDPAddr, error) {
 			packetConn, remoteAddr, err := NewUDPConn(
-				runCtx,
+				ctx,
 				meekConfig.DialAddress,
 				dialConfig)
 			if err != nil {
@@ -241,7 +241,7 @@ func DialMeek(
 		quicDialSNIAddress := fmt.Sprintf("%s:%s", meekConfig.SNIServerName, port)
 
 		transport = quic.NewQUICTransporter(
-			runCtx,
+			ctx,
 			udpDialer,
 			quicDialSNIAddress,
 			meekConfig.QUICVersion)
@@ -531,8 +531,10 @@ func DialMeek(
 type cachedTLSDialer struct {
 	usedCachedConn int32
 	cachedConn     net.Conn
-	requestCtx     atomic.Value
 	dialer         Dialer
+
+	mutex      sync.Mutex
+	requestCtx context.Context
 }
 
 func newCachedTLSDialer(cachedConn net.Conn, dialer Dialer) *cachedTLSDialer {
@@ -543,7 +545,10 @@ func newCachedTLSDialer(cachedConn net.Conn, dialer Dialer) *cachedTLSDialer {
 }
 
 func (c *cachedTLSDialer) setRequestContext(requestCtx context.Context) {
-	c.requestCtx.Store(requestCtx)
+	// Note: not using sync.Value since underlying type of requestCtx may change.
+	c.mutex.Lock()
+	defer c.mutex.Unlock()
+	c.requestCtx = requestCtx
 }
 
 func (c *cachedTLSDialer) dial(network, addr string) (net.Conn, error) {
@@ -552,10 +557,14 @@ func (c *cachedTLSDialer) dial(network, addr string) (net.Conn, error) {
 		c.cachedConn = nil
 		return conn, nil
 	}
-	ctx := c.requestCtx.Load().(context.Context)
+
+	c.mutex.Lock()
+	ctx := c.requestCtx
+	c.mutex.Unlock()
 	if ctx == nil {
 		ctx = context.Background()
 	}
+
 	return c.dialer(ctx, network, addr)
 }
 
@@ -656,7 +665,7 @@ func (meek *MeekConn) RoundTrip(
 	// Note:
 	//
 	// - multiple, concurrent RoundTrip calls are unsafe due to the
-	//   meek.cachedTLSDialer.setRequestContext call in newRequest
+	//   setRequestContext calls in newRequest.
 	//
 	// - concurrent Close and RoundTrip calls are unsafe as Close
 	//   does not synchronize with RoundTrip before calling
@@ -968,7 +977,7 @@ func (r *readCloseSignaller) AwaitClosed() bool {
 // tripper modes.
 //
 // newRequest is not safe for concurrent calls due to its use of
-// cachedTLSDialer.setRequestContext.
+// setRequestContext.
 //
 // The caller must call the returned cancelFunc.
 func (meek *MeekConn) newRequest(
@@ -990,8 +999,10 @@ func (meek *MeekConn) newRequest(
 			meek.clientParameters.Get().Duration(parameters.MeekRoundTripTimeout))
 	}
 
-	// Ensure TLS dials are made within the current request context.
-	if meek.cachedTLSDialer != nil {
+	// Ensure dials are made within the current request context.
+	if meek.isQUIC {
+		meek.transport.(*quic.QUICTransporter).SetRequestContext(requestCtx)
+	} else if meek.cachedTLSDialer != nil {
 		meek.cachedTLSDialer.setRequestContext(requestCtx)
 	}