|
|
@@ -212,91 +212,53 @@ func (t *Handler) cleanupUDPConns() error {
|
|
|
|
|
|
// writeUDPPacket writes a UDP packet back to the gVisor stack with custom source address
|
|
|
func (t *Handler) writeUDPPacket(ipStack *stack.Stack, data []byte, dest, source net.Destination) (int, error) {
|
|
|
- // Build UDP+IP packet with proper headers using gVisor's header builders
|
|
|
-
|
|
|
- // Determine IP version
|
|
|
- var ipHdrLen, udpHdrLen int
|
|
|
- isIPv4 := dest.Address.Family().IsIPv4()
|
|
|
-
|
|
|
- if isIPv4 {
|
|
|
- ipHdrLen = header.IPv4MinimumSize
|
|
|
+ // Create a route from dest (our local) to source (remote)
|
|
|
+ var netProto tcpip.NetworkProtocolNumber
|
|
|
+ if dest.Address.Family().IsIPv4() {
|
|
|
+ netProto = header.IPv4ProtocolNumber
|
|
|
} else {
|
|
|
- ipHdrLen = header.IPv6MinimumSize
|
|
|
+ netProto = header.IPv6ProtocolNumber
|
|
|
}
|
|
|
- udpHdrLen = header.UDPMinimumSize
|
|
|
|
|
|
- totalLen := ipHdrLen + udpHdrLen + len(data)
|
|
|
- packet := make([]byte, totalLen)
|
|
|
+ route, err := ipStack.FindRoute(
|
|
|
+ defaultNIC,
|
|
|
+ tcpip.AddrFromSlice(dest.Address.IP()),
|
|
|
+ tcpip.AddrFromSlice(source.Address.IP()),
|
|
|
+ netProto,
|
|
|
+ false,
|
|
|
+ )
|
|
|
+ if err != nil {
|
|
|
+ return 0, errors.New("failed to find route: " + err.String())
|
|
|
+ }
|
|
|
+ defer route.Release()
|
|
|
+
|
|
|
+ // Create packet buffer with UDP payload
|
|
|
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
|
+ ReserveHeaderBytes: header.UDPMinimumSize,
|
|
|
+ Payload: buffer.MakeWithData(data),
|
|
|
+ })
|
|
|
+ defer pkt.DecRef()
|
|
|
|
|
|
// Build UDP header
|
|
|
- udpHeader := header.UDP(packet[ipHdrLen:])
|
|
|
+ udpHeader := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
|
|
+ length := uint16(pkt.Size())
|
|
|
udpHeader.Encode(&header.UDPFields{
|
|
|
- SrcPort: uint16(dest.Port), // Source is the original destination
|
|
|
- DstPort: uint16(source.Port), // Destination is the original source
|
|
|
- Length: uint16(udpHdrLen + len(data)),
|
|
|
+ SrcPort: uint16(dest.Port),
|
|
|
+ DstPort: uint16(source.Port),
|
|
|
+ Length: length,
|
|
|
})
|
|
|
|
|
|
- // Copy payload
|
|
|
- copy(packet[ipHdrLen+udpHdrLen:], data)
|
|
|
-
|
|
|
- // Build IP header and calculate checksums
|
|
|
- if isIPv4 {
|
|
|
- ipv4Header := header.IPv4(packet)
|
|
|
- ipv4Header.Encode(&header.IPv4Fields{
|
|
|
- TOS: 0,
|
|
|
- TotalLength: uint16(totalLen),
|
|
|
- ID: 0,
|
|
|
- Flags: 0,
|
|
|
- FragmentOffset: 0,
|
|
|
- TTL: 64,
|
|
|
- Protocol: uint8(header.UDPProtocolNumber),
|
|
|
- SrcAddr: tcpip.AddrFromSlice(dest.Address.IP()),
|
|
|
- DstAddr: tcpip.AddrFromSlice(source.Address.IP()),
|
|
|
- })
|
|
|
- ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum())
|
|
|
-
|
|
|
- // Calculate UDP checksum
|
|
|
- xsum := header.PseudoHeaderChecksum(
|
|
|
- header.UDPProtocolNumber,
|
|
|
- tcpip.AddrFromSlice(dest.Address.IP()),
|
|
|
- tcpip.AddrFromSlice(source.Address.IP()),
|
|
|
- uint16(udpHdrLen+len(data)),
|
|
|
- )
|
|
|
- xsum = checksum.Checksum(data, xsum)
|
|
|
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
|
|
|
- } else {
|
|
|
- ipv6Header := header.IPv6(packet)
|
|
|
- ipv6Header.Encode(&header.IPv6Fields{
|
|
|
- TrafficClass: 0,
|
|
|
- FlowLabel: 0,
|
|
|
- PayloadLength: uint16(udpHdrLen + len(data)),
|
|
|
- TransportProtocol: header.UDPProtocolNumber,
|
|
|
- HopLimit: 64,
|
|
|
- SrcAddr: tcpip.AddrFromSlice(dest.Address.IP()),
|
|
|
- DstAddr: tcpip.AddrFromSlice(source.Address.IP()),
|
|
|
- })
|
|
|
-
|
|
|
- // Calculate UDP checksum for IPv6
|
|
|
- xsum := header.PseudoHeaderChecksum(
|
|
|
- header.UDPProtocolNumber,
|
|
|
- tcpip.AddrFromSlice(dest.Address.IP()),
|
|
|
- tcpip.AddrFromSlice(source.Address.IP()),
|
|
|
- uint16(udpHdrLen+len(data)),
|
|
|
- )
|
|
|
- xsum = checksum.Checksum(data, xsum)
|
|
|
- udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
|
|
|
- }
|
|
|
-
|
|
|
- // Write packet to stack
|
|
|
- var proto tcpip.NetworkProtocolNumber
|
|
|
- if isIPv4 {
|
|
|
- proto = header.IPv4ProtocolNumber
|
|
|
- } else {
|
|
|
- proto = header.IPv6ProtocolNumber
|
|
|
- }
|
|
|
+ // Calculate checksum
|
|
|
+ xsum := route.PseudoHeaderChecksum(header.UDPProtocolNumber, length)
|
|
|
+ xsum = checksum.Checksum(data, xsum)
|
|
|
+ udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
|
|
|
|
|
|
- buf := buffer.MakeWithData(packet)
|
|
|
- if err := ipStack.WriteRawPacket(defaultNIC, proto, buf); err != nil {
|
|
|
+ // Write packet through route
|
|
|
+ if err := route.WritePacket(stack.NetworkHeaderParams{
|
|
|
+ Protocol: header.UDPProtocolNumber,
|
|
|
+ TTL: 64,
|
|
|
+ TOS: 0,
|
|
|
+ }, pkt); err != nil {
|
|
|
return 0, errors.New("failed to write packet: " + err.String())
|
|
|
}
|
|
|
|