Przeglądaj źródła

XHTTP transport: Some optimizations (#5803)

https://github.com/XTLS/Xray-core/pull/5801
https://github.com/XTLS/Xray-core/pull/5808

---------

Co-authored-by: Sergei Ozeranskii <[email protected]>
Co-authored-by: rufsieus <[email protected]>
风扇滑翔翼 2 miesięcy temu
rodzic
commit
c1b67a961e

+ 5 - 5
transport/internet/splithttp/browser_client.go

@@ -5,6 +5,7 @@ import (
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 
 
+	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/transport/internet/browser_dialer"
 	"github.com/xtls/xray-core/transport/internet/browser_dialer"
@@ -41,21 +42,20 @@ func (c *BrowserDialerClient) OpenStream(ctx context.Context, url string, sessio
 	return websocket.NewConnection(conn, dummyAddr, nil, 0), conn.RemoteAddr(), conn.LocalAddr(), nil
 	return websocket.NewConnection(conn, dummyAddr, nil, 0), conn.RemoteAddr(), conn.LocalAddr(), nil
 }
 }
 
 
-func (c *BrowserDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, body io.Reader, contentLength int64) error {
+func (c *BrowserDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, payload buf.MultiBuffer) error {
 	method := c.transportConfig.GetNormalizedUplinkHTTPMethod()
 	method := c.transportConfig.GetNormalizedUplinkHTTPMethod()
-	request, err := http.NewRequest(method, url, body)
+	request, err := http.NewRequest(method, url, nil)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	request.ContentLength = contentLength
-	err = c.transportConfig.FillPacketRequest(request, sessionId, seqStr)
+	err = c.transportConfig.FillPacketRequest(request, sessionId, seqStr, payload)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
 	var bytes []byte
 	var bytes []byte
-	if (request.Body != nil) {
+	if request.Body != nil {
 		bytes, err = io.ReadAll(request.Body)
 		bytes, err = io.ReadAll(request.Body)
 		if err != nil {
 		if err != nil {
 			return err
 			return err

+ 6 - 5
transport/internet/splithttp/client.go

@@ -10,6 +10,7 @@ import (
 	"sync"
 	"sync"
 
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/signal/done"
 	"github.com/xtls/xray-core/common/signal/done"
@@ -23,7 +24,7 @@ type DialerClient interface {
 	OpenStream(context.Context, string, string, io.Reader, bool) (io.ReadCloser, net.Addr, net.Addr, error)
 	OpenStream(context.Context, string, string, io.Reader, bool) (io.ReadCloser, net.Addr, net.Addr, error)
 
 
 	// ctx, url, sessionId, seqStr, body, contentLength
 	// ctx, url, sessionId, seqStr, body, contentLength
-	PostPacket(context.Context, string, string, string, io.Reader, int64) error
+	PostPacket(context.Context, string, string, string, buf.MultiBuffer) error
 }
 }
 
 
 // implements splithttp.DialerClient in terms of direct network connections
 // implements splithttp.DialerClient in terms of direct network connections
@@ -89,14 +90,13 @@ func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, sessio
 	return
 	return
 }
 }
 
 
-func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, body io.Reader, contentLength int64) error {
+func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, payload buf.MultiBuffer) error {
 	method := c.transportConfig.GetNormalizedUplinkHTTPMethod()
 	method := c.transportConfig.GetNormalizedUplinkHTTPMethod()
-	req, err := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, body)
+	req, err := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, nil)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	req.ContentLength = contentLength
-	c.transportConfig.FillPacketRequest(req, sessionId, seqStr)
+	c.transportConfig.FillPacketRequest(req, sessionId, seqStr, payload)
 
 
 	if c.httpVersion != "1.1" {
 	if c.httpVersion != "1.1" {
 		resp, err := c.client.Do(req)
 		resp, err := c.client.Do(req)
@@ -117,6 +117,7 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessio
 		// times, the body is already drained after the first
 		// times, the body is already drained after the first
 		// request
 		// request
 		requestBuff := new(bytes.Buffer)
 		requestBuff := new(bytes.Buffer)
+		requestBuff.Grow(512 + int(req.ContentLength))
 		common.Must(req.Write(requestBuff))
 		common.Must(req.Write(requestBuff))
 
 
 		var uploadConn any
 		var uploadConn any

+ 10 - 15
transport/internet/splithttp/config.go

@@ -8,6 +8,7 @@ import (
 	"strings"
 	"strings"
 
 
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/crypto"
 	"github.com/xtls/xray-core/common/crypto"
 	"github.com/xtls/xray-core/common/utils"
 	"github.com/xtls/xray-core/common/utils"
 	"github.com/xtls/xray-core/transport/internet"
 	"github.com/xtls/xray-core/transport/internet"
@@ -55,7 +56,6 @@ func (c *Config) GetRequestHeader() http.Header {
 	return header
 	return header
 }
 }
 
 
-
 func (c *Config) GetRequestHeaderWithPayload(payload []byte) http.Header {
 func (c *Config) GetRequestHeaderWithPayload(payload []byte) http.Header {
 	header := c.GetRequestHeader()
 	header := c.GetRequestHeader()
 
 
@@ -100,9 +100,9 @@ func (c *Config) WriteResponseHeader(writer http.ResponseWriter, requestMethod s
 	}
 	}
 
 
 	if c.GetNormalizedSessionPlacement() == PlacementCookie ||
 	if c.GetNormalizedSessionPlacement() == PlacementCookie ||
-	   c.GetNormalizedSeqPlacement() == PlacementCookie ||
-	   c.XPaddingPlacement == PlacementCookie ||
-	   c.GetNormalizedUplinkDataPlacement() == PlacementCookie {
+		c.GetNormalizedSeqPlacement() == PlacementCookie ||
+		c.XPaddingPlacement == PlacementCookie ||
+		c.GetNormalizedUplinkDataPlacement() == PlacementCookie {
 		writer.Header().Set("Access-Control-Allow-Credentials", "true")
 		writer.Header().Set("Access-Control-Allow-Credentials", "true")
 	}
 	}
 
 
@@ -322,22 +322,17 @@ func (c *Config) FillStreamRequest(request *http.Request, sessionId string, seqS
 	}
 	}
 }
 }
 
 
-func (c *Config) FillPacketRequest(request *http.Request, sessionId string, seqStr string) error {
+func (c *Config) FillPacketRequest(request *http.Request, sessionId string, seqStr string, payload buf.MultiBuffer) error {
 	dataPlacement := c.GetNormalizedUplinkDataPlacement()
 	dataPlacement := c.GetNormalizedUplinkDataPlacement()
 
 
 	if dataPlacement == PlacementBody || dataPlacement == PlacementAuto {
 	if dataPlacement == PlacementBody || dataPlacement == PlacementAuto {
 		request.Header = c.GetRequestHeader()
 		request.Header = c.GetRequestHeader()
+		request.Body = io.NopCloser(&buf.MultiBufferContainer{MultiBuffer: payload})
+		request.ContentLength = int64(payload.Len())
 	} else {
 	} else {
-		var data []byte
-		var err error
-		if request.Body != nil {
-			data, err = io.ReadAll(request.Body)
-			if err != nil {
-				return err
-			}
-		}
-		request.Body = nil
-		request.ContentLength = 0
+		data := make([]byte, payload.Len())
+		payload.Copy(data)
+		buf.ReleaseMulti(payload)
 		switch dataPlacement {
 		switch dataPlacement {
 		case PlacementHeader:
 		case PlacementHeader:
 			request.Header = c.GetRequestHeaderWithPayload(data)
 			request.Header = c.GetRequestHeaderWithPayload(data)

+ 1 - 2
transport/internet/splithttp/dialer.go

@@ -562,8 +562,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 						requestURL.String(),
 						requestURL.String(),
 						sessionId,
 						sessionId,
 						seqStr,
 						seqStr,
-						&buf.MultiBufferContainer{MultiBuffer: chunk},
-						int64(chunk.Len()),
+						chunk,
 					)
 					)
 					wroteRequest.Close()
 					wroteRequest.Close()
 					if err != nil {
 					if err != nil {

+ 27 - 5
transport/internet/splithttp/hub.go

@@ -18,6 +18,7 @@ import (
 	"github.com/apernet/quic-go/http3"
 	"github.com/apernet/quic-go/http3"
 	goreality "github.com/xtls/reality"
 	goreality "github.com/xtls/reality"
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/net"
 	http_proto "github.com/xtls/xray-core/common/protocol/http"
 	http_proto "github.com/xtls/xray-core/common/protocol/http"
@@ -293,15 +294,36 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 
 
 		var bodyPayload []byte
 		var bodyPayload []byte
 		if dataPlacement == PlacementAuto || dataPlacement == PlacementBody {
 		if dataPlacement == PlacementAuto || dataPlacement == PlacementBody {
-			bodyPayload, err = io.ReadAll(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1))
-			if err != nil {
-				errors.LogInfoInner(context.Background(), err, "failed to upload (ReadAll)")
-				writer.WriteHeader(http.StatusInternalServerError)
+			var readErr error
+			if request.ContentLength > int64(scMaxEachPostBytes) {
+				errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.")
+				writer.WriteHeader(http.StatusRequestEntityTooLarge)
+				return
+			}
+			if request.ContentLength > 0 {
+				bodyPayload = make([]byte, request.ContentLength)
+				_, readErr = io.ReadFull(request.Body, bodyPayload)
+			} else {
+				bodyPayload, readErr = buf.ReadAllToBytes(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1))
+			}
+			if readErr != nil {
+				errors.LogInfoInner(context.Background(), readErr, "failed to read body payload")
+				writer.WriteHeader(http.StatusBadRequest)
 				return
 				return
 			}
 			}
 		}
 		}
 
 
-		payload := slices.Concat(headerPayload, cookiePayload, bodyPayload)
+		var payload []byte
+		switch dataPlacement {
+		case PlacementHeader:
+			payload = headerPayload
+		case PlacementCookie:
+			payload = cookiePayload
+		case PlacementBody:
+			payload = bodyPayload
+		case PlacementAuto:
+			payload = slices.Concat(headerPayload, cookiePayload, bodyPayload)
+		}
 
 
 		if len(payload) > scMaxEachPostBytes {
 		if len(payload) > scMaxEachPostBytes {
 			errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.")
 			errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.")