|
|
@@ -5,7 +5,6 @@ import (
|
|
|
"errors"
|
|
|
"net/netip"
|
|
|
"strconv"
|
|
|
- "sync"
|
|
|
|
|
|
"golang.zx2c4.com/wireguard/conn"
|
|
|
|
|
|
@@ -14,12 +13,10 @@ import (
|
|
|
"github.com/xtls/xray-core/transport/internet"
|
|
|
)
|
|
|
|
|
|
+// netReadInfo holds the result of a read operation from a specific endpoint
|
|
|
type netReadInfo struct {
|
|
|
- // status
|
|
|
- waiter sync.WaitGroup
|
|
|
- // param
|
|
|
- buff []byte
|
|
|
// result
|
|
|
+ buff []byte
|
|
|
bytes int
|
|
|
endpoint conn.Endpoint
|
|
|
err error
|
|
|
@@ -30,8 +27,8 @@ type netBind struct {
|
|
|
dns dns.Client
|
|
|
dnsOption dns.IPOption
|
|
|
|
|
|
- workers int
|
|
|
- readQueue chan *netReadInfo
|
|
|
+ workers int
|
|
|
+ responseRecv chan *netReadInfo // responses from all endpoints flow through here
|
|
|
}
|
|
|
|
|
|
// SetMark implements conn.Bind
|
|
|
@@ -79,7 +76,7 @@ func (bind *netBind) BatchSize() int {
|
|
|
|
|
|
// Open implements conn.Bind
|
|
|
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
|
- bind.readQueue = make(chan *netReadInfo)
|
|
|
+ bind.responseRecv = make(chan *netReadInfo)
|
|
|
|
|
|
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
|
|
defer func() {
|
|
|
@@ -89,13 +86,14 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
- r := &netReadInfo{
|
|
|
- buff: bufs[0],
|
|
|
+ r, ok := <-bind.responseRecv
|
|
|
+ if !ok {
|
|
|
+ return 0, errors.New("channel closed")
|
|
|
}
|
|
|
- r.waiter.Add(1)
|
|
|
- bind.readQueue <- r
|
|
|
- r.waiter.Wait() // wait read goroutine done, or we will miss the result
|
|
|
- sizes[0], eps[0] = r.bytes, r.endpoint
|
|
|
+
|
|
|
+ copy(bufs[0], r.buff[:r.bytes])
|
|
|
+ sizes[0] = r.bytes
|
|
|
+ eps[0] = r.endpoint
|
|
|
return 1, r.err
|
|
|
}
|
|
|
workers := bind.workers
|
|
|
@@ -112,8 +110,8 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
|
|
|
|
// Close implements conn.Bind
|
|
|
func (bind *netBind) Close() error {
|
|
|
- if bind.readQueue != nil {
|
|
|
- close(bind.readQueue)
|
|
|
+ if bind.responseRecv != nil {
|
|
|
+ close(bind.responseRecv)
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
@@ -133,30 +131,32 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
|
|
}
|
|
|
endpoint.conn = c
|
|
|
|
|
|
- go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
|
|
|
+ go func(responseRecv chan<- *netReadInfo, endpoint *netEndpoint, c net.Conn) {
|
|
|
+ defer func() {
|
|
|
+ _ = recover() // gracefully handle send on closed channel
|
|
|
+ }()
|
|
|
for {
|
|
|
- v, ok := <-readQueue
|
|
|
- if !ok {
|
|
|
- return
|
|
|
- }
|
|
|
- i, err := c.Read(v.buff)
|
|
|
+ buff := make([]byte, 1700) // max MTU for WireGuard
|
|
|
+ i, err := c.Read(buff)
|
|
|
|
|
|
if i > 3 {
|
|
|
- v.buff[1] = 0
|
|
|
- v.buff[2] = 0
|
|
|
- v.buff[3] = 0
|
|
|
+ buff[1] = 0
|
|
|
+ buff[2] = 0
|
|
|
+ buff[3] = 0
|
|
|
}
|
|
|
|
|
|
- v.bytes = i
|
|
|
- v.endpoint = endpoint
|
|
|
- v.err = err
|
|
|
- v.waiter.Done()
|
|
|
+ responseRecv <- &netReadInfo{
|
|
|
+ buff: buff,
|
|
|
+ bytes: i,
|
|
|
+ endpoint: endpoint,
|
|
|
+ err: err,
|
|
|
+ }
|
|
|
if err != nil {
|
|
|
endpoint.conn = nil
|
|
|
return
|
|
|
}
|
|
|
}
|
|
|
- }(bind.readQueue, endpoint)
|
|
|
+ }(bind.responseRecv, endpoint, c)
|
|
|
|
|
|
return nil
|
|
|
}
|