Browse Source

Fix IPv6 address formatting with ports

Rod Hynes 4 years ago
parent
commit
195ca8d7fa

+ 2 - 2
psiphon/LookupIP.go

@@ -25,8 +25,8 @@ package psiphon
 import (
 	"context"
 	std_errors "errors"
-	"fmt"
 	"net"
+	"strconv"
 	"syscall"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
@@ -133,7 +133,7 @@ func bindLookupIP(
 	}
 
 	netConn, err := dialer.DialContext(
-		ctx, "udp", fmt.Sprintf("%s:%d", ipAddr.String(), DNS_PORT))
+		ctx, "udp", net.JoinHostPort(ipAddr.String(), strconv.Itoa(DNS_PORT)))
 	if err != nil {
 		return nil, errors.Trace(err)
 	}

+ 2 - 8
psiphon/TCPConn_bind.go

@@ -24,10 +24,8 @@ package psiphon
 
 import (
 	"context"
-	"fmt"
 	"math/rand"
 	"net"
-	"strconv"
 	"syscall"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
@@ -37,11 +35,7 @@ import (
 func tcpDial(ctx context.Context, addr string, config *DialConfig) (net.Conn, error) {
 
 	// Get the remote IP and port, resolving a domain name if necessary
-	host, strPort, err := net.SplitHostPort(addr)
-	if err != nil {
-		return nil, errors.Trace(err)
-	}
-	port, err := strconv.Atoi(strPort)
+	host, port, err := net.SplitHostPort(addr)
 	if err != nil {
 		return nil, errors.Trace(err)
 	}
@@ -126,7 +120,7 @@ func tcpDial(ctx context.Context, addr string, config *DialConfig) (net.Conn, er
 		}
 
 		conn, err := dialer.DialContext(
-			ctx, "tcp", fmt.Sprintf("%s:%d", ipAddrs[index].String(), port))
+			ctx, "tcp", net.JoinHostPort(ipAddrs[index].String(), port))
 		if err != nil {
 			lastErr = errors.Trace(err)
 			continue

+ 10 - 11
psiphon/dialParameters.go

@@ -23,7 +23,6 @@ import (
 	"bytes"
 	"crypto/md5"
 	"encoding/binary"
-	"fmt"
 	"net"
 	"net/http"
 	"strconv"
@@ -522,7 +521,7 @@ func MakeDialParameters(
 				return nil, errors.Trace(err)
 			}
 
-			dialParams.MeekDialAddress = fmt.Sprintf("%s:443", dialParams.MeekFrontingDialAddress)
+			dialParams.MeekDialAddress = net.JoinHostPort(dialParams.MeekFrontingDialAddress, "443")
 			dialParams.MeekHostHeader = dialParams.MeekFrontingHost
 
 			// For a FrontingSpec, an SNI value of "" indicates to disable/omit SNI, so
@@ -656,14 +655,14 @@ func MakeDialParameters(
 			if serverEntry.MeekServerPort == 80 {
 				dialParams.MeekHostHeader = hostname
 			} else {
-				dialParams.MeekHostHeader = fmt.Sprintf("%s:%d", hostname, serverEntry.MeekServerPort)
+				dialParams.MeekHostHeader = net.JoinHostPort(
+					hostname, strconv.Itoa(serverEntry.MeekServerPort))
 			}
 		} else if protocol.TunnelProtocolUsesQUIC(dialParams.TunnelProtocol) {
 
-			dialParams.QUICDialSNIAddress = fmt.Sprintf(
-				"%s:%d",
+			dialParams.QUICDialSNIAddress = net.JoinHostPort(
 				selectHostName(dialParams.TunnelProtocol, p),
-				serverEntry.SshObfuscatedQUICPort)
+				strconv.Itoa(serverEntry.SshObfuscatedQUICPort))
 		}
 	}
 
@@ -748,12 +747,12 @@ func MakeDialParameters(
 		protocol.TUNNEL_PROTOCOL_CONJURE_OBFUSCATED_SSH,
 		protocol.TUNNEL_PROTOCOL_QUIC_OBFUSCATED_SSH:
 
-		dialParams.DirectDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
+		dialParams.DirectDialAddress = net.JoinHostPort(serverEntry.IpAddress, dialParams.DialPortNumber)
 
 	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK,
 		protocol.TUNNEL_PROTOCOL_FRONTED_MEEK_QUIC_OBFUSCATED_SSH:
 
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", dialParams.MeekFrontingDialAddress, dialPortNumber)
+		dialParams.MeekDialAddress = net.JoinHostPort(dialParams.MeekFrontingDialAddress, dialParams.DialPortNumber)
 		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
 		if serverEntry.MeekFrontingDisableSNI {
 			dialParams.MeekSNIServerName = ""
@@ -765,14 +764,14 @@ func MakeDialParameters(
 
 	case protocol.TUNNEL_PROTOCOL_FRONTED_MEEK_HTTP:
 
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", dialParams.MeekFrontingDialAddress, dialPortNumber)
+		dialParams.MeekDialAddress = net.JoinHostPort(dialParams.MeekFrontingDialAddress, dialParams.DialPortNumber)
 		dialParams.MeekHostHeader = dialParams.MeekFrontingHost
 		// For FRONTED HTTP, the Host header cannot be transformed.
 		dialParams.MeekTransformedHostName = false
 
 	case protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK:
 
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
+		dialParams.MeekDialAddress = net.JoinHostPort(serverEntry.IpAddress, dialParams.DialPortNumber)
 		if !dialParams.MeekTransformedHostName {
 			if dialPortNumber == 80 {
 				dialParams.MeekHostHeader = serverEntry.IpAddress
@@ -784,7 +783,7 @@ func MakeDialParameters(
 	case protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS,
 		protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET:
 
-		dialParams.MeekDialAddress = fmt.Sprintf("%s:%d", serverEntry.IpAddress, dialPortNumber)
+		dialParams.MeekDialAddress = net.JoinHostPort(serverEntry.IpAddress, dialParams.DialPortNumber)
 		if !dialParams.MeekTransformedHostName {
 			// Note: IP address in SNI field will be omitted.
 			dialParams.MeekSNIServerName = serverEntry.IpAddress

+ 1 - 1
psiphon/httpProxy.go

@@ -100,7 +100,7 @@ func NewHttpProxy(
 	listenIP string) (proxy *HttpProxy, err error) {
 
 	listener, err := net.Listen(
-		"tcp", fmt.Sprintf("%s:%d", listenIP, config.LocalHttpProxyPort))
+		"tcp", net.JoinHostPort(listenIP, strconv.Itoa(config.LocalHttpProxyPort)))
 	if err != nil {
 		if IsAddressInUseError(err) {
 			NoticeHttpProxyPortInUse(config.LocalHttpProxyPort)

+ 4 - 0
psiphon/server/log.go

@@ -314,6 +314,10 @@ func InitLogging(config *Config) (retErr error) {
 	return retErr
 }
 
+func IsLogLevelDebug() bool {
+	return log.Logger.Level == logrus.DebugLevel
+}
+
 func init() {
 
 	// Suppress standard "log" package logging performed by other packages.

+ 3 - 3
psiphon/server/server_test.go

@@ -55,7 +55,7 @@ import (
 
 var serverIPAddress, testDataDirName string
 var mockWebServerURL, mockWebServerExpectedResponse string
-var mockWebServerPort = 8080
+var mockWebServerPort = "8080"
 
 func TestMain(m *testing.M) {
 	flag.Parse()
@@ -93,7 +93,7 @@ func runMockWebServer() (string, string) {
 	serveMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
 		w.Write([]byte(responseBody))
 	})
-	webServerAddress := fmt.Sprintf("%s:%d", serverIPAddress, mockWebServerPort)
+	webServerAddress := net.JoinHostPort(serverIPAddress, mockWebServerPort)
 	server := &http.Server{
 		Addr:    webServerAddress,
 		Handler: serveMux,
@@ -2106,7 +2106,7 @@ func paveTrafficRulesFile(
 		t.Fatalf("unexpected intLookupThreshold")
 	}
 
-	TCPPorts := fmt.Sprintf("%d", mockWebServerPort)
+	TCPPorts := mockWebServerPort
 	UDPPorts := "53, 123, 10001, 10002, 10003, 10004, 10005, 10006, 10007, 10008, 10009, 10010"
 
 	allowTCPPorts := TCPPorts

+ 2 - 2
psiphon/server/tunnelServer.go

@@ -145,8 +145,8 @@ func (server *TunnelServer) Run() error {
 
 	for tunnelProtocol, listenPort := range support.Config.TunnelProtocolPorts {
 
-		localAddress := fmt.Sprintf(
-			"%s:%d", support.Config.ServerIPAddress, listenPort)
+		localAddress := net.JoinHostPort(
+			support.Config.ServerIPAddress, strconv.Itoa(listenPort))
 
 		var listener net.Listener
 		var BPFProgramName string

+ 11 - 6
psiphon/server/udp.go

@@ -25,6 +25,7 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"strconv"
 	"sync"
 	"sync/atomic"
 
@@ -237,10 +238,14 @@ func (mux *udpgwPortForwardMultiplexer) run() {
 			// Can't defer sshClient.closedPortForward() here;
 			// relayDownstream will call sshClient.closedPortForward()
 
-			log.WithTraceFields(
-				LogFields{
-					"remoteAddr": fmt.Sprintf("%s:%d", dialIP.String(), dialPort),
-					"connID":     message.connID}).Debug("dialing")
+			// Pre-check log level to avoid overhead of rendering log for
+			// every DNS query and other UDP port forward.
+			if IsLogLevelDebug() {
+				log.WithTraceFields(
+					LogFields{
+						"remoteAddr": net.JoinHostPort(dialIP.String(), strconv.Itoa(dialPort)),
+						"connID":     message.connID}).Debug("dialing")
+			}
 
 			udpConn, err := net.DialUDP(
 				"udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort})
@@ -463,8 +468,8 @@ func (portForward *udpgwPortForward) relayDownstream() {
 
 	log.WithTraceFields(
 		LogFields{
-			"remoteAddr": fmt.Sprintf("%s:%d",
-				net.IP(portForward.remoteIP).String(), portForward.remotePort),
+			"remoteAddr": net.JoinHostPort(
+				net.IP(portForward.remoteIP).String(), strconv.Itoa(int(portForward.remotePort))),
 			"bytesUp":   bytesUp,
 			"bytesDown": bytesDown,
 			"connID":    portForward.connID}).Debug("exiting")

+ 4 - 3
psiphon/server/webServer.go

@@ -22,11 +22,11 @@ package server
 import (
 	"crypto/tls"
 	"encoding/json"
-	"fmt"
 	"io/ioutil"
 	golanglog "log"
 	"net"
 	"net/http"
+	"strconv"
 	"sync"
 	"time"
 
@@ -101,8 +101,9 @@ func RunWebServer(
 		},
 	}
 
-	localAddress := fmt.Sprintf("%s:%d",
-		support.Config.ServerIPAddress, support.Config.WebServerPort)
+	localAddress := net.JoinHostPort(
+		support.Config.ServerIPAddress,
+		strconv.Itoa(support.Config.WebServerPort))
 
 	listener, err := net.Listen("tcp", localAddress)
 	if err != nil {

+ 2 - 2
psiphon/socksProxy.go

@@ -20,8 +20,8 @@
 package psiphon
 
 import (
-	"fmt"
 	"net"
+	"strconv"
 	"strings"
 	"sync"
 
@@ -53,7 +53,7 @@ func NewSocksProxy(
 	listenIP string) (proxy *SocksProxy, err error) {
 
 	listener, err := socks.ListenSocks(
-		"tcp", fmt.Sprintf("%s:%d", listenIP, config.LocalSocksProxyPort))
+		"tcp", net.JoinHostPort(listenIP, strconv.Itoa(config.LocalSocksProxyPort)))
 	if err != nil {
 		if IsAddressInUseError(err) {
 			NoticeSocksProxyPortInUse(config.LocalSocksProxyPort)