Просмотр исходного кода

Use net.Dial, rather than net.DialTCP, to dial upstream.

The usual use case for upstream is that it is a localhost IP address and
port, but it may also be a hostname and port. net.DialTCP resolves the
hostname once and for all, and only uses one of the hostname's IP
addresses if there are more than one. net.Dial will try all the IP
addresses in turn until it is able to establish a connection.

Now upstream is kept as a string variable all the way through the call
chain. For the sake of usability, we try resolving the address with
net.ResolveTCPAddr in main, to emit an error or warning right away,
rather than deferring it to the first stream.
David Fifield 5 лет назад
Родитель
Сommit
6e5ba30abf
1 измененных файлов с 40 добавлено и 15 удалено
  1. 40 15
      dnstt-server/main.go

+ 40 - 15
dnstt-server/main.go

@@ -181,18 +181,19 @@ func readKeyFromFile(filename string) ([]byte, error) {
 
 
 // handleStream bidirectionally connects a client stream with a TCP socket
 // handleStream bidirectionally connects a client stream with a TCP socket
 // addressed by upstream.
 // addressed by upstream.
-func handleStream(stream *smux.Stream, upstream *net.TCPAddr, conv uint32) error {
-	conn, err := net.DialTCP("tcp", nil, upstream)
+func handleStream(stream *smux.Stream, upstream string, conv uint32) error {
+	upstreamConn, err := net.Dial("tcp", upstream)
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("stream %08x:%d connect upstream: %v", conv, stream.ID(), err)
 		return fmt.Errorf("stream %08x:%d connect upstream: %v", conv, stream.ID(), err)
 	}
 	}
-	defer conn.Close()
+	defer upstreamConn.Close()
+	upstreamTCPConn := upstreamConn.(*net.TCPConn)
 
 
 	var wg sync.WaitGroup
 	var wg sync.WaitGroup
 	wg.Add(2)
 	wg.Add(2)
 	go func() {
 	go func() {
 		defer wg.Done()
 		defer wg.Done()
-		_, err := io.Copy(stream, conn)
+		_, err := io.Copy(stream, upstreamTCPConn)
 		if err == io.EOF {
 		if err == io.EOF {
 			// smux Stream.Write may return io.EOF.
 			// smux Stream.Write may return io.EOF.
 			err = nil
 			err = nil
@@ -200,12 +201,12 @@ func handleStream(stream *smux.Stream, upstream *net.TCPAddr, conv uint32) error
 		if err != nil {
 		if err != nil {
 			log.Printf("stream %08x:%d copy stream←upstream: %v", conv, stream.ID(), err)
 			log.Printf("stream %08x:%d copy stream←upstream: %v", conv, stream.ID(), err)
 		}
 		}
-		conn.CloseRead()
+		upstreamTCPConn.CloseRead()
 		stream.Close()
 		stream.Close()
 	}()
 	}()
 	go func() {
 	go func() {
 		defer wg.Done()
 		defer wg.Done()
-		_, err := io.Copy(conn, stream)
+		_, err := io.Copy(upstreamTCPConn, stream)
 		if err == io.EOF {
 		if err == io.EOF {
 			// smux Stream.WriteTo may return io.EOF.
 			// smux Stream.WriteTo may return io.EOF.
 			err = nil
 			err = nil
@@ -213,7 +214,7 @@ func handleStream(stream *smux.Stream, upstream *net.TCPAddr, conv uint32) error
 		if err != nil && err != io.ErrClosedPipe {
 		if err != nil && err != io.ErrClosedPipe {
 			log.Printf("stream %08x:%d copy upstream←stream: %v", conv, stream.ID(), err)
 			log.Printf("stream %08x:%d copy upstream←stream: %v", conv, stream.ID(), err)
 		}
 		}
-		conn.CloseWrite()
+		upstreamTCPConn.CloseWrite()
 	}()
 	}()
 	wg.Wait()
 	wg.Wait()
 
 
@@ -222,7 +223,7 @@ func handleStream(stream *smux.Stream, upstream *net.TCPAddr, conv uint32) error
 
 
 // acceptStreams wraps a KCP session in a Noise channel and an smux.Session,
 // acceptStreams wraps a KCP session in a Noise channel and an smux.Session,
 // then awaits smux streams. It passes each stream to handleStream.
 // then awaits smux streams. It passes each stream to handleStream.
-func acceptStreams(conn *kcp.UDPSession, privkey, pubkey []byte, upstream *net.TCPAddr) error {
+func acceptStreams(conn *kcp.UDPSession, privkey, pubkey []byte, upstream string) error {
 	// Put a Noise channel on top of the KCP conn.
 	// Put a Noise channel on top of the KCP conn.
 	rw, err := noise.NewServer(conn, privkey, pubkey)
 	rw, err := noise.NewServer(conn, privkey, pubkey)
 	if err != nil {
 	if err != nil {
@@ -263,7 +264,7 @@ func acceptStreams(conn *kcp.UDPSession, privkey, pubkey []byte, upstream *net.T
 
 
 // acceptSessions listens for incoming KCP connections and passes them to
 // acceptSessions listens for incoming KCP connections and passes them to
 // acceptStreams.
 // acceptStreams.
-func acceptSessions(ln *kcp.Listener, privkey, pubkey []byte, mtu int, upstream *net.TCPAddr) error {
+func acceptSessions(ln *kcp.Listener, privkey, pubkey []byte, mtu int, upstream string) error {
 	for {
 	for {
 		conn, err := ln.AcceptKCP()
 		conn, err := ln.AcceptKCP()
 		if err != nil {
 		if err != nil {
@@ -739,7 +740,7 @@ func computeMaxEncodedPayload(limit int) int {
 	return low
 	return low
 }
 }
 
 
-func run(privkey, pubkey []byte, domain dns.Name, upstream net.Addr, dnsConn net.PacketConn) error {
+func run(privkey, pubkey []byte, domain dns.Name, upstream string, dnsConn net.PacketConn) error {
 	defer dnsConn.Close()
 	defer dnsConn.Close()
 
 
 	log.Printf("pubkey %x", pubkey)
 	log.Printf("pubkey %x", pubkey)
@@ -770,7 +771,7 @@ func run(privkey, pubkey []byte, domain dns.Name, upstream net.Addr, dnsConn net
 	}
 	}
 	defer ln.Close()
 	defer ln.Close()
 	go func() {
 	go func() {
-		err := acceptSessions(ln, privkey, pubkey, mtu, upstream.(*net.TCPAddr))
+		err := acceptSessions(ln, privkey, pubkey, mtu, upstream)
 		if err != nil {
 		if err != nil {
 			log.Printf("acceptSessions: %v", err)
 			log.Printf("acceptSessions: %v", err)
 		}
 		}
@@ -842,10 +843,34 @@ Example:
 			fmt.Fprintf(os.Stderr, "invalid domain %+q: %v\n", flag.Arg(0), err)
 			fmt.Fprintf(os.Stderr, "invalid domain %+q: %v\n", flag.Arg(0), err)
 			os.Exit(1)
 			os.Exit(1)
 		}
 		}
-		upstream, err := net.ResolveTCPAddr("tcp", flag.Arg(1))
-		if err != nil {
-			fmt.Fprintf(os.Stderr, "cannot resolve %+q: %v\n", flag.Arg(1), err)
-			os.Exit(1)
+		upstream := flag.Arg(1)
+		// We keep upstream as a string in order to eventually pass it
+		// to net.Dial in handleStream. But for the sake of displaying
+		// an error or warning at startup, rather than only when the
+		// first stream occurs, we apply some parsing and name
+		// resolution checks here.
+		{
+			upstreamHost, _, err := net.SplitHostPort(upstream)
+			if err != nil {
+				// host:port format is required in all cases, so
+				// this is a fatal error.
+				fmt.Fprintf(os.Stderr, "cannot parse upstream address %+q: %v\n", upstream, err)
+				os.Exit(1)
+			}
+			upstreamIPAddr, err := net.ResolveIPAddr("ip", upstreamHost)
+			if err != nil {
+				// Failure to resolve the host portion is only a
+				// warning. The name will be re-resolved on each
+				// net.Dial in handleStream.
+				log.Printf("warning: cannot resolve upstream host %+q: %v", upstreamHost, err)
+			} else if upstreamIPAddr.IP == nil {
+				// Handle the special case of an empty string
+				// for the host portion, which resolves to a nil
+				// IP. This is a fatal error as we will not be
+				// able to dial this address.
+				fmt.Fprintf(os.Stderr, "cannot parse upstream address %+q: missing host in address\n", upstream)
+				os.Exit(1)
+			}
 		}
 		}
 
 
 		if udpAddr == "" {
 		if udpAddr == "" {