Просмотр исходного кода

Bundle multiple packets downstream.

David Fifield 6 лет назад
Родитель
Сommit
8aace665d5
1 измененных файлов с 54 добавлено и 16 удалено
  1. 54 16
      dnstt-server/main.go

+ 54 - 16
dnstt-server/main.go

@@ -250,7 +250,7 @@ type record struct {
 }
 
 func loop(dnsConn net.PacketConn, domain dns.Name, ttConn *turbotunnel.QueuePacketConn) error {
-	ch := make(chan record, 100)
+	ch := make(chan *record, 100)
 	defer close(ch)
 
 	go func() {
@@ -263,7 +263,7 @@ func loop(dnsConn net.PacketConn, domain dns.Name, ttConn *turbotunnel.QueuePack
 	return recvLoop(domain, dnsConn, ttConn, ch)
 }
 
-func recvLoop(domain dns.Name, dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch chan<- record) error {
+func recvLoop(domain dns.Name, dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch chan<- *record) error {
 	for {
 		// One byte longer than we want, to check for truncation.
 		var buf [513]byte
@@ -291,7 +291,7 @@ func recvLoop(domain dns.Name, dnsConn net.PacketConn, ttConn *turbotunnel.Queue
 		// If a response is called for, pass it to sendLoop via the channel.
 		if resp != nil {
 			select {
-			case ch <- record{resp, addr, clientID}:
+			case ch <- &record{resp, addr, clientID}:
 			default:
 			}
 		}
@@ -308,30 +308,68 @@ func recvLoop(domain dns.Name, dnsConn net.PacketConn, ttConn *turbotunnel.Queue
 	}
 }
 
-func sendLoop(dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch <-chan record) error {
-	for rec := range ch {
+func sendLoop(dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch <-chan *record) error {
+	var nextRec *record
+	var nextP []byte
+	for {
+		rec := nextRec
+		nextRec = nil
+
+		if rec == nil {
+			var ok bool
+			rec, ok = <-ch
+			if !ok {
+				break
+			}
+		}
+
 		if rec.Resp.Rcode() == dns.RcodeNoError && len(rec.Resp.Question) == 1 {
 			// If it's a non-error response, we can fill the Answer
 			// section with downstream packets.
-			// TODO: can bundle multiple packets here.
-			var buf bytes.Buffer
-			select {
-			case p := <-ttConn.OutgoingQueue(rec.ClientID):
-				n := uint16(len(p))
-				if int(n) != len(p) {
-					panic(len(p))
+			var payload bytes.Buffer
+
+			limit := 4096
+			if len(nextP) > 0 {
+				limit -= 2 + len(nextP)
+				binary.Write(&payload, binary.BigEndian, uint16(len(nextP)))
+				payload.Write(nextP)
+			}
+			nextP = nil
+
+		loop:
+			for {
+				select {
+				case p := <-ttConn.OutgoingQueue(rec.ClientID):
+					if int(uint16(len(p))) != len(p) {
+						panic(len(p))
+					}
+					if 2+len(p) > limit {
+						// Save this packet to send in
+						// the next response.
+						nextP = p
+						break loop
+					}
+					limit -= 2 + len(p)
+					binary.Write(&payload, binary.BigEndian, uint16(len(p)))
+					payload.Write(p)
+				case nextRec = <-ch:
+					// If there's another response waiting
+					// to be sent, wait no longer for a
+					// payload for this one.
+					break loop
+				default:
+					break loop
 				}
-				binary.Write(&buf, binary.BigEndian, n)
-				buf.Write(p)
-			default:
 			}
+
 			rec.Resp.Answer = append(rec.Resp.Answer, dns.RR{
 				Name: rec.Resp.Question[0].Name,
 				Type: dns.RRTypeTXT,
 				TTL:  responseTTL,
-				Data: dns.EncodeRDataTXT(buf.Bytes()),
+				Data: dns.EncodeRDataTXT(payload.Bytes()),
 			})
 		}
+
 		buf, err := rec.Resp.WireFormat()
 		if err != nil {
 			log.Printf("resp WireFormat: %v", err)