Просмотр исходного кода

XHTTP transport: Some optimizations (#5803)

https://github.com/XTLS/Xray-core/pull/5801
https://github.com/XTLS/Xray-core/pull/5808

---------

Co-authored-by: Sergei Ozeranskii <[email protected]>
Co-authored-by: rufsieus <[email protected]>
风扇滑翔翼 2 месяцев назад
Родитель
Сommit
c1b67a961e

+ 5 - 5
transport/internet/splithttp/browser_client.go

@@ -5,6 +5,7 @@ import (
 	"io"
 	"net/http"
 
+	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/transport/internet/browser_dialer"
@@ -41,21 +42,20 @@ func (c *BrowserDialerClient) OpenStream(ctx context.Context, url string, sessio
 	return websocket.NewConnection(conn, dummyAddr, nil, 0), conn.RemoteAddr(), conn.LocalAddr(), nil
 }
 
-func (c *BrowserDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, body io.Reader, contentLength int64) error {
+func (c *BrowserDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, payload buf.MultiBuffer) error {
 	method := c.transportConfig.GetNormalizedUplinkHTTPMethod()
-	request, err := http.NewRequest(method, url, body)
+	request, err := http.NewRequest(method, url, nil)
 	if err != nil {
 		return err
 	}
 
-	request.ContentLength = contentLength
-	err = c.transportConfig.FillPacketRequest(request, sessionId, seqStr)
+	err = c.transportConfig.FillPacketRequest(request, sessionId, seqStr, payload)
 	if err != nil {
 		return err
 	}
 
 	var bytes []byte
-	if (request.Body != nil) {
+	if request.Body != nil {
 		bytes, err = io.ReadAll(request.Body)
 		if err != nil {
 			return err

+ 6 - 5
transport/internet/splithttp/client.go

@@ -10,6 +10,7 @@ import (
 	"sync"
 
 	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/signal/done"
@@ -23,7 +24,7 @@ type DialerClient interface {
 	OpenStream(context.Context, string, string, io.Reader, bool) (io.ReadCloser, net.Addr, net.Addr, error)
 
 	// ctx, url, sessionId, seqStr, body, contentLength
-	PostPacket(context.Context, string, string, string, io.Reader, int64) error
+	PostPacket(context.Context, string, string, string, buf.MultiBuffer) error
 }
 
 // implements splithttp.DialerClient in terms of direct network connections
@@ -89,14 +90,13 @@ func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, sessio
 	return
 }
 
-func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, body io.Reader, contentLength int64) error {
+func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, payload buf.MultiBuffer) error {
 	method := c.transportConfig.GetNormalizedUplinkHTTPMethod()
-	req, err := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, body)
+	req, err := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, nil)
 	if err != nil {
 		return err
 	}
-	req.ContentLength = contentLength
-	c.transportConfig.FillPacketRequest(req, sessionId, seqStr)
+	c.transportConfig.FillPacketRequest(req, sessionId, seqStr, payload)
 
 	if c.httpVersion != "1.1" {
 		resp, err := c.client.Do(req)
@@ -117,6 +117,7 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessio
 		// times, the body is already drained after the first
 		// request
 		requestBuff := new(bytes.Buffer)
+		requestBuff.Grow(512 + int(req.ContentLength))
 		common.Must(req.Write(requestBuff))
 
 		var uploadConn any

+ 10 - 15
transport/internet/splithttp/config.go

@@ -8,6 +8,7 @@ import (
 	"strings"
 
 	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/crypto"
 	"github.com/xtls/xray-core/common/utils"
 	"github.com/xtls/xray-core/transport/internet"
@@ -55,7 +56,6 @@ func (c *Config) GetRequestHeader() http.Header {
 	return header
 }
 
-
 func (c *Config) GetRequestHeaderWithPayload(payload []byte) http.Header {
 	header := c.GetRequestHeader()
 
@@ -100,9 +100,9 @@ func (c *Config) WriteResponseHeader(writer http.ResponseWriter, requestMethod s
 	}
 
 	if c.GetNormalizedSessionPlacement() == PlacementCookie ||
-	   c.GetNormalizedSeqPlacement() == PlacementCookie ||
-	   c.XPaddingPlacement == PlacementCookie ||
-	   c.GetNormalizedUplinkDataPlacement() == PlacementCookie {
+		c.GetNormalizedSeqPlacement() == PlacementCookie ||
+		c.XPaddingPlacement == PlacementCookie ||
+		c.GetNormalizedUplinkDataPlacement() == PlacementCookie {
 		writer.Header().Set("Access-Control-Allow-Credentials", "true")
 	}
 
@@ -322,22 +322,17 @@ func (c *Config) FillStreamRequest(request *http.Request, sessionId string, seqS
 	}
 }
 
-func (c *Config) FillPacketRequest(request *http.Request, sessionId string, seqStr string) error {
+func (c *Config) FillPacketRequest(request *http.Request, sessionId string, seqStr string, payload buf.MultiBuffer) error {
 	dataPlacement := c.GetNormalizedUplinkDataPlacement()
 
 	if dataPlacement == PlacementBody || dataPlacement == PlacementAuto {
 		request.Header = c.GetRequestHeader()
+		request.Body = io.NopCloser(&buf.MultiBufferContainer{MultiBuffer: payload})
+		request.ContentLength = int64(payload.Len())
 	} else {
-		var data []byte
-		var err error
-		if request.Body != nil {
-			data, err = io.ReadAll(request.Body)
-			if err != nil {
-				return err
-			}
-		}
-		request.Body = nil
-		request.ContentLength = 0
+		data := make([]byte, payload.Len())
+		payload.Copy(data)
+		buf.ReleaseMulti(payload)
 		switch dataPlacement {
 		case PlacementHeader:
 			request.Header = c.GetRequestHeaderWithPayload(data)

+ 1 - 2
transport/internet/splithttp/dialer.go

@@ -562,8 +562,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 						requestURL.String(),
 						sessionId,
 						seqStr,
-						&buf.MultiBufferContainer{MultiBuffer: chunk},
-						int64(chunk.Len()),
+						chunk,
 					)
 					wroteRequest.Close()
 					if err != nil {

+ 27 - 5
transport/internet/splithttp/hub.go

@@ -18,6 +18,7 @@ import (
 	"github.com/apernet/quic-go/http3"
 	goreality "github.com/xtls/reality"
 	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/net"
 	http_proto "github.com/xtls/xray-core/common/protocol/http"
@@ -293,15 +294,36 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 
 		var bodyPayload []byte
 		if dataPlacement == PlacementAuto || dataPlacement == PlacementBody {
-			bodyPayload, err = io.ReadAll(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1))
-			if err != nil {
-				errors.LogInfoInner(context.Background(), err, "failed to upload (ReadAll)")
-				writer.WriteHeader(http.StatusInternalServerError)
+			var readErr error
+			if request.ContentLength > int64(scMaxEachPostBytes) {
+				errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.")
+				writer.WriteHeader(http.StatusRequestEntityTooLarge)
+				return
+			}
+			if request.ContentLength > 0 {
+				bodyPayload = make([]byte, request.ContentLength)
+				_, readErr = io.ReadFull(request.Body, bodyPayload)
+			} else {
+				bodyPayload, readErr = buf.ReadAllToBytes(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1))
+			}
+			if readErr != nil {
+				errors.LogInfoInner(context.Background(), readErr, "failed to read body payload")
+				writer.WriteHeader(http.StatusBadRequest)
 				return
 			}
 		}
 
-		payload := slices.Concat(headerPayload, cookiePayload, bodyPayload)
+		var payload []byte
+		switch dataPlacement {
+		case PlacementHeader:
+			payload = headerPayload
+		case PlacementCookie:
+			payload = cookiePayload
+		case PlacementBody:
+			payload = bodyPayload
+		case PlacementAuto:
+			payload = slices.Concat(headerPayload, cookiePayload, bodyPayload)
+		}
 
 		if len(payload) > scMaxEachPostBytes {
 			errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.")