David Fifield 6 лет назад
Родитель
Сommit
a6c891c5ae
3 измененных файлов с 226 добавлено и 3 удалено
  1. 207 0
      dnstt-client/doh.go
  2. 15 3
      dnstt-client/main.go
  3. 4 0
      dnstt-client/udp.go

+ 207 - 0
dnstt-client/doh.go

@@ -0,0 +1,207 @@
+package main
+
+import (
+	"bytes"
+	"crypto/rand"
+	"encoding/binary"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"log"
+	"net/http"
+	"net/url"
+	"time"
+
+	"www.bamsoftware.com/git/dnstt.git/dns"
+	"www.bamsoftware.com/git/dnstt.git/turbotunnel"
+)
+
+type DoHPacketConn struct {
+	clientID  turbotunnel.ClientID
+	domain    dns.Name
+	urlString string
+	pollChan  chan struct{}
+	sendChan  chan []byte
+	*turbotunnel.QueuePacketConn
+}
+
+func NewDoHPacketConn(urlString string, domain dns.Name) (*DoHPacketConn, error) {
+	u, err := url.Parse(urlString)
+	if err != nil {
+		return nil, err
+	}
+	if u.Scheme != "https" {
+		return nil, fmt.Errorf("bad URL scheme %+q", u.Scheme)
+	}
+	// Generate a new random ClientID.
+	var clientID turbotunnel.ClientID
+	rand.Read(clientID[:])
+	c := &DoHPacketConn{
+		clientID:        clientID,
+		domain:          domain,
+		urlString:       urlString,
+		pollChan:        make(chan struct{}),
+		sendChan:        make(chan []byte, 32),
+		QueuePacketConn: turbotunnel.NewQueuePacketConn(clientID, idleTimeout),
+	}
+	go func() {
+		err := c.sendLoop()
+		if err != nil {
+			log.Printf("sendLoop: %v", err)
+		}
+	}()
+	for i := 0; i < 10; i++ {
+		go func() {
+			for p := range c.sendChan {
+				err := c.send(p)
+				if err != nil {
+					log.Printf("sender thread: %v", err)
+				}
+			}
+		}()
+	}
+	return c, nil
+}
+
+// send sends a single packet in an HTTP request.
+func (c *DoHPacketConn) send(p []byte) error {
+	var decoded []byte
+	{
+		if len(p) >= 224 {
+			return fmt.Errorf("too long")
+		}
+		var buf bytes.Buffer
+		// ClientID
+		buf.Write(c.clientID[:])
+		// Padding / cache inhibition
+		buf.WriteByte(224 + numPadding)
+		io.CopyN(&buf, rand.Reader, numPadding)
+		// Packet contents
+		if len(p) > 0 {
+			buf.WriteByte(byte(len(p)))
+			buf.Write(p)
+		}
+		decoded = buf.Bytes()
+	}
+
+	encoded := make([]byte, base32Encoding.EncodedLen(len(decoded)))
+	base32Encoding.Encode(encoded, decoded)
+	labels := chunks(encoded, 63)
+	labels = append(labels, c.domain...)
+	name, err := dns.NewName(labels)
+	if err != nil {
+		return err
+	}
+
+	var id uint16
+	binary.Read(rand.Reader, binary.BigEndian, &id)
+	query := &dns.Message{
+		ID:    id,
+		Flags: 0x0100, // QR = 0, RD = 1
+		Question: []dns.Question{
+			{
+				Name:  name,
+				Type:  dns.RRTypeTXT,
+				Class: dns.ClassIN,
+			},
+		},
+		// EDNS(0)
+		Additional: []dns.RR{
+			{
+				Name:  dns.Name{},
+				Type:  dns.RRTypeOPT,
+				Class: 4096, // requestor's UDP payload size
+				TTL:   0,    // extended RCODE and flags
+				Data:  []byte{},
+			},
+		},
+	}
+	buf, err := query.WireFormat()
+	if err != nil {
+		return err
+	}
+
+	req, err := http.NewRequest("POST", c.urlString, bytes.NewReader(buf))
+	if err != nil {
+		return err
+	}
+	req.Header.Set("Accept", "application/dns-message")
+	req.Header.Set("Content-Type", "application/dns-message")
+	resp, err := http.DefaultTransport.RoundTrip(req)
+	if err != nil {
+		return err
+	}
+	defer resp.Body.Close()
+	if resp.StatusCode != http.StatusOK || resp.Header.Get("Content-Type") != "application/dns-message" {
+		return fmt.Errorf("unexpected response")
+	}
+	body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 64000))
+	if err != nil {
+		// Don't report an error if we at least managed to send.
+		return nil
+	}
+	// Got a response. Try to parse it as a DNS message.
+	dnsResp, err := dns.MessageFromWireFormat(body)
+	if err != nil {
+		log.Printf("MessageFromWireFormat: %v", err)
+		return nil
+	}
+	payload := dnsResponsePayload(&dnsResp, c.domain)
+	// Reading anything gives sendLoop license to poll immediately.
+	if len(payload) > 0 {
+		select {
+		case c.pollChan <- struct{}{}:
+		default:
+		}
+		select {
+		case c.pollChan <- struct{}{}:
+		default:
+		}
+	}
+	// Pull out the packets contained in the payload.
+	r := bytes.NewReader(payload)
+	for {
+		p, err := nextPacket(r)
+		if err != nil {
+			break
+		}
+		c.QueuePacketConn.QueueIncoming(p, dummyAddr{})
+	}
+	return nil
+}
+
+func (c *DoHPacketConn) sendLoop() error {
+	pollDelay := initPollDelay
+	pollTimer := time.NewTimer(pollDelay)
+	for {
+		var p []byte
+		select {
+		case <-c.pollChan:
+			if !pollTimer.Stop() {
+				<-pollTimer.C
+			}
+			p = nil
+		case p = <-c.QueuePacketConn.OutgoingQueue(dummyAddr{}):
+			if !pollTimer.Stop() {
+				<-pollTimer.C
+			}
+			pollDelay = initPollDelay
+		case <-pollTimer.C:
+			p = nil
+			pollDelay = time.Duration(float64(pollDelay) * pollDelayMultiplier)
+			if pollDelay > maxPollDelay {
+				pollDelay = maxPollDelay
+			}
+		}
+		pollTimer.Reset(pollDelay)
+		select {
+		case c.sendChan <- p:
+		default:
+		}
+	}
+}
+
+func (c *DoHPacketConn) Close() error {
+	close(c.sendChan) // TODO
+	return c.QueuePacketConn.Close()
+}

