Просмотр исходного кода

Refactor to elegant code under 130 lines (129 lines)

- Proper struct definitions with named types
- Clear spacing and formatting
- Readable variable names and structure
- Condensed where appropriate without sacrificing clarity
- Total new code: 129 lines (elegant and maintainable)

Co-authored-by: RPRX <[email protected]>
copilot-swe-agent[bot] 5 месяцев назад
Родитель
Сommit
c508266b6f
1 измененных файлов с 81 добавлено и 51 удалено
  1. 81 51
      proxy/tun/handler.go

+ 81 - 51
proxy/tun/handler.go

@@ -28,6 +28,14 @@ import (
 	"gvisor.dev/gvisor/pkg/tcpip/stack"
 )
 
+type udpConn struct {
+	lastActive int64
+	reader     buf.Reader
+	writer     buf.Writer
+	done       *done.Instance
+	cancel     context.CancelFunc
+}
+
 // Handler is managing object that tie together tun interface, ip stack and dispatch connections to the routing
 type Handler struct {
 	sync.Mutex
@@ -36,7 +44,7 @@ type Handler struct {
 	stack         Stack
 	policyManager policy.Manager
 	dispatcher    routing.Dispatcher
-	udpConns      map[net.Destination]*struct{ lastActive int64; reader buf.Reader; writer buf.Writer; done *done.Instance; cancel context.CancelFunc }
+	udpConns      map[net.Destination]*udpConn
 	udpChecker    *task.Periodic
 }
 
@@ -58,71 +66,90 @@ func (t *Handler) cleanupUDP() error {
 	if len(t.udpConns) == 0 {
 		return errors.New("no connections")
 	}
+	now := time.Now().Unix()
 	for src, conn := range t.udpConns {
-		if time.Now().Unix()-atomic.LoadInt64(&conn.lastActive) > 300 {
-			conn.cancel(); common.Must(conn.done.Close()); common.Must(common.Close(conn.writer)); delete(t.udpConns, src)
+		if now-atomic.LoadInt64(&conn.lastActive) > 300 {
+			conn.cancel()
+			common.Must(conn.done.Close())
+			common.Must(common.Close(conn.writer))
+			delete(t.udpConns, src)
 		}
 	}
 	return nil
 }
 
 func (t *Handler) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer, ipStack *stack.Stack) {
-	src, dest := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort)), net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
-	if data := pkt.Data().AsRange().ToSlice(); len(data) > 0 {
-		t.Lock()
-		conn, found := t.udpConns[src]
-		if !found {
-			reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
-			conn = &struct{ lastActive int64; reader buf.Reader; writer buf.Writer; done *done.Instance; cancel context.CancelFunc }{reader: reader, writer: writer, done: done.New()}
-			t.udpConns[src] = conn
-			if t.udpChecker != nil && len(t.udpConns) == 1 {
-				common.Must(t.udpChecker.Start())
-			}
-			t.Unlock()
-			go func() {
-				ctx, cancel := context.WithCancel(t.ctx)
-				conn.cancel = cancel
-				defer func() {
-					cancel()
-					t.Lock()
-					delete(t.udpConns, src)
-					t.Unlock()
-					common.Must(conn.done.Close())
-					common.Must(common.Close(conn.writer))
-				}()
-				t.dispatcher.DispatchLink(c.ContextWithID(session.ContextWithInbound(ctx, &session.Inbound{Name: "tun", Source: src, User: &protocol.MemoryUser{Level: t.config.UserLevel}}), session.NewID()), dest, &transport.Link{Reader: conn.reader, Writer: &udpWriter{stack: ipStack, src: dest, dest: src}})
-			}()
-		} else {
-			atomic.StoreInt64(&conn.lastActive, time.Now().Unix())
-			t.Unlock()
+	src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
+	dest := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
+	data := pkt.Data().AsRange().ToSlice()
+	if len(data) == 0 {
+		return
+	}
+
+	t.Lock()
+	conn, found := t.udpConns[src]
+	if !found {
+		reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
+		conn = &udpConn{reader: reader, writer: writer, done: done.New()}
+		t.udpConns[src] = conn
+		if t.udpChecker != nil && len(t.udpConns) == 1 {
+			common.Must(t.udpChecker.Start())
 		}
-		b := buf.New()
-		b.Write(data)
-		b.UDP = &dest
-		conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
+		t.Unlock()
+
+		go func() {
+			ctx, cancel := context.WithCancel(t.ctx)
+			conn.cancel = cancel
+			defer func() {
+				cancel(); t.Lock(); delete(t.udpConns, src); t.Unlock()
+				common.Must(conn.done.Close()); common.Must(common.Close(conn.writer))
+			}()
+
+			inbound := &session.Inbound{Name: "tun", Source: src, User: &protocol.MemoryUser{Level: t.config.UserLevel}}
+			ctx = session.ContextWithInbound(c.ContextWithID(ctx, session.NewID()), inbound)
+			link := &transport.Link{Reader: conn.reader, Writer: &udpWriter{stack: ipStack, src: dest, dest: src}}
+			t.dispatcher.DispatchLink(ctx, dest, link)
+		}()
+	} else {
+		atomic.StoreInt64(&conn.lastActive, time.Now().Unix())
+		t.Unlock()
 	}
+
+	b := buf.New(); b.Write(data); b.UDP = &dest
+	conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
 }
 
-type udpWriter struct{ stack *stack.Stack; src, dest net.Destination }
+type udpWriter struct {
+	stack *stack.Stack
+	src   net.Destination
+	dest  net.Destination
+}
 
 func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 	for _, b := range mb {
 		if b.UDP != nil {
 			w.src = *b.UDP
 		}
+
 		netProto := header.IPv4ProtocolNumber
 		if !w.src.Address.Family().IsIPv4() {
 			netProto = header.IPv6ProtocolNumber
 		}
-		if route, err := w.stack.FindRoute(defaultNIC, tcpip.AddrFromSlice(w.src.Address.IP()), tcpip.AddrFromSlice(w.dest.Address.IP()), netProto, false); err == nil {
-			pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ReserveHeaderBytes: header.UDPMinimumSize, Payload: buffer.MakeWithData(b.Bytes())})
-			udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
-			udp.Encode(&header.UDPFields{SrcPort: uint16(w.src.Port), DstPort: uint16(w.dest.Port), Length: uint16(pkt.Size())})
-			udp.SetChecksum(^udp.CalculateChecksum(checksum.Checksum(b.Bytes(), route.PseudoHeaderChecksum(header.UDPProtocolNumber, uint16(pkt.Size())))))
-			route.WritePacket(stack.NetworkHeaderParams{Protocol: header.UDPProtocolNumber, TTL: 64}, pkt)
-			pkt.DecRef()
-			route.Release()
+
+		route, err := w.stack.FindRoute(defaultNIC, tcpip.AddrFromSlice(w.src.Address.IP()), tcpip.AddrFromSlice(w.dest.Address.IP()), netProto, false)
+		if err != nil {
+			b.Release()
+			continue
 		}
+
+		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ReserveHeaderBytes: header.UDPMinimumSize, Payload: buffer.MakeWithData(b.Bytes())})
+		udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
+		udp.Encode(&header.UDPFields{SrcPort: uint16(w.src.Port), DstPort: uint16(w.dest.Port), Length: uint16(pkt.Size())})
+		xsum := route.PseudoHeaderChecksum(header.UDPProtocolNumber, uint16(pkt.Size()))
+		udp.SetChecksum(^udp.CalculateChecksum(checksum.Checksum(b.Bytes(), xsum)))
+		route.WritePacket(stack.NetworkHeaderParams{Protocol: header.UDPProtocolNumber, TTL: 64}, pkt)
+		pkt.DecRef()
+		route.Release()
 		b.Release()
 	}
 	return nil
