|
|
@@ -39,7 +39,14 @@ type client struct {
|
|
|
dialOnce sync.Once
|
|
|
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
|
|
|
|
|
|
- session quic.Session
|
|
|
+ // [Psiphon]
|
|
|
+ // Fix Close-while-dialing race condition by synchronizing access to
|
|
|
+ // client.session and adding a closed flag to indicate if the client was
|
|
|
+ // closed while a dial was in progress.
|
|
|
+ sessionMutex sync.Mutex
|
|
|
+ closed bool
|
|
|
+ session quic.Session
|
|
|
+
|
|
|
headerStream quic.Stream
|
|
|
headerErr *qerr.QuicError
|
|
|
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
|
|
|
@@ -84,15 +91,32 @@ func newClient(
|
|
|
// dial dials the connection
|
|
|
func (c *client) dial() error {
|
|
|
var err error
|
|
|
+ var session quic.Session
|
|
|
if c.dialer != nil {
|
|
|
- c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
|
|
+ session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
|
|
|
} else {
|
|
|
- c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
|
|
+ session, err = dialAddr(c.hostname, c.tlsConf, c.config)
|
|
|
}
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
+ // [Psiphon]
|
|
|
+ // Only this write and the Close reads of c.session require synchronization.
|
|
|
+ // After this point, it's safe to concurrently read c.session as it is not
|
|
|
+ // rewritten.
|
|
|
+ c.sessionMutex.Lock()
|
|
|
+ closed := c.closed
|
|
|
+ if !closed {
|
|
|
+ c.session = session
|
|
|
+ }
|
|
|
+ c.sessionMutex.Unlock()
|
|
|
+ if closed {
|
|
|
+ session.Close()
|
|
|
+ return errors.New("closed while dialing")
|
|
|
+ }
|
|
|
+ // [Psiphon]
|
|
|
+
|
|
|
// once the version has been negotiated, open the header stream
|
|
|
c.headerStream, err = c.session.OpenStream()
|
|
|
if err != nil {
|
|
|
@@ -276,18 +300,34 @@ func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (e
|
|
|
}
|
|
|
|
|
|
func (c *client) closeWithError(e error) error {
|
|
|
- if c.session == nil {
|
|
|
+
|
|
|
+ // [Psiphon]
|
|
|
+ c.sessionMutex.Lock()
|
|
|
+ session := c.session
|
|
|
+ c.closed = true
|
|
|
+ c.sessionMutex.Unlock()
|
|
|
+ // [Psiphon]
|
|
|
+
|
|
|
+ if session == nil {
|
|
|
return nil
|
|
|
}
|
|
|
- return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
|
|
|
+ return session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
|
|
|
}
|
|
|
|
|
|
// Close closes the client
|
|
|
func (c *client) Close() error {
|
|
|
- if c.session == nil {
|
|
|
+
|
|
|
+ // [Psiphon]
|
|
|
+ c.sessionMutex.Lock()
|
|
|
+ session := c.session
|
|
|
+ c.closed = true
|
|
|
+ c.sessionMutex.Unlock()
|
|
|
+ // [Psiphon]
|
|
|
+
|
|
|
+ if session == nil {
|
|
|
return nil
|
|
|
}
|
|
|
- return c.session.Close()
|
|
|
+ return session.Close()
|
|
|
}
|
|
|
|
|
|
// copied from net/transport.go
|