瀏覽代碼

annoying pre alloc

Fangliding 3 月之前
父節點
當前提交
b587d1e94a
共有 2 個文件被更改,包括 16 次插入4 次删除
  1. 1 0
      transport/internet/splithttp/client.go
  2. 15 4
      transport/internet/splithttp/hub.go

+ 1 - 0
transport/internet/splithttp/client.go

@@ -117,6 +117,7 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessio
 		// times, the body is already drained after the first
 		// request
 		requestBuff := new(bytes.Buffer)
+		requestBuff.Grow(512 + int(req.ContentLength))
 		common.Must(req.Write(requestBuff))
 
 		var uploadConn any

+ 15 - 4
transport/internet/splithttp/hub.go

@@ -294,10 +294,21 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
 
 		var bodyPayload []byte
 		if dataPlacement == PlacementAuto || dataPlacement == PlacementBody {
-			bodyPayload, err = buf.ReadAllToBytes(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1))
-			if err != nil {
-				errors.LogInfoInner(context.Background(), err, "failed to upload (ReadAll)")
-				writer.WriteHeader(http.StatusInternalServerError)
+			var readErr error
+			if request.ContentLength > int64(scMaxEachPostBytes) {
+				errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.")
+				writer.WriteHeader(http.StatusRequestEntityTooLarge)
+				return
+			}
+			if request.ContentLength > 0 {
+				bodyPayload = make([]byte, request.ContentLength)
+				_, readErr = io.ReadFull(request.Body, bodyPayload)
+			} else {
+				bodyPayload, readErr = buf.ReadAllToBytes(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1))
+			}
+			if readErr != nil {
+				errors.LogInfoInner(context.Background(), readErr, "failed to read body payload")
+				writer.WriteHeader(http.StatusBadRequest)
 				return
 			}
 		}