|
|
@@ -14,18 +14,12 @@ import (
|
|
|
"github.com/xtls/xray-core/transport/internet"
|
|
|
)
|
|
|
|
|
|
-const udpBufferSize = 1700 // max MTU for WireGuard
|
|
|
-
|
|
|
-var bufferPool = sync.Pool{
|
|
|
- New: func() any {
|
|
|
- return make([]byte, udpBufferSize)
|
|
|
- },
|
|
|
-}
|
|
|
-
|
|
|
-// netReadInfo holds the result of a read operation from a specific endpoint
|
|
|
type netReadInfo struct {
|
|
|
+ // status
|
|
|
+ waiter sync.WaitGroup
|
|
|
+ // param
|
|
|
+ buff []byte
|
|
|
// result
|
|
|
- buff []byte
|
|
|
bytes int
|
|
|
endpoint conn.Endpoint
|
|
|
err error
|
|
|
@@ -36,8 +30,8 @@ type netBind struct {
|
|
|
dns dns.Client
|
|
|
dnsOption dns.IPOption
|
|
|
|
|
|
- workers int
|
|
|
- responseRecv chan *netReadInfo // responses from all endpoints flow through here
|
|
|
+ workers int
|
|
|
+ readQueue chan *netReadInfo
|
|
|
}
|
|
|
|
|
|
// SetMark implements conn.Bind
|
|
|
@@ -85,7 +79,7 @@ func (bind *netBind) BatchSize() int {
|
|
|
|
|
|
// Open implements conn.Bind
|
|
|
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
|
- bind.responseRecv = make(chan *netReadInfo)
|
|
|
+ bind.readQueue = make(chan *netReadInfo)
|
|
|
|
|
|
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
|
|
defer func() {
|
|
|
@@ -95,16 +89,14 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
- r, ok := <-bind.responseRecv
|
|
|
+ r, ok := <-bind.readQueue
|
|
|
if !ok {
|
|
|
return 0, errors.New("channel closed")
|
|
|
}
|
|
|
|
|
|
copy(bufs[0], r.buff[:r.bytes])
|
|
|
- sizes[0] = r.bytes
|
|
|
- eps[0] = r.endpoint
|
|
|
- // Return buffer to pool
|
|
|
- bufferPool.Put(r.buff)
|
|
|
+ sizes[0], eps[0] = r.bytes, r.endpoint
|
|
|
+ r.waiter.Done()
|
|
|
return 1, r.err
|
|
|
}
|
|
|
workers := bind.workers
|
|
|
@@ -121,8 +113,8 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
|
|
|
|
// Close implements conn.Bind
|
|
|
func (bind *netBind) Close() error {
|
|
|
- if bind.responseRecv != nil {
|
|
|
- close(bind.responseRecv)
|
|
|
+ if bind.readQueue != nil {
|
|
|
+ close(bind.readQueue)
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
@@ -142,12 +134,12 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
|
|
}
|
|
|
endpoint.conn = c
|
|
|
|
|
|
- go func(responseRecv chan<- *netReadInfo, endpoint *netEndpoint, c net.Conn) {
|
|
|
+ go func(readQueue chan<- *netReadInfo, endpoint *netEndpoint) {
|
|
|
defer func() {
|
|
|
- _ = recover() // gracefully handle send on closed channel
|
|
|
+ _ = recover() // handle send on closed channel
|
|
|
}()
|
|
|
for {
|
|
|
- buff := bufferPool.Get().([]byte)
|
|
|
+ buff := make([]byte, 1700)
|
|
|
i, err := c.Read(buff)
|
|
|
|
|
|
if i > 3 {
|
|
|
@@ -156,18 +148,21 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
|
|
buff[3] = 0
|
|
|
}
|
|
|
|
|
|
- responseRecv <- &netReadInfo{
|
|
|
+ r := &netReadInfo{
|
|
|
buff: buff,
|
|
|
bytes: i,
|
|
|
endpoint: endpoint,
|
|
|
err: err,
|
|
|
}
|
|
|
+ r.waiter.Add(1)
|
|
|
+ readQueue <- r
|
|
|
+ r.waiter.Wait()
|
|
|
if err != nil {
|
|
|
endpoint.conn = nil
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
- }(bind.responseRecv, endpoint, c)
|
|
|
+ }(bind.readQueue, endpoint)
|
|
|
|
|
|
return nil
|
|
|
}
|