Просмотр исходного кода

Add packet tunnel tunneled DNS test

Rod Hynes 5 лет назад
Родитель
Сommit
3bfe2207c0
1 измененных файлов с 85 добавлено и 10 удалено
  1. 85 10
      psiphon/common/tun/tun_test.go

+ 85 - 10
psiphon/common/tun/tun_test.go

@@ -37,6 +37,7 @@ import (
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/stacktrace"
+	"github.com/miekg/dns"
 )
 
 const (
@@ -57,10 +58,6 @@ func TestTunneledTCPIPv6(t *testing.T) {
 	testTunneledTCP(t, true)
 }
 
-func TestTunneledDNS(t *testing.T) {
-	t.Skip("TODO: test DNS tunneling")
-}
-
 func TestSessionExpiry(t *testing.T) {
 	t.Skip("TODO: test short session TTLs actually persist/expire as expected")
 }
@@ -81,7 +78,8 @@ func testTunneledTCP(t *testing.T, useIPv6 bool) {
 	// - starts a packet tunnel server that uses a unix domain socket for client channels
 	// - starts CONCURRENT_CLIENT_COUNT concurrent clients
 	// - each client runs a packet tunnel client connected to the server unix domain socket
-	//   and establishes a TCP client connection to the TCP through the packet tunnel
+	// - one client first performs a tunneled DNS query against an external DNS server
+	// - clients establish a TCP client connection to the TCP server through the packet tunnel
 	// - each TCP client transfers TCP_RELAY_TOTAL_SIZE bytes to the TCP server
 	// - the test checks that all data echoes back correctly and that the server packet
 	//   metrics reflects the expected amount of data transferred through the tunnel
@@ -136,7 +134,7 @@ func testTunneledTCP(t *testing.T, useIPv6 bool) {
 	results := make(chan error, CONCURRENT_CLIENT_COUNT)
 
 	for i := 0; i < CONCURRENT_CLIENT_COUNT; i++ {
-		go func() {
+		go func(clientNum int) {
 
 			testClient, err := startTestClient(
 				useIPv6, MTU, []string{testTCPServer.getListenerIPAddress()})
@@ -145,6 +143,18 @@ func testTunneledTCP(t *testing.T, useIPv6 bool) {
 				return
 			}
 
+			// Test one tunneled DNS query.
+
+			if clientNum == 0 {
+				err = testDNSClient(
+					useIPv6,
+					testClient.tunClient.device.Name())
+				if err != nil {
+					results <- fmt.Errorf("testDNSClient failed: %s", err)
+					return
+				}
+			}
+
 			// The TCP client will bind to the packet tunnel client tun
 			// device and connect to the TCP server. With the bind to
 			// device, TCP packets will flow through the packet tunnel
@@ -244,7 +254,7 @@ func testTunneledTCP(t *testing.T, useIPv6 bool) {
 			}
 
 			results <- nil
-		}()
+		}(i)
 	}
 
 	for i := 0; i < CONCURRENT_CLIENT_COUNT; i++ {
@@ -325,14 +335,20 @@ func startTestServer(
 
 	logger := newTestLogger(true)
 
-	noDNSResolvers := func() []net.IP { return make([]net.IP, 0) }
+	getDNSResolverIPv4Addresses := func() []net.IP {
+		return []net.IP{net.ParseIP("8.8.8.8")}
+	}
+
+	getDNSResolverIPv6Addresses := func() []net.IP {
+		return []net.IP{net.ParseIP("2001:4860:4860::8888")}
+	}
 
 	config := &ServerConfig{
 		Logger:                          logger,
 		SudoNetworkConfigCommands:       os.Getenv("TUN_TEST_SUDO") != "",
 		AllowNoIPv6NetworkConfiguration: !useIPv6,
-		GetDNSResolverIPv4Addresses:     noDNSResolvers,
-		GetDNSResolverIPv6Addresses:     noDNSResolvers,
+		GetDNSResolverIPv4Addresses:     getDNSResolverIPv4Addresses,
+		GetDNSResolverIPv6Addresses:     getDNSResolverIPv6Addresses,
 		MTU:                             MTU,
 		AllowBogons:                     true,
 	}
@@ -477,6 +493,8 @@ func startTestClient(
 
 	// Assumes IP addresses are available on test host
 
+	// TODO: assign unique IP to each testClient?
+
 	config := &ClientConfig{
 		Logger:                          logger,
 		SudoNetworkConfigCommands:       os.Getenv("TUN_TEST_SUDO") != "",
@@ -687,6 +705,63 @@ func (client *testTCPClient) stop() {
 	client.conn.Close()
 }
 
+func testDNSClient(useIPv6 bool, tunDeviceName string) error {
+
+	var ipv4 [4]byte
+	var ipv6 [16]byte
+	var domain int
+	var sockAddr syscall.Sockaddr
+
+	if !useIPv6 {
+		copy(ipv4[:], transparentDNSResolverIPv4Address)
+		domain = syscall.AF_INET
+		sockAddr = &syscall.SockaddrInet4{Addr: ipv4, Port: portNumberDNS}
+	} else {
+		copy(ipv6[:], transparentDNSResolverIPv6Address)
+		domain = syscall.AF_INET6
+		sockAddr = &syscall.SockaddrInet6{Addr: ipv6, Port: portNumberDNS}
+	}
+
+	socketFd, err := syscall.Socket(domain, syscall.SOCK_DGRAM, 0)
+	if err != nil {
+		return err
+	}
+	defer syscall.Close(socketFd)
+
+	err = BindToDevice(socketFd, tunDeviceName)
+	if err != nil {
+		return err
+	}
+
+	err = syscall.Connect(socketFd, sockAddr)
+	if err != nil {
+		return err
+	}
+
+	file := os.NewFile(uintptr(socketFd), "")
+	conn, err := net.FileConn(file)
+	file.Close()
+	if err != nil {
+		return err
+	}
+	defer conn.Close()
+
+	dnsConn := &dns.Conn{Conn: conn}
+	defer dnsConn.Close()
+
+	query := new(dns.Msg)
+	query.SetQuestion(dns.Fqdn("www.example.org"), dns.TypeA)
+	query.RecursionDesired = true
+
+	dnsConn.WriteMsg(query)
+	_, err = dnsConn.ReadMsg()
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
 type testLogger struct {
 	packetMetrics chan common.LogFields
 }