Przeglądaj źródła

Factor out separate sending and receiving threads.

David Fifield 6 lat temu
rodzic
commit
2050034fb1
2 zmienionych plików z 95 dodań i 75 usunięć
  1. 5 0
      dns/dns.go
  2. 90 75
      dnstt-server/main.go

+ 5 - 0
dns/dns.go

@@ -141,6 +141,11 @@ type Message struct {
 	Additional []RR
 }
 
+// Rcode extracts the RCODE part of the Flags field.
+func (msg *Message) Rcode() uint16 {
+	return msg.Flags & 0x000f
+}
+
 // Question represents the question section of a message.
 //
 // https://tools.ietf.org/html/rfc1035#section-4.1.2

+ 90 - 75
dnstt-server/main.go

@@ -179,7 +179,9 @@ func nextPacket(r *bytes.Reader) ([]byte, error) {
 	}
 }
 
-func responseFor(query *dns.Message, domain dns.Name, ttConn *turbotunnel.QueuePacketConn) *dns.Message {
+func responseFor(query *dns.Message, domain dns.Name) (*dns.Message, turbotunnel.ClientID, []byte) {
+	var clientID turbotunnel.ClientID
+
 	resp := &dns.Message{
 		ID:       query.ID,
 		Flags:    0x8400, // QR = 1, AA = 1, RCODE = no error
@@ -188,30 +190,31 @@ func responseFor(query *dns.Message, domain dns.Name, ttConn *turbotunnel.QueueP
 
 	if query.Flags&0x8000 != 0 {
 		// QR != 0, this is not a query. Don't even send a response.
-		return nil
+		return nil, clientID, nil
 	}
 	if query.Flags&0x7800 != 0 {
 		// We don't support OPCODE != QUERY.
 		resp.Flags |= dns.RcodeNotImplemented
-		return resp
+		return resp, clientID, nil
 	}
 
 	if len(query.Question) != 1 {
 		// There must be exactly one question.
 		resp.Flags |= dns.RcodeFormatError
-		return resp
+		return resp, clientID, nil
 	}
 	question := query.Question[0]
 	if question.Type != dns.RRTypeTXT {
-		// We only support QTYPE == TXT. Send an empty response.
-		return resp
+		// We only support QTYPE == TXT.
+		resp.Flags |= dns.RcodeNotImplemented
+		return resp, clientID, nil
 	}
 
 	prefix, ok := question.Name.TrimSuffix(domain)
 	if !ok {
 		// Not a name we are authoritative for.
 		resp.Flags |= dns.RcodeNameError
-		return resp
+		return resp, clientID, nil
 	}
 
 	encoded := bytes.ToUpper(bytes.Join(prefix, nil))
@@ -220,86 +223,46 @@ func responseFor(query *dns.Message, domain dns.Name, ttConn *turbotunnel.QueueP
 	if err != nil {
 		// Base32 error, make like the name doesn't exist.
 		resp.Flags |= dns.RcodeNameError
-		return resp
+		return resp, clientID, nil
 	}
 	payload = payload[:n]
 
 	// Now extract the ClientID.
-	var clientID turbotunnel.ClientID
 	n = copy(clientID[:], payload)
 	if n < len(clientID) {
 		// Payload is not long enough to contain a ClientID.
 		resp.Flags |= dns.RcodeNameError
-		return resp
-	}
-
-	// Discard padding and pull out the packets contained in the payload.
-	buf := bytes.NewReader(payload[len(clientID):])
-	for {
-		p, err := nextPacket(buf)
-		// Feed the incoming packet to KCP.
-		ttConn.QueueIncoming(p, clientID)
-		if err != nil {
-			break
-		}
-	}
-
-	// Send a downstream packet if any is available.
-	// TODO: can bundle multiple packets here.
-	select {
-	case p := <-ttConn.OutgoingQueue(clientID):
-		resp.Answer = append(resp.Answer, dns.RR{
-			Name: question.Name,
-			Type: dns.RRTypeTXT,
-			TTL:  responseTTL,
-			Data: dns.EncodeRDataTXT(p),
-		})
-	default:
+		return resp, clientID, nil
 	}
 
-	return resp
+	return resp, clientID, payload[len(clientID):]
 }
 
-func handle(p []byte, addr net.Addr, dnsConn net.PacketConn, domain dns.Name, ttConn *turbotunnel.QueuePacketConn) error {
-	query, err := dns.MessageFromWireFormat(p)
-	if err != nil {
-		return fmt.Errorf("parsing DNS query: %v", err)
-	}
-
-	resp := responseFor(&query, domain, ttConn)
-	if resp != nil {
-		buf, err := resp.WireFormat()
-		if err != nil {
-			return err
-		}
-		_, err = dnsConn.WriteTo(buf, addr)
-		if err != nil {
-			return err
-		}
-	}
-
-	return nil
+// record represents a response set up with metadata appropriate for a response
+// to a previously received query. recvLoop sends instances of this type to
+// sendLoop via a channel. sendLoop may optionally fill in the response's Answer
+// section before sending it.
+type record struct {
+	Resp     *dns.Message
+	Addr     net.Addr
+	ClientID turbotunnel.ClientID
 }
 
 func loop(dnsConn net.PacketConn, domain dns.Name, ttConn *turbotunnel.QueuePacketConn) error {
-	type taggedPacket struct {
-		P    []byte
-		Addr net.Addr
-	}
+	ch := make(chan record, 100)
+	defer close(ch)
 
-	handleChan := make(chan taggedPacket, 64)
-	defer close(handleChan)
 	go func() {
-		for tp := range handleChan {
-			p := tp.P
-			addr := tp.Addr
-			err := handle(p, addr, dnsConn, domain, ttConn)
-			if err != nil {
-				log.Printf("handle from %v: %v\n", addr, err)
-			}
+		err := sendLoop(dnsConn, ttConn, ch)
+		if err != nil {
+			log.Printf("sendLoop: %v", err)
 		}
 	}()
 
+	return recvLoop(domain, dnsConn, ttConn, ch)
+}
+
+func recvLoop(domain dns.Name, dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch chan<- record) error {
 	for {
 		// One byte longer than we want, to check for truncation.
 		var buf [513]byte
@@ -312,18 +275,70 @@ func loop(dnsConn net.PacketConn, domain dns.Name, ttConn *turbotunnel.QueuePack
 			return err
 		}
 		if n == len(buf) {
-			log.Printf("ReadFrom: truncated packet")
+			log.Printf("%v: ReadFrom: truncated packet", addr)
+			continue
+		}
+
+		// Got a UDP packet. Try to parse it as a DNS message.
+		query, err := dns.MessageFromWireFormat(buf[:n])
+		if err != nil {
+			log.Printf("%v: parsing DNS query: %v", addr, err)
+			continue
+		}
+
+		resp, clientID, payload := responseFor(&query, domain)
+		// If a response is called for, pass it to sendLoop via the channel.
+		if resp != nil {
+			select {
+			case ch <- record{resp, addr, clientID}:
+			default:
+			}
+		}
+		// Discard padding and pull out the packets contained in the payload.
+		r := bytes.NewReader(payload)
+		for {
+			p, err := nextPacket(r)
+			if err != nil {
+				break
+			}
+			// Feed the incoming packet to KCP.
+			ttConn.QueueIncoming(p, clientID)
+		}
+	}
+}
+
+func sendLoop(dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch <-chan record) error {
+	for rec := range ch {
+		if rec.Resp.Rcode() == dns.RcodeNoError && len(rec.Resp.Question) == 1 {
+			// If it's a non-error response, we can fill the Answer
+			// section with downstream packets.
+			// TODO: can bundle multiple packets here.
+			select {
+			case p := <-ttConn.OutgoingQueue(rec.ClientID):
+				rec.Resp.Answer = append(rec.Resp.Answer, dns.RR{
+					Name: rec.Resp.Question[0].Name,
+					Type: dns.RRTypeTXT,
+					TTL:  responseTTL,
+					Data: dns.EncodeRDataTXT(p),
+				})
+			default:
+			}
+		}
+		buf, err := rec.Resp.WireFormat()
+		if err != nil {
+			log.Printf("resp WireFormat: %v", err)
 			continue
 		}
-		// Copy the packet data into its own buffer.
-		p := make([]byte, n)
-		copy(p, buf[:n])
-		select {
-		case handleChan <- taggedPacket{p, addr}:
-		default:
-			// Drop incoming packets if channel is full.
+		_, err = dnsConn.WriteTo(buf, rec.Addr)
+		if err != nil {
+			if err, ok := err.(net.Error); ok && err.Temporary() {
+				log.Printf("WriteTo temporary error: %v", err)
+				continue
+			}
+			return err
 		}
 	}
+	return nil
 }
 
 type dummyAddr struct{}