Przeglądaj źródła

Refactor code to use DispatchLink() in vmess inbound

- Always apply NoTerminationSignal
yuhan6665 9 miesięcy temu
rodzic
commit
b14d5407e5
1 zmienionych plików z 16 dodań i 76 usunięć
  1. 16 76
      proxy/vmess/inbound/inbound.go

+ 16 - 76
proxy/vmess/inbound/inbound.go

@@ -14,8 +14,6 @@ import (
 	"github.com/xtls/xray-core/common/net"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/session"
-	"github.com/xtls/xray-core/common/signal"
-	"github.com/xtls/xray-core/common/task"
 	"github.com/xtls/xray-core/common/uuid"
 	"github.com/xtls/xray-core/core"
 	feature_inbound "github.com/xtls/xray-core/features/inbound"
@@ -23,6 +21,7 @@ import (
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/proxy/vmess"
 	"github.com/xtls/xray-core/proxy/vmess/encoding"
+	"github.com/xtls/xray-core/transport"
 	"github.com/xtls/xray-core/transport/internet/stat"
 )
 
@@ -184,44 +183,6 @@ func (h *Handler) RemoveUser(ctx context.Context, email string) error {
 	return nil
 }
 
-func transferResponse(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, response *protocol.ResponseHeader, input buf.Reader, output *buf.BufferedWriter) error {
-	session.EncodeResponseHeader(response, output)
-
-	bodyWriter, err := session.EncodeResponseBody(request, output)
-	if err != nil {
-		return errors.New("failed to start decoding response").Base(err)
-	}
-	{
-		// Optimize for small response packet
-		data, err := input.ReadMultiBuffer()
-		if err != nil {
-			return err
-		}
-
-		if err := bodyWriter.WriteMultiBuffer(data); err != nil {
-			return err
-		}
-	}
-
-	if err := output.SetBuffered(false); err != nil {
-		return err
-	}
-
-	if err := buf.Copy(input, bodyWriter, buf.UpdateActivity(timer)); err != nil {
-		return err
-	}
-
-	account := request.User.Account.(*vmess.MemoryAccount)
-
-	if request.Option.Has(protocol.RequestOptionChunkStream) && !account.NoTerminationSignal {
-		if err := bodyWriter.WriteMultiBuffer(buf.MultiBuffer{}); err != nil {
-			return err
-		}
-	}
-
-	return nil
-}
-
 // Process implements proxy.Inbound.Process().
 func (h *Handler) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error {
 	sessionPolicy := h.policyManager.ForLevel(0)
@@ -275,49 +236,28 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
 	inbound.CanSpliceCopy = 3
 	inbound.User = request.User
 
-	sessionPolicy = h.policyManager.ForLevel(request.User.Level)
-
-	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
-
-	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
-	link, err := dispatcher.Dispatch(ctx, request.Destination())
+	bodyReader, err := svrSession.DecodeRequestBody(request, reader)
 	if err != nil {
-		return errors.New("failed to dispatch request to ", request.Destination()).Base(err)
+		return errors.New("failed to start decoding").Base(err)
 	}
 
-	requestDone := func() error {
-		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
-
-		bodyReader, err := svrSession.DecodeRequestBody(request, reader)
-		if err != nil {
-			return errors.New("failed to start decoding").Base(err)
-		}
-		if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
-			return errors.New("failed to transfer request").Base(err)
-		}
-		return nil
+	writer := buf.NewBufferedWriter(buf.NewWriter(connection))
+	response := &protocol.ResponseHeader{
+		Command: h.generateCommand(ctx, request),
 	}
-
-	responseDone := func() error {
-		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
-
-		writer := buf.NewBufferedWriter(buf.NewWriter(connection))
-		defer writer.Flush()
-
-		response := &protocol.ResponseHeader{
-			Command: h.generateCommand(ctx, request),
-		}
-		return transferResponse(timer, svrSession, request, response, link.Reader, writer)
+	svrSession.EncodeResponseHeader(response, writer)
+	bodyWriter, err := svrSession.EncodeResponseBody(request, writer)
+	if err != nil {
+		return errors.New("failed to start decoding response").Base(err)
 	}
+	writer.SetFlushNext()
 
-	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
-	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
-		common.Interrupt(link.Reader)
-		common.Interrupt(link.Writer)
-		return errors.New("connection ends").Base(err)
+	if err := dispatcher.DispatchLink(ctx, request.Destination(), &transport.Link{
+		Reader: bodyReader,
+		Writer: bodyWriter},
+	); err != nil {
+		return errors.New("failed to dispatch request").Base(err)
 	}
-
 	return nil
 }