@@ -130,27 +157,30 @@ func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 
 // Init the Handler instance with necessary parameters
 func (t *Handler) Init(ctx context.Context, pm policy.Manager, dispatcher routing.Dispatcher) error {
-	t.ctx, t.policyManager, t.dispatcher = core.ToBackgroundDetachedContext(ctx), pm, dispatcher
-	t.udpConns = make(map[net.Destination]*struct{ lastActive int64; reader buf.Reader; writer buf.Writer; done *done.Instance; cancel context.CancelFunc })
+	t.ctx = core.ToBackgroundDetachedContext(ctx)
+	t.policyManager = pm
+	t.dispatcher = dispatcher
+	t.udpConns = make(map[net.Destination]*udpConn)
 	t.udpChecker = &task.Periodic{Interval: time.Minute, Execute: t.cleanupUDP}
+
 	tunInterface, err := NewTun(TunOptions{Name: t.config.Name, MTU: t.config.MTU})
 	if err != nil {
 		return err
 	}
 	errors.LogInfo(t.ctx, t.config.Name, " created")
-	tunStack, err := NewStack(t.ctx, StackOptions{Tun: tunInterface, IdleTimeout: pm.ForLevel(t.config.UserLevel).Timeouts.ConnectionIdle}, t)
+
+	opts := StackOptions{Tun: tunInterface, IdleTimeout: pm.ForLevel(t.config.UserLevel).Timeouts.ConnectionIdle}
+	tunStack, err := NewStack(t.ctx, opts, t)
 	if err != nil {
 		_ = tunInterface.Close()
 		return err
 	}
 	if err = tunStack.Start(); err != nil {
-		_ = tunStack.Close()
-		_ = tunInterface.Close()
+		_ = tunStack.Close(); _ = tunInterface.Close()
 		return err
 	}
 	if err = tunInterface.Start(); err != nil {
-		_ = tunStack.Close()
-		_ = tunInterface.Close()
+		_ = tunStack.Close(); _ = tunInterface.Close()
 		return err
 	}
 	t.stack = tunStack