|
@@ -181,18 +181,19 @@ func readKeyFromFile(filename string) ([]byte, error) {
|
|
|
|
|
|
|
|
// handleStream bidirectionally connects a client stream with a TCP socket
|
|
// handleStream bidirectionally connects a client stream with a TCP socket
|
|
|
// addressed by upstream.
|
|
// addressed by upstream.
|
|
|
-func handleStream(stream *smux.Stream, upstream *net.TCPAddr, conv uint32) error {
|
|
|
|
|
- conn, err := net.DialTCP("tcp", nil, upstream)
|
|
|
|
|
|
|
+func handleStream(stream *smux.Stream, upstream string, conv uint32) error {
|
|
|
|
|
+ upstreamConn, err := net.Dial("tcp", upstream)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return fmt.Errorf("stream %08x:%d connect upstream: %v", conv, stream.ID(), err)
|
|
return fmt.Errorf("stream %08x:%d connect upstream: %v", conv, stream.ID(), err)
|
|
|
}
|
|
}
|
|
|
- defer conn.Close()
|
|
|
|
|
|
|
+ defer upstreamConn.Close()
|
|
|
|
|
+ upstreamTCPConn := upstreamConn.(*net.TCPConn)
|
|
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
var wg sync.WaitGroup
|
|
|
wg.Add(2)
|
|
wg.Add(2)
|
|
|
go func() {
|
|
go func() {
|
|
|
defer wg.Done()
|
|
defer wg.Done()
|
|
|
- _, err := io.Copy(stream, conn)
|
|
|
|
|
|
|
+ _, err := io.Copy(stream, upstreamTCPConn)
|
|
|
if err == io.EOF {
|
|
if err == io.EOF {
|
|
|
// smux Stream.Write may return io.EOF.
|
|
// smux Stream.Write may return io.EOF.
|
|
|
err = nil
|
|
err = nil
|
|
@@ -200,12 +201,12 @@ func handleStream(stream *smux.Stream, upstream *net.TCPAddr, conv uint32) error
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
log.Printf("stream %08x:%d copy stream←upstream: %v", conv, stream.ID(), err)
|
|
log.Printf("stream %08x:%d copy stream←upstream: %v", conv, stream.ID(), err)
|
|
|
}
|
|
}
|
|
|
- conn.CloseRead()
|
|
|
|
|
|
|
+ upstreamTCPConn.CloseRead()
|
|
|
stream.Close()
|
|
stream.Close()
|
|
|
}()
|
|
}()
|
|
|
go func() {
|
|
go func() {
|
|
|
defer wg.Done()
|
|
defer wg.Done()
|
|
|
- _, err := io.Copy(conn, stream)
|
|
|
|
|
|
|
+ _, err := io.Copy(upstreamTCPConn, stream)
|
|
|
if err == io.EOF {
|
|
if err == io.EOF {
|
|
|
// smux Stream.WriteTo may return io.EOF.
|
|
// smux Stream.WriteTo may return io.EOF.
|
|
|
err = nil
|
|
err = nil
|
|
@@ -213,7 +214,7 @@ func handleStream(stream *smux.Stream, upstream *net.TCPAddr, conv uint32) error
|
|
|
if err != nil && err != io.ErrClosedPipe {
|
|
if err != nil && err != io.ErrClosedPipe {
|
|
|
log.Printf("stream %08x:%d copy upstream←stream: %v", conv, stream.ID(), err)
|
|
log.Printf("stream %08x:%d copy upstream←stream: %v", conv, stream.ID(), err)
|
|
|
}
|
|
}
|
|
|
- conn.CloseWrite()
|
|
|
|
|
|
|
+ upstreamTCPConn.CloseWrite()
|
|
|
}()
|
|
}()
|
|
|
wg.Wait()
|
|
wg.Wait()
|
|
|
|
|
|
|
@@ -222,7 +223,7 @@ func handleStream(stream *smux.Stream, upstream *net.TCPAddr, conv uint32) error
|
|
|
|
|
|
|
|
// acceptStreams wraps a KCP session in a Noise channel and an smux.Session,
|
|
// acceptStreams wraps a KCP session in a Noise channel and an smux.Session,
|
|
|
// then awaits smux streams. It passes each stream to handleStream.
|
|
// then awaits smux streams. It passes each stream to handleStream.
|
|
|
-func acceptStreams(conn *kcp.UDPSession, privkey, pubkey []byte, upstream *net.TCPAddr) error {
|
|
|
|
|
|
|
+func acceptStreams(conn *kcp.UDPSession, privkey, pubkey []byte, upstream string) error {
|
|
|
// Put a Noise channel on top of the KCP conn.
|
|
// Put a Noise channel on top of the KCP conn.
|
|
|
rw, err := noise.NewServer(conn, privkey, pubkey)
|
|
rw, err := noise.NewServer(conn, privkey, pubkey)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -263,7 +264,7 @@ func acceptStreams(conn *kcp.UDPSession, privkey, pubkey []byte, upstream *net.T
|
|
|
|
|
|
|
|
// acceptSessions listens for incoming KCP connections and passes them to
|
|
// acceptSessions listens for incoming KCP connections and passes them to
|
|
|
// acceptStreams.
|
|
// acceptStreams.
|
|
|
-func acceptSessions(ln *kcp.Listener, privkey, pubkey []byte, mtu int, upstream *net.TCPAddr) error {
|
|
|
|
|
|
|
+func acceptSessions(ln *kcp.Listener, privkey, pubkey []byte, mtu int, upstream string) error {
|
|
|
for {
|
|
for {
|
|
|
conn, err := ln.AcceptKCP()
|
|
conn, err := ln.AcceptKCP()
|
|
|
if err != nil {
|
|
if err != nil {
|
|
@@ -739,7 +740,7 @@ func computeMaxEncodedPayload(limit int) int {
|
|
|
return low
|
|
return low
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func run(privkey, pubkey []byte, domain dns.Name, upstream net.Addr, dnsConn net.PacketConn) error {
|
|
|
|
|
|
|
+func run(privkey, pubkey []byte, domain dns.Name, upstream string, dnsConn net.PacketConn) error {
|
|
|
defer dnsConn.Close()
|
|
defer dnsConn.Close()
|
|
|
|
|
|
|
|
log.Printf("pubkey %x", pubkey)
|
|
log.Printf("pubkey %x", pubkey)
|
|
@@ -770,7 +771,7 @@ func run(privkey, pubkey []byte, domain dns.Name, upstream net.Addr, dnsConn net
|
|
|
}
|
|
}
|
|
|
defer ln.Close()
|
|
defer ln.Close()
|
|
|
go func() {
|
|
go func() {
|
|
|
- err := acceptSessions(ln, privkey, pubkey, mtu, upstream.(*net.TCPAddr))
|
|
|
|
|
|
|
+ err := acceptSessions(ln, privkey, pubkey, mtu, upstream)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
log.Printf("acceptSessions: %v", err)
|
|
log.Printf("acceptSessions: %v", err)
|
|
|
}
|
|
}
|
|
@@ -842,10 +843,34 @@ Example:
|
|
|
fmt.Fprintf(os.Stderr, "invalid domain %+q: %v\n", flag.Arg(0), err)
|
|
fmt.Fprintf(os.Stderr, "invalid domain %+q: %v\n", flag.Arg(0), err)
|
|
|
os.Exit(1)
|
|
os.Exit(1)
|
|
|
}
|
|
}
|
|
|
- upstream, err := net.ResolveTCPAddr("tcp", flag.Arg(1))
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- fmt.Fprintf(os.Stderr, "cannot resolve %+q: %v\n", flag.Arg(1), err)
|
|
|
|
|
- os.Exit(1)
|
|
|
|
|
|
|
+ upstream := flag.Arg(1)
|
|
|
|
|
+ // We keep upstream as a string in order to eventually pass it
|
|
|
|
|
+ // to net.Dial in handleStream. But for the sake of displaying
|
|
|
|
|
+ // an error or warning at startup, rather than only when the
|
|
|
|
|
+ // first stream occurs, we apply some parsing and name
|
|
|
|
|
+ // resolution checks here.
|
|
|
|
|
+ {
|
|
|
|
|
+ upstreamHost, _, err := net.SplitHostPort(upstream)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ // host:port format is required in all cases, so
|
|
|
|
|
+ // this is a fatal error.
|
|
|
|
|
+ fmt.Fprintf(os.Stderr, "cannot parse upstream address %+q: %v\n", upstream, err)
|
|
|
|
|
+ os.Exit(1)
|
|
|
|
|
+ }
|
|
|
|
|
+ upstreamIPAddr, err := net.ResolveIPAddr("ip", upstreamHost)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ // Failure to resolve the host portion is only a
|
|
|
|
|
+ // warning. The name will be re-resolved on each
|
|
|
|
|
+ // net.Dial in handleStream.
|
|
|
|
|
+ log.Printf("warning: cannot resolve upstream host %+q: %v", upstreamHost, err)
|
|
|
|
|
+ } else if upstreamIPAddr.IP == nil {
|
|
|
|
|
+ // Handle the special case of an empty string
|
|
|
|
|
+ // for the host portion, which resolves to a nil
|
|
|
|
|
+ // IP. This is a fatal error as we will not be
|
|
|
|
|
+ // able to dial this address.
|
|
|
|
|
+ fmt.Fprintf(os.Stderr, "cannot parse upstream address %+q: missing host in address\n", upstream)
|
|
|
|
|
+ os.Exit(1)
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if udpAddr == "" {
|
|
if udpAddr == "" {
|