+ 15 - 3
dnstt-client/main.go

@@ -210,13 +210,20 @@ func run(domain dns.Name, localAddr *net.TCPAddr, remoteAddr net.Addr, pconn net
 	}
 }
 
+type dummyAddr struct{}
+
+func (addr dummyAddr) Network() string { return "dummy" }
+func (addr dummyAddr) String() string  { return "dummy" }
+
 func main() {
+	var dohURL string
 	var udpAddr string
 
 	flag.Usage = func() {
-		fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s -udp ADDR DOMAIN LOCALADDR\n", os.Args[0])
+		fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [-doh URL|-udp ADDR] DOMAIN LOCALADDR\n", os.Args[0])
 		flag.PrintDefaults()
 	}
+	flag.StringVar(&dohURL, "doh", "", "URL of DoH resolver")
 	flag.StringVar(&udpAddr, "udp", "", "address of UDP DNS resolver")
 	flag.Parse()
 
@@ -245,6 +252,11 @@ func main() {
 		s string
 		f func(string) (net.Addr, net.PacketConn, error)
 	}{
+		// -doh
+		{dohURL, func(s string) (net.Addr, net.PacketConn, error) {
+			c, err := NewDoHPacketConn(dohURL, domain)
+			return dummyAddr{}, c, err
+		}},
 		// -udp
 		{udpAddr, func(s string) (net.Addr, net.PacketConn, error) {
 			addr, err := net.ResolveUDPAddr("udp", s)
@@ -262,7 +274,7 @@ func main() {
 			continue
 		}
 		if pconn != nil {
-			fmt.Fprintf(os.Stderr, "the -udp option may be given only once\n")
+			fmt.Fprintf(os.Stderr, "only one of -doh and -udp may be given\n")
 			os.Exit(1)
 		}
 		a, c, err := opt.f(opt.s)
@@ -274,7 +286,7 @@ func main() {
 		pconn = c
 	}
 	if pconn == nil {
-		fmt.Fprintf(os.Stderr, "the -udp option is required\n")
+		fmt.Fprintf(os.Stderr, "one of -doh or -udp is required\n")
 		os.Exit(1)
 	}
 

+ 4 - 0
dnstt-client/udp.go

@@ -72,6 +72,10 @@ func (c *UDPPacketConn) recvLoop(udpConn net.PacketConn) error {
 			case c.pollChan <- struct{}{}:
 			default:
 			}
+			select {
+			case c.pollChan <- struct{}{}:
+			default:
+			}
 		}
 
 		// Pull out the packets contained in the payload.