|
|
@@ -3,6 +3,7 @@ package wireguard
|
|
|
import (
|
|
|
"context"
|
|
|
"fmt"
|
|
|
+ "io"
|
|
|
"net/netip"
|
|
|
"runtime"
|
|
|
"strconv"
|
|
|
@@ -10,12 +11,17 @@ import (
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
+ "github.com/xtls/xray-core/common/buf"
|
|
|
"github.com/xtls/xray-core/common/errors"
|
|
|
"github.com/xtls/xray-core/common/log"
|
|
|
"github.com/xtls/xray-core/common/net"
|
|
|
"github.com/xtls/xray-core/proxy/wireguard/gvisortun"
|
|
|
+ "gvisor.dev/gvisor/pkg/buffer"
|
|
|
"gvisor.dev/gvisor/pkg/tcpip"
|
|
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
|
|
+ "gvisor.dev/gvisor/pkg/tcpip/checksum"
|
|
|
+ "gvisor.dev/gvisor/pkg/tcpip/header"
|
|
|
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
|
|
"gvisor.dev/gvisor/pkg/waiter"
|
|
|
@@ -138,7 +144,7 @@ func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, erro
|
|
|
|
|
|
func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) {
|
|
|
out := &gvisorNet{}
|
|
|
- tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
|
|
|
+ tun, n, gstack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -147,60 +153,236 @@ func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousMo
|
|
|
// handler is only used for promiscuous mode
|
|
|
// capture all packets and send to handler
|
|
|
|
|
|
- tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) {
|
|
|
+ tcpForwarder := tcp.NewForwarder(gstack, 0, 65535, func(r *tcp.ForwarderRequest) {
|
|
|
go func(r *tcp.ForwarderRequest) {
|
|
|
- var (
|
|
|
- wq waiter.Queue
|
|
|
- id = r.ID()
|
|
|
- )
|
|
|
+ var wq waiter.Queue
|
|
|
+ var id = r.ID()
|
|
|
|
|
|
- // Perform a TCP three-way handshake.
|
|
|
ep, err := r.CreateEndpoint(&wq)
|
|
|
if err != nil {
|
|
|
errors.LogError(context.Background(), err.String())
|
|
|
r.Complete(true)
|
|
|
return
|
|
|
}
|
|
|
- r.Complete(false)
|
|
|
- defer ep.Close()
|
|
|
|
|
|
- // enable tcp keep-alive to prevent hanging connections
|
|
|
- ep.SocketOptions().SetKeepAlive(true)
|
|
|
+ options := ep.SocketOptions()
|
|
|
+ options.SetKeepAlive(false)
|
|
|
+ options.SetReuseAddress(true)
|
|
|
+ options.SetReusePort(true)
|
|
|
|
|
|
- // local address is actually destination
|
|
|
handler(net.TCPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep))
|
|
|
+
|
|
|
+ ep.Close()
|
|
|
+ r.Complete(false)
|
|
|
}(r)
|
|
|
})
|
|
|
- stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
|
|
-
|
|
|
- udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) bool {
|
|
|
- go func(r *udp.ForwarderRequest) {
|
|
|
- var (
|
|
|
- wq waiter.Queue
|
|
|
- id = r.ID()
|
|
|
- )
|
|
|
-
|
|
|
- ep, err := r.CreateEndpoint(&wq)
|
|
|
- if err != nil {
|
|
|
- errors.LogError(context.Background(), err.String())
|
|
|
- return
|
|
|
- }
|
|
|
- defer ep.Close()
|
|
|
+ gstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
|
|
|
|
|
- // prevents hanging connections and ensure timely release
|
|
|
- ep.SocketOptions().SetLinger(tcpip.LingerOption{
|
|
|
- Enabled: true,
|
|
|
- Timeout: 15 * time.Second,
|
|
|
- })
|
|
|
-
|
|
|
- handler(net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)), gonet.NewUDPConn(&wq, ep))
|
|
|
- }(r)
|
|
|
+ manager := &udpManager{
|
|
|
+ stack: gstack,
|
|
|
+ handler: handler,
|
|
|
+ m: make(map[string]*udpConn),
|
|
|
+ }
|
|
|
|
|
|
+ gstack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
|
|
+ data := pkt.Clone().Data().AsRange().ToSlice()
|
|
|
+ // if len(data) == 0 {
|
|
|
+ // return false
|
|
|
+ // }
|
|
|
+ src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
|
|
|
+ dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
|
|
|
+ manager.feed(src, dst, data)
|
|
|
return true
|
|
|
})
|
|
|
- stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
|
|
}
|
|
|
|
|
|
out.tun, out.net = tun, n
|
|
|
return out, nil
|
|
|
}
|
|
|
+
|
|
|
+type udpManager struct {
|
|
|
+ stack *stack.Stack
|
|
|
+ handler func(dest net.Destination, conn net.Conn)
|
|
|
+ m map[string]*udpConn
|
|
|
+ mutex sync.RWMutex
|
|
|
+}
|
|
|
+
|
|
|
+func (m *udpManager) feed(src net.Destination, dst net.Destination, data []byte) {
|
|
|
+ m.mutex.RLock()
|
|
|
+ uc, ok := m.m[src.NetAddr()]
|
|
|
+ if ok {
|
|
|
+ select {
|
|
|
+ case uc.ch <- data:
|
|
|
+ default:
|
|
|
+ }
|
|
|
+ m.mutex.RUnlock()
|
|
|
+ return
|
|
|
+ }
|
|
|
+ m.mutex.RUnlock()
|
|
|
+
|
|
|
+ m.mutex.Lock()
|
|
|
+ defer m.mutex.Unlock()
|
|
|
+
|
|
|
+ uc, ok = m.m[src.NetAddr()]
|
|
|
+ if !ok {
|
|
|
+ uc = &udpConn{
|
|
|
+ ch: make(chan []byte, 1024),
|
|
|
+ src: src,
|
|
|
+ dst: dst,
|
|
|
+ }
|
|
|
+ uc.writeFunc = m.writeRawUDPPacket
|
|
|
+ uc.closeFunc = func() {
|
|
|
+ m.mutex.Lock()
|
|
|
+ m.close(uc)
|
|
|
+ m.mutex.Unlock()
|
|
|
+ }
|
|
|
+ m.m[src.NetAddr()] = uc
|
|
|
+ go m.handler(dst, uc)
|
|
|
+ }
|
|
|
+
|
|
|
+ select {
|
|
|
+ case uc.ch <- data:
|
|
|
+ default:
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (m *udpManager) close(uc *udpConn) {
|
|
|
+ if !uc.closed {
|
|
|
+ uc.closed = true
|
|
|
+ close(uc.ch)
|
|
|
+ delete(m.m, uc.src.NetAddr())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (m *udpManager) writeRawUDPPacket(payload []byte, src net.Destination, dst net.Destination) error {
|
|
|
+ udpLen := header.UDPMinimumSize + len(payload)
|
|
|
+ srcIP := tcpip.AddrFromSlice(src.Address.IP())
|
|
|
+ dstIP := tcpip.AddrFromSlice(dst.Address.IP())
|
|
|
+
|
|
|
+ // build packet with appropriate IP header size
|
|
|
+ isIPv4 := dst.Address.Family().IsIPv4()
|
|
|
+ ipHdrSize := header.IPv6MinimumSize
|
|
|
+ ipProtocol := header.IPv6ProtocolNumber
|
|
|
+ if isIPv4 {
|
|
|
+ ipHdrSize = header.IPv4MinimumSize
|
|
|
+ ipProtocol = header.IPv4ProtocolNumber
|
|
|
+ }
|
|
|
+
|
|
|
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
|
+ ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
|
|
|
+ Payload: buffer.MakeWithData(payload),
|
|
|
+ })
|
|
|
+ defer pkt.DecRef()
|
|
|
+
|
|
|
+ // Build UDP header
|
|
|
+ udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
|
|
+ udpHdr.Encode(&header.UDPFields{
|
|
|
+ SrcPort: uint16(src.Port),
|
|
|
+ DstPort: uint16(dst.Port),
|
|
|
+ Length: uint16(udpLen),
|
|
|
+ })
|
|
|
+
|
|
|
+ // Calculate and set UDP checksum
|
|
|
+ xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
|
|
|
+ udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
|
|
|
+
|
|
|
+ // Build IP header
|
|
|
+ if isIPv4 {
|
|
|
+ ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
|
|
|
+ ipHdr.Encode(&header.IPv4Fields{
|
|
|
+ TotalLength: uint16(header.IPv4MinimumSize + udpLen),
|
|
|
+ TTL: 64,
|
|
|
+ Protocol: uint8(header.UDPProtocolNumber),
|
|
|
+ SrcAddr: srcIP,
|
|
|
+ DstAddr: dstIP,
|
|
|
+ })
|
|
|
+ ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
|
|
+ } else {
|
|
|
+ ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
|
|
|
+ ipHdr.Encode(&header.IPv6Fields{
|
|
|
+ PayloadLength: uint16(udpLen),
|
|
|
+ TransportProtocol: header.UDPProtocolNumber,
|
|
|
+ HopLimit: 64,
|
|
|
+ SrcAddr: srcIP,
|
|
|
+ DstAddr: dstIP,
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ // dispatch the packet
|
|
|
+ err := m.stack.WriteRawPacket(1, ipProtocol, buffer.MakeWithView(pkt.ToView()))
|
|
|
+ if err != nil {
|
|
|
+ return errors.New("failed to write raw udp packet back to stack err ", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+type udpConn struct {
|
|
|
+ ch chan []byte
|
|
|
+ src net.Destination
|
|
|
+ dst net.Destination
|
|
|
+ writeFunc func(payload []byte, src net.Destination, dst net.Destination) error
|
|
|
+ closeFunc func()
|
|
|
+ closed bool
|
|
|
+}
|
|
|
+
|
|
|
+func (c *udpConn) Read(p []byte) (int, error) {
|
|
|
+ b, ok := <-c.ch
|
|
|
+ if !ok {
|
|
|
+ return 0, io.EOF
|
|
|
+ }
|
|
|
+ n := copy(p, b)
|
|
|
+ if n != len(b) {
|
|
|
+ return 0, io.ErrShortBuffer
|
|
|
+ }
|
|
|
+ return n, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *udpConn) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
|
|
+ for i, b := range mb {
|
|
|
+ dst := c.dst
|
|
|
+ if b.UDP != nil {
|
|
|
+ dst = *b.UDP
|
|
|
+ }
|
|
|
+ err := c.writeFunc(b.Bytes(), dst, c.src)
|
|
|
+ if err != nil {
|
|
|
+ buf.ReleaseMulti(mb[i:])
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ b.Release()
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *udpConn) Write(p []byte) (int, error) {
|
|
|
+ err := c.writeFunc(p, c.dst, c.src)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ return len(p), nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *udpConn) Close() error {
|
|
|
+ c.closeFunc()
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *udpConn) LocalAddr() net.Addr {
|
|
|
+ return c.src.RawNetAddr() // fake
|
|
|
+}
|
|
|
+
|
|
|
+func (c *udpConn) RemoteAddr() net.Addr {
|
|
|
+ return c.src.RawNetAddr() // src
|
|
|
+}
|
|
|
+
|
|
|
+func (c *udpConn) SetDeadline(t time.Time) error {
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *udpConn) SetReadDeadline(t time.Time) error {
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *udpConn) SetWriteDeadline(t time.Time) error {
|
|
|
+ return nil
|
|
|
+}
|