|
@@ -2,27 +2,23 @@ package wireguard
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
"context"
|
|
"context"
|
|
|
- "errors"
|
|
|
|
|
|
|
+ gonet "net"
|
|
|
"net/netip"
|
|
"net/netip"
|
|
|
|
|
+ "runtime"
|
|
|
"strconv"
|
|
"strconv"
|
|
|
- "sync"
|
|
|
|
|
|
|
|
|
|
"golang.zx2c4.com/wireguard/conn"
|
|
"golang.zx2c4.com/wireguard/conn"
|
|
|
|
|
+ "golang.zx2c4.com/wireguard/device"
|
|
|
|
|
|
|
|
|
|
+ "github.com/xtls/xray-core/common/errors"
|
|
|
"github.com/xtls/xray-core/common/net"
|
|
"github.com/xtls/xray-core/common/net"
|
|
|
"github.com/xtls/xray-core/features/dns"
|
|
"github.com/xtls/xray-core/features/dns"
|
|
|
"github.com/xtls/xray-core/transport/internet"
|
|
"github.com/xtls/xray-core/transport/internet"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
type netReadInfo struct {
|
|
type netReadInfo struct {
|
|
|
- // status
|
|
|
|
|
- waiter sync.WaitGroup
|
|
|
|
|
- // param
|
|
|
|
|
- buff []byte
|
|
|
|
|
- // result
|
|
|
|
|
- bytes int
|
|
|
|
|
|
|
+ buff []byte
|
|
|
endpoint conn.Endpoint
|
|
endpoint conn.Endpoint
|
|
|
- err error
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// reduce duplicated code
|
|
// reduce duplicated code
|
|
@@ -32,6 +28,7 @@ type netBind struct {
|
|
|
|
|
|
|
|
workers int
|
|
workers int
|
|
|
readQueue chan *netReadInfo
|
|
readQueue chan *netReadInfo
|
|
|
|
|
+ closedCh chan struct{}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// SetMark implements conn.Bind
|
|
// SetMark implements conn.Bind
|
|
@@ -79,27 +76,23 @@ func (bind *netBind) BatchSize() int {
|
|
|
|
|
|
|
|
// Open implements conn.Bind
|
|
// Open implements conn.Bind
|
|
|
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
|
- bind.readQueue = make(chan *netReadInfo)
|
|
|
|
|
|
|
+ bind.closedCh = make(chan struct{})
|
|
|
|
|
+ errors.LogDebug(context.Background(), "bind opened")
|
|
|
|
|
|
|
|
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
|
fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
|
|
- defer func() {
|
|
|
|
|
- if r := recover(); r != nil {
|
|
|
|
|
- n = 0
|
|
|
|
|
- err = errors.New("channel closed")
|
|
|
|
|
- }
|
|
|
|
|
- }()
|
|
|
|
|
-
|
|
|
|
|
- r, ok := <-bind.readQueue
|
|
|
|
|
- if !ok {
|
|
|
|
|
- return 0, errors.New("channel closed")
|
|
|
|
|
|
|
+ select {
|
|
|
|
|
+ case r := <-bind.readQueue:
|
|
|
|
|
+ sizes[0], eps[0] = copy(bufs[0], r.buff), r.endpoint
|
|
|
|
|
+ return 1, nil
|
|
|
|
|
+ case <-bind.closedCh:
|
|
|
|
|
+ errors.LogDebug(context.Background(), "recv func closed")
|
|
|
|
|
+ return 0, gonet.ErrClosed
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- copy(bufs[0], r.buff[:r.bytes])
|
|
|
|
|
- sizes[0], eps[0] = r.bytes, r.endpoint
|
|
|
|
|
- r.waiter.Done()
|
|
|
|
|
- return 1, r.err
|
|
|
|
|
}
|
|
}
|
|
|
workers := bind.workers
|
|
workers := bind.workers
|
|
|
|
|
+ if workers <= 0 {
|
|
|
|
|
+ workers = runtime.NumCPU()
|
|
|
|
|
+ }
|
|
|
if workers <= 0 {
|
|
if workers <= 0 {
|
|
|
workers = 1
|
|
workers = 1
|
|
|
}
|
|
}
|
|
@@ -113,8 +106,9 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
|
|
|
|
|
|
|
// Close implements conn.Bind
|
|
// Close implements conn.Bind
|
|
|
func (bind *netBind) Close() error {
|
|
func (bind *netBind) Close() error {
|
|
|
- if bind.readQueue != nil {
|
|
|
|
|
- close(bind.readQueue)
|
|
|
|
|
|
|
+ errors.LogDebug(context.Background(), "bind closed")
|
|
|
|
|
+ if bind.closedCh != nil {
|
|
|
|
|
+ close(bind.closedCh)
|
|
|
}
|
|
}
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
@@ -134,35 +128,35 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
|
|
|
}
|
|
}
|
|
|
endpoint.conn = c
|
|
endpoint.conn = c
|
|
|
|
|
|
|
|
- go func(readQueue chan<- *netReadInfo, endpoint *netEndpoint) {
|
|
|
|
|
- defer func() {
|
|
|
|
|
- _ = recover() // handle send on closed channel
|
|
|
|
|
- }()
|
|
|
|
|
|
|
+ go func() {
|
|
|
for {
|
|
for {
|
|
|
- buff := make([]byte, 1700)
|
|
|
|
|
- i, err := c.Read(buff)
|
|
|
|
|
|
|
+ buff := make([]byte, device.MaxMessageSize)
|
|
|
|
|
+ n, err := c.Read(buff)
|
|
|
|
|
+
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ endpoint.conn = nil
|
|
|
|
|
+ c.Close()
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- if i > 3 {
|
|
|
|
|
|
|
+ if n > 3 {
|
|
|
buff[1] = 0
|
|
buff[1] = 0
|
|
|
buff[2] = 0
|
|
buff[2] = 0
|
|
|
buff[3] = 0
|
|
buff[3] = 0
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- r := &netReadInfo{
|
|
|
|
|
- buff: buff,
|
|
|
|
|
- bytes: i,
|
|
|
|
|
|
|
+ select {
|
|
|
|
|
+ case bind.readQueue <- &netReadInfo{
|
|
|
|
|
+ buff: buff[:n],
|
|
|
endpoint: endpoint,
|
|
endpoint: endpoint,
|
|
|
- err: err,
|
|
|
|
|
- }
|
|
|
|
|
- r.waiter.Add(1)
|
|
|
|
|
- readQueue <- r
|
|
|
|
|
- r.waiter.Wait()
|
|
|
|
|
- if err != nil {
|
|
|
|
|
|
|
+ }:
|
|
|
|
|
+ case <-bind.closedCh:
|
|
|
endpoint.conn = nil
|
|
endpoint.conn = nil
|
|
|
|
|
+ c.Close()
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- }(bind.readQueue, endpoint)
|
|
|
|
|
|
|
+ }()
|
|
|
|
|
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
@@ -206,7 +200,8 @@ func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if nend.conn == nil {
|
|
if nend.conn == nil {
|
|
|
- return errors.New("connection not open yet")
|
|
|
|
|
|
|
+ errors.LogDebug(context.Background(), nend.dst.NetAddr(), " send on closed peer")
|
|
|
|
|
+ return errors.New("peer closed")
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
for _, buff := range buff {
|
|
for _, buff := range buff {
|