Przeglądaj źródła

Explicitly guard against SetNoticeWriter clobbering an active setting

Rod Hynes 7 miesięcy temu
rodzic
commit
14f2a20e0b

+ 5 - 3
ClientLibrary/clientlib/clientlib.go

@@ -24,7 +24,6 @@ import (
 	"encoding/json"
 	std_errors "errors"
 	"fmt"
-	"io"
 	"net"
 	"path/filepath"
 	"sync"
@@ -154,7 +153,7 @@ func StartTunnel(
 
 	// Set up notice handling. It is important to do this before config operations, as
 	// otherwise they will write notices to stderr.
-	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
+	err := psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 		func(notice []byte) {
 			var event NoticeEvent
 			err := json.Unmarshal(notice, &event)
@@ -193,6 +192,9 @@ func StartTunnel(
 				noticeReceiver(event)
 			}
 		}))
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
 
 	// Create a cancelable context that will be used for stopping the tunnel
 	tunnelCtx, cancelTunnelCtx := context.WithCancel(ctx)
@@ -212,7 +214,7 @@ func StartTunnel(
 		started.Store(false)
 		// Clear our notice receiver, as it is no longer needed and we should let it be
 		// garbage-collected.
-		psiphon.SetNoticeWriter(io.Discard)
+		psiphon.ResetNoticeWriter()
 	}
 
 	defer func() {

+ 7 - 1
ClientLibrary/clientlib/clientlib_test.go

@@ -449,10 +449,14 @@ func TestPsiphonTunnel_Dial(t *testing.T) {
 func TestStartTunnelNoOutput(t *testing.T) {
 	// Before starting the tunnel, set up a notice receiver. If it receives anything at
 	// all, that means that it would have been printed to stderr.
-	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
+	err := psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 		func(notice []byte) {
 			t.Fatalf("Received notice: %v", string(notice))
 		}))
+	if err != nil {
+		t.Fatalf("psiphon.SetNoticeWriter failed: %v", err)
+	}
+	defer psiphon.ResetNoticeWriter()
 
 	configJSON := setupConfig(t, false)
 
@@ -462,6 +466,8 @@ func TestStartTunnelNoOutput(t *testing.T) {
 	}
 	defer os.RemoveAll(testDataDirName)
 
+	psiphon.ResetNoticeWriter()
+
 	ctx := context.Background()
 
 	tunnel, err := StartTunnel(

+ 6 - 1
ConsoleClient/main.go

@@ -138,7 +138,12 @@ func main() {
 	if formatNotices {
 		noticeWriter = psiphon.NewNoticeConsoleRewriter(noticeWriter)
 	}
-	psiphon.SetNoticeWriter(noticeWriter)
+	err := psiphon.SetNoticeWriter(noticeWriter)
+	if err != nil {
+		fmt.Printf("error setting notice writer: %s\n", err)
+		os.Exit(1)
+	}
+	defer psiphon.ResetNoticeWriter()
 
 	// Handle required config file parameter
 

+ 18 - 13
MobileLibrary/psi/psi.go

@@ -27,8 +27,6 @@ package psi
 import (
 	"context"
 	"encoding/json"
-	"fmt"
-	"os"
 	"path/filepath"
 	"strings"
 	"sync"
@@ -36,6 +34,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/buildinfo"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
 )
 
@@ -134,7 +133,7 @@ func Start(
 	defer controllerMutex.Unlock()
 
 	if controller != nil {
-		return fmt.Errorf("already started")
+		return errors.TraceNew("already started")
 	}
 
 	// Clients may toggle Stop/Start immediately to apply new config settings
@@ -154,7 +153,7 @@ func Start(
 
 	config, err := psiphon.LoadConfig([]byte(configJson))
 	if err != nil {
-		return fmt.Errorf("error loading configuration file: %s", err)
+		return errors.Trace(err)
 	}
 
 	// Set up callbacks.
@@ -179,13 +178,16 @@ func Start(
 
 	err = config.Commit(true)
 	if err != nil {
-		return fmt.Errorf("error committing configuration file: %s", err)
+		return errors.Trace(err)
 	}
 
-	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
+	err = psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 		func(notice []byte) {
 			wrappedProvider.Notice(string(notice))
 		}))
+	if err != nil {
+		return errors.Trace(err)
+	}
 
 	// BuildInfo is a diagnostic notice, so emit only after config.Commit
 	// sets EmitDiagnosticNotices.
@@ -194,7 +196,7 @@ func Start(
 
 	err = psiphon.OpenDataStore(config)
 	if err != nil {
-		return fmt.Errorf("error initializing datastore: %s", err)
+		return errors.Trace(err)
 	}
 
 	controllerCtx, stopController = context.WithCancel(context.Background())
@@ -236,7 +238,7 @@ func Start(
 		stopController()
 		embeddedServerListWaitGroup.Wait()
 		psiphon.CloseDataStore()
-		return fmt.Errorf("error initializing controller: %s", err)
+		return errors.Trace(err)
 	}
 
 	controllerWaitGroup = new(sync.WaitGroup)
@@ -264,7 +266,7 @@ func Stop() {
 		stopController = nil
 		controllerWaitGroup = nil
 		// Allow the provider to be garbage collected.
-		psiphon.SetNoticeWriter(os.Stderr)
+		psiphon.ResetNoticeWriter()
 	}
 }
 
@@ -425,15 +427,18 @@ func StartSendFeedback(
 	sendFeedbackSetNoticeWriter = noticeHandler != nil
 
 	if sendFeedbackSetNoticeWriter {
-		psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
+		err := psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 			func(notice []byte) {
 				noticeHandler.Notice(string(notice))
 			}))
+		if err != nil {
+			return errors.Trace(err)
+		}
 	}
 
 	config, err := psiphon.LoadConfig([]byte(configJson))
 	if err != nil {
-		return fmt.Errorf("error loading configuration file: %s", err)
+		return errors.Trace(err)
 	}
 
 	// Set up callbacks.
@@ -461,7 +466,7 @@ func StartSendFeedback(
 
 	err = config.Commit(true)
 	if err != nil {
-		return fmt.Errorf("error committing configuration file: %s", err)
+		return errors.Trace(err)
 	}
 
 	sendFeedbackWaitGroup = new(sync.WaitGroup)
@@ -493,7 +498,7 @@ func StopSendFeedback() {
 		sendFeedbackWaitGroup = nil
 		if sendFeedbackSetNoticeWriter {
 			// Allow the notice handler to be garbage collected.
-			psiphon.SetNoticeWriter(os.Stderr)
+			psiphon.ResetNoticeWriter()
 		}
 		sendFeedbackSetNoticeWriter = false
 	}

+ 13 - 1
psiphon/controller_test.go

@@ -482,6 +482,13 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 
 	configJSON, _ = json.Marshal(modifyConfig)
 
+	// Don't print initial config setup notices
+	err = SetNoticeWriter(io.Discard)
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer ResetNoticeWriter()
+
 	config, err := LoadConfig(configJSON)
 	if err != nil {
 		t.Fatalf("error processing configuration file: %s", err)
@@ -559,7 +566,8 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 	var clientUpgradeDownloadedBytesCount int32
 	var remoteServerListDownloadedBytesCount int32
 
-	SetNoticeWriter(NewNoticeReceiver(
+	ResetNoticeWriter()
+	err = SetNoticeWriter(NewNoticeReceiver(
 		func(notice []byte) {
 			// TODO: log notices without logging server IPs:
 			//fmt.Fprintf(os.Stderr, "%s\n", string(notice))
@@ -650,6 +658,10 @@ func controllerRun(t *testing.T, runConfig *controllerRunConfig) {
 				}
 			}
 		}))
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer ResetNoticeWriter()
 
 	// Run controller, which establishes tunnels
 

+ 5 - 1
psiphon/dataStoreRecovery_test.go

@@ -82,7 +82,7 @@ func TestBoltResiliency(t *testing.T) {
 	noticeResetDatastore := make(chan struct{}, 1)
 	noticeDatastoreFailed := make(chan struct{}, 1)
 
-	SetNoticeWriter(NewNoticeReceiver(
+	err = SetNoticeWriter(NewNoticeReceiver(
 		func(notice []byte) {
 
 			noticeType, payload, err := GetNotice(notice)
@@ -127,6 +127,10 @@ func TestBoltResiliency(t *testing.T) {
 				fmt.Printf("%s\n", string(notice))
 			}
 		}))
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer ResetNoticeWriter()
 
 	drainNoticeChannel := func(channel chan struct{}) {
 		for {

+ 11 - 2
psiphon/dialParameters_test.go

@@ -24,6 +24,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"os"
 	"reflect"
@@ -67,7 +68,11 @@ func runDialParametersAndReplay(t *testing.T, tunnelProtocol string) {
 	}
 	defer os.RemoveAll(testDataDirName)
 
-	SetNoticeWriter(ioutil.Discard)
+	err = SetNoticeWriter(io.Discard)
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer ResetNoticeWriter()
 
 	clientConfig := &Config{
 		PropagationChannelId: "0",
@@ -868,7 +873,11 @@ func TestLimitTunnelDialPortNumbers(t *testing.T) {
 	}
 	defer os.RemoveAll(testDataDirName)
 
-	SetNoticeWriter(ioutil.Discard)
+	err = SetNoticeWriter(io.Discard)
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer ResetNoticeWriter()
 
 	clientConfig := &Config{
 		PropagationChannelId: "0",

+ 6 - 1
psiphon/exchange_test.go

@@ -22,6 +22,7 @@ package psiphon
 import (
 	"encoding/base64"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"os"
 	"testing"
@@ -41,7 +42,11 @@ func TestServerEntryExchange(t *testing.T) {
 	}
 	defer os.RemoveAll(testDataDirName)
 
-	SetNoticeWriter(ioutil.Discard)
+	err = SetNoticeWriter(io.Discard)
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer ResetNoticeWriter()
 
 	// Generate signing and exchange key material
 

+ 5 - 1
psiphon/limitProtocols_test.go

@@ -49,7 +49,7 @@ func TestLimitTunnelProtocols(t *testing.T) {
 	initialConnectingCount := 0
 	connectingCount := 0
 
-	SetNoticeWriter(NewNoticeReceiver(
+	err = SetNoticeWriter(NewNoticeReceiver(
 		func(notice []byte) {
 			noticeType, payload, err := GetNotice(notice)
 			if err != nil {
@@ -93,6 +93,10 @@ func TestLimitTunnelProtocols(t *testing.T) {
 				}
 			}
 		}))
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer ResetNoticeWriter()
 
 	clientConfigJSON := `
     {

+ 5 - 1
psiphon/memory_test/memory_test.go

@@ -152,7 +152,7 @@ func runMemoryTest(t *testing.T, testMode int) {
 	memInspectionFrequency := 10 * time.Second
 	maxInuseBytes := uint64(10 * 1024 * 1024)
 
-	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
+	err = psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 		func(notice []byte) {
 			noticeType, payload, err := psiphon.GetNotice(notice)
 			if err != nil {
@@ -192,6 +192,10 @@ func runMemoryTest(t *testing.T, testMode int) {
 				}
 			}
 		}))
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer psiphon.ResetNoticeWriter()
 
 	startController := func() {
 		controller, err = psiphon.NewController(config)

+ 19 - 1
psiphon/notice.go

@@ -112,12 +112,30 @@ func GetEmitNetworkParameters() bool {
 // - "timestamp": UTC timezone, RFC3339Milli format timestamp for notice event
 //
 // See the Notice* functions for details on each notice meaning and payload.
-func SetNoticeWriter(writer io.Writer) {
+//
+// SetNoticeWriter does not replace the writer and returns an error if a
+// non-default writer is already set.
+func SetNoticeWriter(writer io.Writer) error {
 
 	singletonNoticeLogger.mutex.Lock()
 	defer singletonNoticeLogger.mutex.Unlock()
 
+	if f, ok := singletonNoticeLogger.writer.(*os.File); !ok || f != os.Stderr {
+		return errors.TraceNew("notice writer already set")
+	}
+
 	singletonNoticeLogger.writer = writer
+
+	return nil
+}
+
+// ResetNoticeWriter resets the notice write to the default, stderr.
+func ResetNoticeWriter() {
+
+	singletonNoticeLogger.mutex.Lock()
+	defer singletonNoticeLogger.mutex.Unlock()
+
+	singletonNoticeLogger.writer = os.Stderr
 }
 
 // setNoticeFiles configures files for notice writing.

+ 5 - 1
psiphon/remoteServerList_test.go

@@ -425,7 +425,7 @@ func testObfuscatedRemoteServerLists(t *testing.T, omitMD5Sums bool) {
 
 	tunnelEstablished := make(chan struct{}, 1)
 
-	SetNoticeWriter(NewNoticeReceiver(
+	err = SetNoticeWriter(NewNoticeReceiver(
 		func(notice []byte) {
 
 			noticeType, payload, err := GetNotice(notice)
@@ -455,6 +455,10 @@ func testObfuscatedRemoteServerLists(t *testing.T, omitMD5Sums bool) {
 				fmt.Printf("%s\n", string(notice))
 			}
 		}))
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer ResetNoticeWriter()
 
 	ctx, cancelFunc := context.WithCancel(context.Background())
 	defer cancelFunc()

+ 5 - 1
psiphon/server/passthrough_test.go

@@ -170,7 +170,7 @@ func testPassthrough(t *testing.T, legacy bool) {
 
 	tunnelEstablished := make(chan struct{}, 1)
 
-	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
+	err = psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 		func(notice []byte) {
 			noticeType, payload, err := psiphon.GetNotice(notice)
 			if err != nil {
@@ -183,6 +183,10 @@ func testPassthrough(t *testing.T, legacy bool) {
 				}
 			}
 		}))
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer psiphon.ResetNoticeWriter()
 
 	ctx, cancelFunc := context.WithCancel(context.Background())
 	controllerWaitGroup := new(sync.WaitGroup)

+ 5 - 1
psiphon/server/replay_test.go

@@ -308,7 +308,7 @@ func runServerReplayClient(
 
 	tunnelEstablished := make(chan struct{}, 1)
 
-	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
+	err = psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 		func(notice []byte) {
 			noticeType, payload, err := psiphon.GetNotice(notice)
 			if err != nil {
@@ -321,6 +321,10 @@ func runServerReplayClient(
 				}
 			}
 		}))
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer psiphon.ResetNoticeWriter()
 
 	ctx, cancelFunc := context.WithCancel(context.Background())
 	controllerWaitGroup := new(sync.WaitGroup)

+ 11 - 2
psiphon/server/server_test.go

@@ -1398,7 +1398,11 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 		jsonNetworkID)
 
 	// Don't print initial config setup notices
-	psiphon.SetNoticeWriter(io.Discard)
+	err = psiphon.SetNoticeWriter(io.Discard)
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer psiphon.ResetNoticeWriter()
 
 	clientConfig, err := psiphon.LoadConfig([]byte(clientConfigJSON))
 	if err != nil {
@@ -1678,7 +1682,8 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 	untunneledPortForward := make(chan struct{}, 1)
 	discardTunnel := make(chan struct{}, 1)
 
-	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
+	psiphon.ResetNoticeWriter()
+	err = psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
 		func(notice []byte) {
 
 			noticeType, payload, err := psiphon.GetNotice(notice)
@@ -1759,6 +1764,10 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
 				fmt.Printf("%s\n", string(notice))
 			}
 		}))
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer psiphon.ResetNoticeWriter()
 
 	ctx, cancelFunc := context.WithCancel(context.Background())
 

+ 6 - 1
psiphon/server/sessionID_test.go

@@ -23,6 +23,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"os"
 	"path/filepath"
@@ -119,7 +120,11 @@ func TestDuplicateSessionID(t *testing.T) {
 	// Limitation: all tunnels still use one singleton datastore and notice
 	// handler.
 
-	psiphon.SetNoticeWriter(ioutil.Discard)
+	err = psiphon.SetNoticeWriter(io.Discard)
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer psiphon.ResetNoticeWriter()
 
 	clientConfigJSONTemplate := `
     {

+ 5 - 1
psiphon/tactics_test.go

@@ -77,7 +77,7 @@ func TestStandAloneGetTactics(t *testing.T) {
 
 	gotTactics := int32(0)
 
-	SetNoticeWriter(NewNoticeReceiver(
+	err = SetNoticeWriter(NewNoticeReceiver(
 		func(notice []byte) {
 			noticeType, _, err := GetNotice(notice)
 			if err != nil {
@@ -88,6 +88,10 @@ func TestStandAloneGetTactics(t *testing.T) {
 				atomic.StoreInt32(&gotTactics, 1)
 			}
 		}))
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer ResetNoticeWriter()
 
 	ctx, cancelFunc := context.WithTimeout(context.Background(), 30*time.Second)
 	defer cancelFunc()

+ 5 - 1
psiphon/userAgent_test.go

@@ -207,7 +207,7 @@ func attemptConnectionsWithUserAgent(
 	}
 	defer CloseDataStore()
 
-	SetNoticeWriter(NewNoticeReceiver(
+	err = SetNoticeWriter(NewNoticeReceiver(
 		func(notice []byte) {
 			noticeType, payload, err := GetNotice(notice)
 			if err != nil {
@@ -220,6 +220,10 @@ func attemptConnectionsWithUserAgent(
 				}
 			}
 		}))
+	if err != nil {
+		t.Fatalf("error setting notice writer: %s", err)
+	}
+	defer ResetNoticeWriter()
 
 	controller, err := NewController(clientConfig)
 	if err != nil {