Просмотр исходного кода

Factor out a pattern for different kinds of remote address.

David Fifield 6 лет назад
Родитель
Сommit
aefe4f9971
1 измененных файлов с 84 добавлено и 56 удалено
  1. 84 56
      dnstt-client/main.go

+ 84 - 56
dnstt-client/main.go

@@ -149,62 +149,45 @@ func dnsNameCapacity(domain dns.Name) int {
 	return capacity
 }
 
-func run(domain dns.Name, localAddr, udpAddr string) error {
-	var sess *smux.Session
+func run(domain dns.Name, localAddr *net.TCPAddr, remoteAddr net.Addr, pconn net.PacketConn) error {
+	defer pconn.Close()
 
-	if udpAddr != "" {
-		addr, err := net.ResolveUDPAddr("udp", udpAddr)
-		if err != nil {
-			return err
-		}
-		udpConn, err := net.ListenUDP("udp", nil)
-		if err != nil {
-			return fmt.Errorf("opening UDP conn: %v", err)
-		}
-		defer udpConn.Close()
-
-		// Start up the virtual PacketConn for turbotunnel.
-		pconn := NewUDPPacketConn(udpConn, addr, domain)
-
-		// Open a KCP conn on the PacketConn.
-		conn, err := kcp.NewConn2(addr, nil, 0, 0, pconn)
-		if err != nil {
-			return fmt.Errorf("opening KCP conn: %v", err)
-		}
-		defer conn.Close()
-		// Permit coalescing the payloads of consecutive sends.
-		conn.SetStreamMode(true)
-		// Disable the dynamic congestion window (limit only by the
-		// maximum of local and remote static windows).
-		conn.SetNoDelay(
-			0, // default nodelay
-			0, // default interval
-			0, // default resend
-			1, // nc=1 => congestion window off
-		)
-		mtu := dnsNameCapacity(domain) - 8 - 1 - numPadding - 1 // clientid + padding length prefix + padding + data length prefix
-		if mtu < 80 {
-			return fmt.Errorf("domain %s leaves only %d bytes for payload", domain, mtu)
-		}
-		log.Printf("MTU %d\n", mtu)
-		if rc := conn.SetMtu(mtu); !rc {
-			panic(rc)
-		}
+	// Open a KCP conn on the PacketConn.
+	conn, err := kcp.NewConn2(remoteAddr, nil, 0, 0, pconn)
+	if err != nil {
+		return fmt.Errorf("opening KCP conn: %v", err)
+	}
+	defer conn.Close()
+	// Permit coalescing the payloads of consecutive sends.
+	conn.SetStreamMode(true)
+	// Disable the dynamic congestion window (limit only by the
+	// maximum of local and remote static windows).
+	conn.SetNoDelay(
+		0, // default nodelay
+		0, // default interval
+		0, // default resend
+		1, // nc=1 => congestion window off
+	)
+	mtu := dnsNameCapacity(domain) - 8 - 1 - numPadding - 1 // clientid + padding length prefix + padding + data length prefix
+	if mtu < 80 {
+		return fmt.Errorf("domain %s leaves only %d bytes for payload", domain, mtu)
+	}
+	log.Printf("MTU %d\n", mtu)
+	if rc := conn.SetMtu(mtu); !rc {
+		panic(rc)
+	}
 
-		// Start a smux session on the KCP conn.
-		smuxConfig := smux.DefaultConfig()
-		smuxConfig.Version = 2
-		smuxConfig.KeepAliveTimeout = idleTimeout
-		sess, err = smux.Client(conn, smuxConfig)
-		if err != nil {
-			return fmt.Errorf("opening smux session: %v", err)
-		}
-		defer sess.Close()
-	} else {
-		return fmt.Errorf("need a UDP address")
+	// Start a smux session on the KCP conn.
+	smuxConfig := smux.DefaultConfig()
+	smuxConfig.Version = 2
+	smuxConfig.KeepAliveTimeout = idleTimeout
+	sess, err := smux.Client(conn, smuxConfig)
+	if err != nil {
+		return fmt.Errorf("opening smux session: %v", err)
 	}
+	defer sess.Close()
 
-	ln, err := net.Listen("tcp", localAddr)
+	ln, err := net.ListenTCP("tcp", localAddr)
 	if err != nil {
 		return fmt.Errorf("opening local listener: %v", err)
 	}
@@ -234,7 +217,7 @@ func main() {
 		fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s -udp ADDR DOMAIN LOCALADDR\n", os.Args[0])
 		flag.PrintDefaults()
 	}
-	flag.StringVar(&udpAddr, "udp", "", "UDP port of DNS server")
+	flag.StringVar(&udpAddr, "udp", "", "address of UDP DNS resolver")
 	flag.Parse()
 
 	log.SetFlags(log.LstdFlags | log.LUTC)
@@ -245,12 +228,57 @@ func main() {
 	}
 	domain, err := dns.ParseName(flag.Arg(0))
 	if err != nil {
-		log.Printf("invalid domain %+q: %v\n", flag.Arg(0), err)
+		fmt.Fprintf(os.Stderr, "invalid domain %+q: %v\n", flag.Arg(0), err)
+		os.Exit(1)
+	}
+	localAddr, err := net.ResolveTCPAddr("tcp", flag.Arg(1))
+	if err != nil {
+		fmt.Fprintln(os.Stderr, err)
+		os.Exit(1)
+	}
+
+	// Iterate over the remote resolver address options and select one and
+	// only one.
+	var remoteAddr net.Addr
+	var pconn net.PacketConn
+	for _, opt := range []struct {
+		s string
+		f func(string) (net.Addr, net.PacketConn, error)
+	}{
+		// -udp
+		{udpAddr, func(s string) (net.Addr, net.PacketConn, error) {
+			addr, err := net.ResolveUDPAddr("udp", s)
+			if err != nil {
+				return nil, nil, err
+			}
+			udpConn, err := net.ListenUDP("udp", nil)
+			if err != nil {
+				return nil, nil, err
+			}
+			return addr, NewUDPPacketConn(udpConn, addr, domain), nil
+		}},
+	} {
+		if opt.s == "" {
+			continue
+		}
+		if pconn != nil {
+			fmt.Fprintf(os.Stderr, "the -udp option may be given only once\n")
+			os.Exit(1)
+		}
+		a, c, err := opt.f(opt.s)
+		if err != nil {
+			fmt.Fprintln(os.Stderr, err)
+			os.Exit(1)
+		}
+		remoteAddr = a
+		pconn = c
+	}
+	if pconn == nil {
+		fmt.Fprintf(os.Stderr, "the -udp option is required\n")
 		os.Exit(1)
 	}
-	localAddr := flag.Arg(1)
 
-	err = run(domain, localAddr, udpAddr)
+	err = run(domain, localAddr, remoteAddr, pconn)
 	if err != nil {
 		log.Fatal(err)
 	}