Rod Hynes 4 месяцев назад
Родитель
Сommit
16222863b5

+ 9 - 0
psiphon/common/crypto/ssh/connection.go

@@ -94,6 +94,15 @@ type connection struct {
 }
 
 func (c *connection) Close() error {
+
+	// [Psiphon]
+	// Ensure handshakeTransport.interrupt is invoked.
+	// handshakeTransport.Close also closes the underlying network
+	// connection, so c.sshConn.conn.Close in not necessary in this case.
+	if c.transport != nil {
+		return c.transport.Close()
+	}
+
 	return c.sshConn.conn.Close()
 }
 

+ 49 - 1
psiphon/common/crypto/ssh/handshake.go

@@ -133,6 +133,11 @@ type handshakeTransport struct {
 	// strictMode indicates if the other side of the handshake indicated
 	// that we should be following the strict KEX protocol restrictions.
 	strictMode bool
+
+	// [Psiphon]
+	// Unblocks readLoop blocked on sending to incoming channel.
+	doSignalCloseReadLoop sync.Once
+	signalCloseReadLoop   chan struct{}
 }
 
 type pendingKex struct {
@@ -150,6 +155,9 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion,
 		startKex:      make(chan *pendingKex),
 		kexLoopDone:   make(chan struct{}),
 
+		// [Psiphon]
+		signalCloseReadLoop: make(chan struct{}),
+
 		config: config,
 	}
 	t.writeCond = sync.NewCond(&t.mu)
@@ -249,7 +257,21 @@ func (t *handshakeTransport) readLoop() {
 		if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) {
 			continue
 		}
-		t.incoming <- p
+
+		// [Psiphon]
+		// Add a closed signal case to interrupt readLoop when blocked on
+		// sending to incoming.
+		closed := false
+		select {
+		case t.incoming <- p:
+		case <-t.signalCloseReadLoop:
+			closed = true
+		}
+		if closed {
+			t.readError = io.EOF
+			close(t.incoming)
+			break
+		}
 	}
 
 	// Stop writers too.
@@ -1024,6 +1046,10 @@ func (t *handshakeTransport) Close() error {
 	// and close t.startKex, which will shut down kexLoop if running.
 	err := t.conn.Close()
 
+	// [Psiphon]
+	// Interrupt any blocked readers or writers.
+	t.interrupt(err)
+
 	// Wait for the kexLoop goroutine to complete.
 	// At that point we know that the readLoop goroutine is complete too,
 	// because kexLoop itself waits for readLoop to close the startKex channel.
@@ -1032,6 +1058,28 @@ func (t *handshakeTransport) Close() error {
 	return err
 }
 
+// [Psiphon]
+// interrupt unblocks any goroutines waiting on readLoop/writePacket when
+// the underlying transport is shutting down and a KEX may be in progress.
+func (t *handshakeTransport) interrupt(err error) {
+
+	if err == nil {
+		err = io.EOF
+	}
+
+	// Interrupt readLoop if blocked on sending to t.incoming.
+	t.doSignalCloseReadLoop.Do(func() {
+		close(t.signalCloseReadLoop)
+	})
+
+	// Interrupt writePacket if blocked on t.writeCond.Wait awaiting a KEX.
+	// Call recordWriteError to ensure t.writeError is set, if not already;
+	// and unconditionally Broadcast as well, in case the condition in
+	// recordWriteError skips that.
+	t.recordWriteError(err)
+	t.writeCond.Broadcast()
+}
+
 func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 	if debugHandshake {
 		log.Printf("%s entered key exchange", t.id())

+ 6 - 0
psiphon/common/crypto/ssh/mux.go

@@ -190,6 +190,12 @@ func (m *mux) loop() {
 		err = m.onePacket()
 	}
 
+	// [Psiphon]
+	// Interrupt any blocked readers or writers before closing channels.
+	if t, ok := m.conn.(*handshakeTransport); ok {
+		t.interrupt(err)
+	}
+
 	for _, ch := range m.chanList.dropAll() {
 		ch.close()
 	}