Răsfoiți Sursa

Packetization and padding in the client→server direction.

David Fifield 6 ani în urmă
părinte
comite
933f17de78
2 a modificat fișierele cu 58 adăugiri și 9 ștergeri
  1. 22 4
      dnstt-client/main.go
  2. 36 5
      dnstt-server/main.go

+ 22 - 4
dnstt-client/main.go

@@ -24,6 +24,8 @@ const (
 	initPollDelay       = 100 * time.Millisecond
 	maxPollDelay        = 10 * time.Second
 	pollDelayMultiplier = 2.0
+	// How many bytes of random padding to insert into queries.
+	numPadding = 3
 )
 
 // A base32 encoding without padding.
@@ -96,9 +98,25 @@ func NewDNSPacketConn(conn net.PacketConn, addr net.Addr, domain dns.Name) *DNSP
 
 // send sends a single packet in a DNS query.
 func (c *DNSPacketConn) send(p []byte, addr net.Addr) error {
-	p = bytes.Join([][]byte{c.clientID[:], p}, nil)
-	encoded := make([]byte, base32Encoding.EncodedLen(len(p)))
-	base32Encoding.Encode(encoded, p)
+	var decoded []byte
+	{
+		if len(p) >= 224 {
+			return fmt.Errorf("too long")
+		}
+		var buf bytes.Buffer
+		// ClientID
+		buf.Write(c.clientID[:])
+		// Padding / cache inhibition
+		buf.WriteByte(224 + numPadding)
+		io.CopyN(&buf, rand.Reader, numPadding)
+		// Packet contents
+		buf.WriteByte(byte(len(p)))
+		buf.Write(p)
+		decoded = buf.Bytes()
+	}
+
+	encoded := make([]byte, base32Encoding.EncodedLen(len(decoded)))
+	base32Encoding.Encode(encoded, decoded)
 	labels := chunks(encoded, 63)
 	labels = append(labels, c.domain...)
 	name, err := dns.NewName(labels)
@@ -291,7 +309,7 @@ func run(domain dns.Name, localAddr, udpAddr string) error {
 			0, // default resend
 			1, // nc=1 => congestion window off
 		)
-		mtu := dnsNameCapacity(domain) - 8 // clientid
+		mtu := dnsNameCapacity(domain) - 8 - 1 - numPadding - 1 // clientid + padding length prefix + padding + data length prefix
 		if mtu < 80 {
 			return fmt.Errorf("domain %s leaves only %d bytes for payload", domain, mtu)
 		}

+ 36 - 5
dnstt-server/main.go

@@ -6,6 +6,7 @@ import (
 	"flag"
 	"fmt"
 	"io"
+	"io/ioutil"
 	"net"
 	"os"
 	"sync"
@@ -119,6 +120,33 @@ func acceptSessions(ln *kcp.Listener, mtu int, upstream *net.TCPAddr) error {
 	}
 }
 
+func nextPacket(r *bytes.Reader) ([]byte, error) {
+	eof := func(err error) error {
+		if err == io.EOF {
+			err = io.ErrUnexpectedEOF
+		}
+		return err
+	}
+
+	for {
+		prefix, err := r.ReadByte()
+		if err != nil {
+			return nil, err
+		}
+		if prefix >= 224 {
+			paddingLen := prefix - 224
+			_, err := io.CopyN(ioutil.Discard, r, int64(paddingLen))
+			if err != nil {
+				return nil, eof(err)
+			}
+			continue
+		}
+		p := make([]byte, int(prefix))
+		_, err = io.ReadFull(r, p)
+		return p, eof(err)
+	}
+}
+
 func responseFor(query *dns.Message, domain dns.Name, ttConn *turbotunnel.QueuePacketConn) *dns.Message {
 	resp := &dns.Message{
 		ID:       query.ID,
@@ -172,13 +200,16 @@ func responseFor(query *dns.Message, domain dns.Name, ttConn *turbotunnel.QueueP
 		resp.Flags |= dns.RcodeNameError
 		return resp
 	}
-	p := payload[len(clientID):]
 
-	// Feed the incoming packet to KCP. If there is nothing after the
-	// conversation ID, this is an empty polling request and we don't need
-	// to give it to KCP.
-	if len(p) > 0 {
+	// Discard padding and pull out the packets contained in the payload.
+	buf := bytes.NewReader(payload[len(clientID):])
+	for {
+		p, err := nextPacket(buf)
+		// Feed the incoming packet to KCP.
 		ttConn.QueueIncoming(p, clientID)
+		if err != nil {
+			break
+		}
 	}
 
 	// Send a downstream packet if any is available.