Fangliding 2 месяцев назад
Родитель
Сommit
8b86c6b041
2 измененных файлов с 18 добавлено и 13 удалено
  1. 13 8
      proxy/wireguard/bind.go
  2. 5 5
      proxy/wireguard/server.go

+ 13 - 8
proxy/wireguard/bind.go

@@ -10,6 +10,7 @@ import (
 	"golang.zx2c4.com/wireguard/conn"
 	"golang.zx2c4.com/wireguard/device"
 
+	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/features/dns"
@@ -17,7 +18,7 @@ import (
 )
 
 type netReadInfo struct {
-	buff     []byte
+	buff     *buf.Buffer
 	endpoint conn.Endpoint
 }
 
@@ -82,7 +83,8 @@ func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
 	fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
 		select {
 		case r := <-bind.readQueue:
-			sizes[0], eps[0] = copy(bufs[0], r.buff), r.endpoint
+			sizes[0], eps[0] = copy(bufs[0], r.buff.Bytes()), r.endpoint
+			r.buff.Release()
 			return 1, nil
 		case <-bind.closedCh:
 			errors.LogDebug(context.Background(), "recv func closed")
@@ -130,27 +132,30 @@ func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
 
 	go func() {
 		for {
-			buff := make([]byte, device.MaxMessageSize)
-			n, err := c.Read(buff)
+			buff := buf.NewWithSize(device.MaxMessageSize)
+			n, err := buff.ReadFrom(c)
 
 			if err != nil {
+				buff.Release()
 				endpoint.conn = nil
 				c.Close()
 				return
 			}
 
+			rawBytes := buff.Bytes()
 			if n > 3 {
-				buff[1] = 0
-				buff[2] = 0
-				buff[3] = 0
+				rawBytes[1] = 0
+				rawBytes[2] = 0
+				rawBytes[3] = 0
 			}
 
 			select {
 			case bind.readQueue <- &netReadInfo{
-				buff:     buff[:n],
+				buff:     buff,
 				endpoint: endpoint,
 			}:
 			case <-bind.closedCh:
+				buff.Release()
 				endpoint.conn = nil
 				c.Close()
 				return

+ 5 - 5
proxy/wireguard/server.go

@@ -101,17 +101,17 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
 		}
 
 		for i, b := range mb {
-			buff := b.Bytes()
 
+			rawBytes := b.Bytes()
 			if b.Len() > 3 {
-				buff[1] = 0
-				buff[2] = 0
-				buff[3] = 0
+				rawBytes[1] = 0
+				rawBytes[2] = 0
+				rawBytes[3] = 0
 			}
 
 			select {
 			case s.bindServer.readQueue <- &netReadInfo{
-				buff:     buff,
+				buff:     b,
 				endpoint: nep,
 			}:
 			case <-s.bindServer.closedCh: