Преглед изворни кода

XTLS: More separate uplink/downlink flags for splice copy (#4407)

- In 03131c72dbbfc13ba4ce8e1f9f65f43f3dda7372 new flags were added for uplink/downlink, but that was not suffcient
- Now that the traffic state contains all possible info
- Each inbound and outbound is responsible to set their own CanSpliceCopy flag. Note that this also open up more splice usage. E.g. socks in -> freedom out
- Fixes https://github.com/XTLS/Xray-core/issues/4033
yuhan6665 пре 1 година
родитељ
комит
eef74b2c7d
6 измењених фајлова са 132 додато и 67 уклоњено
  1. 1 0
      proxy/http/client.go
  2. 1 0
      proxy/http/server.go
  3. 117 58
      proxy/proxy.go
  4. 2 0
      proxy/socks/client.go
  5. 2 0
      proxy/socks/server.go
  6. 9 9
      proxy/vless/encoding/encoding.go

+ 1 - 0
proxy/http/client.go

@@ -151,6 +151,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
 	}
 	responseFunc := func() error {
+		ob.CanSpliceCopy = 1
 		defer timer.SetTimeout(p.Timeouts.UplinkOnly)
 		return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
 	}

+ 1 - 0
proxy/http/server.go

@@ -207,6 +207,7 @@ func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *buf
 	}
 
 	responseDone := func() error {
+		inbound.CanSpliceCopy = 1
 		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
 
 		v2writer := buf.NewWriter(conn)

+ 117 - 58
proxy/proxy.go

@@ -107,19 +107,33 @@ type TrafficState struct {
 	IsTLS                  bool
 	Cipher                 uint16
 	RemainingServerHello   int32
+	Inbound                InboundState
+	Outbound               OutboundState
+}
 
+type InboundState struct {
 	// reader link state
 	WithinPaddingBuffers     bool
-	DownlinkReaderDirectCopy bool
 	UplinkReaderDirectCopy   bool
 	RemainingCommand         int32
 	RemainingContent         int32
 	RemainingPadding         int32
 	CurrentCommand           int
-
 	// write link state
 	IsPadding                bool
 	DownlinkWriterDirectCopy bool
+}
+
+type OutboundState struct {
+	// reader link state
+	WithinPaddingBuffers     bool
+	DownlinkReaderDirectCopy bool
+	RemainingCommand         int32
+	RemainingContent         int32
+	RemainingPadding         int32
+	CurrentCommand           int
+	// write link state
+	IsPadding                bool
 	UplinkWriterDirectCopy   bool
 }
 
@@ -132,16 +146,26 @@ func NewTrafficState(userUUID []byte) *TrafficState {
 		IsTLS:                    false,
 		Cipher:                   0,
 		RemainingServerHello:     -1,
-		WithinPaddingBuffers:     true,
-		DownlinkReaderDirectCopy: false,
-		UplinkReaderDirectCopy:   false,
-		RemainingCommand:         -1,
-		RemainingContent:         -1,
-		RemainingPadding:         -1,
-		CurrentCommand:           0,
-		IsPadding:                true,
-		DownlinkWriterDirectCopy: false,
-		UplinkWriterDirectCopy:   false,
+		Inbound: InboundState{
+			WithinPaddingBuffers:     true,
+			UplinkReaderDirectCopy:   false,
+			RemainingCommand:         -1,
+			RemainingContent:         -1,
+			RemainingPadding:         -1,
+			CurrentCommand:           0,
+			IsPadding:                true,
+			DownlinkWriterDirectCopy: false,
+		},
+		Outbound: OutboundState{
+			WithinPaddingBuffers:     true,
+			DownlinkReaderDirectCopy: false,
+			RemainingCommand:         -1,
+			RemainingContent:         -1,
+			RemainingPadding:         -1,
+			CurrentCommand:           0,
+			IsPadding:                true,
+			UplinkWriterDirectCopy:   false,
+		},
 	}
 }
 
@@ -166,28 +190,43 @@ func NewVisionReader(reader buf.Reader, state *TrafficState, isUplink bool, cont
 func (w *VisionReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	buffer, err := w.Reader.ReadMultiBuffer()
 	if !buffer.IsEmpty() {
-		if w.trafficState.WithinPaddingBuffers || w.trafficState.NumberOfPacketToFilter > 0 {
+		var withinPaddingBuffers *bool
+		var remainingContent *int32
+		var remainingPadding *int32
+		var currentCommand *int
+		var switchToDirectCopy *bool
+		if w.isUplink {
+			withinPaddingBuffers = &w.trafficState.Inbound.WithinPaddingBuffers
+			remainingContent = &w.trafficState.Inbound.RemainingContent
+			remainingPadding = &w.trafficState.Inbound.RemainingPadding
+			currentCommand = &w.trafficState.Inbound.CurrentCommand
+			switchToDirectCopy = &w.trafficState.Inbound.UplinkReaderDirectCopy
+		} else {
+			withinPaddingBuffers = &w.trafficState.Outbound.WithinPaddingBuffers
+			remainingContent = &w.trafficState.Outbound.RemainingContent
+			remainingPadding = &w.trafficState.Outbound.RemainingPadding
+			currentCommand = &w.trafficState.Outbound.CurrentCommand
+			switchToDirectCopy = &w.trafficState.Outbound.DownlinkReaderDirectCopy
+		}
+
+		if *withinPaddingBuffers || w.trafficState.NumberOfPacketToFilter > 0 {
 			mb2 := make(buf.MultiBuffer, 0, len(buffer))
 			for _, b := range buffer {
-				newbuffer := XtlsUnpadding(b, w.trafficState, w.ctx)
+				newbuffer := XtlsUnpadding(b, w.trafficState, w.isUplink, w.ctx)
 				if newbuffer.Len() > 0 {
 					mb2 = append(mb2, newbuffer)
 				}
 			}
 			buffer = mb2
-			if w.trafficState.RemainingContent > 0 || w.trafficState.RemainingPadding > 0 || w.trafficState.CurrentCommand == 0 {
-				w.trafficState.WithinPaddingBuffers = true
-			} else if w.trafficState.CurrentCommand == 1 {
-				w.trafficState.WithinPaddingBuffers = false
-			} else if w.trafficState.CurrentCommand == 2 {
-				w.trafficState.WithinPaddingBuffers = false
-				if w.isUplink {
-					w.trafficState.UplinkReaderDirectCopy = true
-				} else {
-					w.trafficState.DownlinkReaderDirectCopy = true
-				}
+			if *remainingContent > 0 || *remainingPadding > 0 || *currentCommand == 0 {
+				*withinPaddingBuffers = true
+			} else if *currentCommand == 1 {
+				*withinPaddingBuffers = false
+			} else if *currentCommand == 2 {
+				*withinPaddingBuffers = false
+				*switchToDirectCopy = true
 			} else {
-				errors.LogInfo(w.ctx, "XtlsRead unknown command ", w.trafficState.CurrentCommand, buffer.Len())
+				errors.LogInfo(w.ctx, "XtlsRead unknown command ", *currentCommand, buffer.Len())
 			}
 		}
 		if w.trafficState.NumberOfPacketToFilter > 0 {
@@ -223,7 +262,16 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 	if w.trafficState.NumberOfPacketToFilter > 0 {
 		XtlsFilterTls(mb, w.trafficState, w.ctx)
 	}
-	if w.trafficState.IsPadding {
+	var isPadding *bool
+	var switchToDirectCopy *bool
+	if w.isUplink {
+		isPadding = &w.trafficState.Outbound.IsPadding
+		switchToDirectCopy = &w.trafficState.Outbound.UplinkWriterDirectCopy
+	} else {
+		isPadding = &w.trafficState.Inbound.IsPadding
+		switchToDirectCopy = &w.trafficState.Inbound.DownlinkWriterDirectCopy
+	}
+	if *isPadding {
 		if len(mb) == 1 && mb[0] == nil {
 			mb[0] = XtlsPadding(nil, CommandPaddingContinue, &w.writeOnceUserUUID, true, w.ctx) // we do a long padding to hide vless header
 			return w.Writer.WriteMultiBuffer(mb)
@@ -233,11 +281,7 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 		for i, b := range mb {
 			if w.trafficState.IsTLS && b.Len() >= 6 && bytes.Equal(TlsApplicationDataStart, b.BytesTo(3)) {
 				if w.trafficState.EnableXtls {
-					if w.isUplink {
-						w.trafficState.UplinkWriterDirectCopy = true
-					} else {
-						w.trafficState.DownlinkWriterDirectCopy = true
-					}
+					*switchToDirectCopy = true
 				}
 				var command byte = CommandPaddingContinue
 				if i == len(mb)-1 {
@@ -247,16 +291,16 @@ func (w *VisionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 					}
 				}
 				mb[i] = XtlsPadding(b, command, &w.writeOnceUserUUID, true, w.ctx)
-				w.trafficState.IsPadding = false // padding going to end
+				*isPadding = false // padding going to end
 				longPadding = false
 				continue
 			} else if !w.trafficState.IsTLS12orAbove && w.trafficState.NumberOfPacketToFilter <= 1 { // For compatibility with earlier vision receiver, we finish padding 1 packet early
-				w.trafficState.IsPadding = false
+				*isPadding = false
 				mb[i] = XtlsPadding(b, CommandPaddingEnd, &w.writeOnceUserUUID, longPadding, w.ctx)
 				break
 			}
 			var command byte = CommandPaddingContinue
-			if i == len(mb)-1 && !w.trafficState.IsPadding {
+			if i == len(mb)-1 && !*isPadding {
 				command = CommandPaddingEnd
 				if w.trafficState.EnableXtls {
 					command = CommandPaddingDirect
@@ -343,38 +387,53 @@ func XtlsPadding(b *buf.Buffer, command byte, userUUID *[]byte, longPadding bool
 }
 
 // XtlsUnpadding remove padding and parse command
-func XtlsUnpadding(b *buf.Buffer, s *TrafficState, ctx context.Context) *buf.Buffer {
-	if s.RemainingCommand == -1 && s.RemainingContent == -1 && s.RemainingPadding == -1 { // initial state
+func XtlsUnpadding(b *buf.Buffer, s *TrafficState, isUplink bool, ctx context.Context) *buf.Buffer {
+	var remainingCommand *int32
+	var remainingContent *int32
+	var remainingPadding *int32
+	var currentCommand *int
+	if isUplink {
+		remainingCommand = &s.Inbound.RemainingCommand
+		remainingContent = &s.Inbound.RemainingContent
+		remainingPadding = &s.Inbound.RemainingPadding
+		currentCommand = &s.Inbound.CurrentCommand
+	} else {
+		remainingCommand = &s.Outbound.RemainingCommand
+		remainingContent = &s.Outbound.RemainingContent
+		remainingPadding = &s.Outbound.RemainingPadding
+		currentCommand = &s.Outbound.CurrentCommand
+	}
+	if *remainingCommand == -1 && *remainingContent == -1 && *remainingPadding == -1 { // initial state
 		if b.Len() >= 21 && bytes.Equal(s.UserUUID, b.BytesTo(16)) {
 			b.Advance(16)
-			s.RemainingCommand = 5
+			*remainingCommand = 5
 		} else {
 			return b
 		}
 	}
 	newbuffer := buf.New()
 	for b.Len() > 0 {
-		if s.RemainingCommand > 0 {
+		if *remainingCommand > 0 {
 			data, err := b.ReadByte()
 			if err != nil {
 				return newbuffer
 			}
-			switch s.RemainingCommand {
+			switch *remainingCommand {
 			case 5:
-				s.CurrentCommand = int(data)
+				*currentCommand = int(data)
 			case 4:
-				s.RemainingContent = int32(data) << 8
+				*remainingContent = int32(data) << 8
 			case 3:
-				s.RemainingContent = s.RemainingContent | int32(data)
+				*remainingContent = *remainingContent | int32(data)
 			case 2:
-				s.RemainingPadding = int32(data) << 8
+				*remainingPadding = int32(data) << 8
 			case 1:
-				s.RemainingPadding = s.RemainingPadding | int32(data)
-				errors.LogInfo(ctx, "Xtls Unpadding new block, content ", s.RemainingContent, " padding ", s.RemainingPadding, " command ", s.CurrentCommand)
+				*remainingPadding = *remainingPadding | int32(data)
+				errors.LogInfo(ctx, "Xtls Unpadding new block, content ", *remainingContent, " padding ", *remainingPadding, " command ", *currentCommand)
 			}
-			s.RemainingCommand--
-		} else if s.RemainingContent > 0 {
-			len := s.RemainingContent
+			*remainingCommand--
+		} else if *remainingContent > 0 {
+			len := *remainingContent
 			if b.Len() < len {
 				len = b.Len()
 			}
@@ -383,22 +442,22 @@ func XtlsUnpadding(b *buf.Buffer, s *TrafficState, ctx context.Context) *buf.Buf
 				return newbuffer
 			}
 			newbuffer.Write(data)
-			s.RemainingContent -= len
+			*remainingContent -= len
 		} else { // remainingPadding > 0
-			len := s.RemainingPadding
+			len := *remainingPadding
 			if b.Len() < len {
 				len = b.Len()
 			}
 			b.Advance(len)
-			s.RemainingPadding -= len
+			*remainingPadding -= len
 		}
-		if s.RemainingCommand <= 0 && s.RemainingContent <= 0 && s.RemainingPadding <= 0 { // this block done
-			if s.CurrentCommand == 0 {
-				s.RemainingCommand = 5
+		if *remainingCommand <= 0 && *remainingContent <= 0 && *remainingPadding <= 0 { // this block done
+			if *currentCommand == 0 {
+				*remainingCommand = 5
 			} else {
-				s.RemainingCommand = -1 // set to initial state
-				s.RemainingContent = -1
-				s.RemainingPadding = -1
+				*remainingCommand = -1 // set to initial state
+				*remainingContent = -1
+				*remainingPadding = -1
 				if b.Len() > 0 { // shouldn't happen
 					newbuffer.Write(b.Bytes())
 				}

+ 2 - 0
proxy/socks/client.go

@@ -146,6 +146,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 			return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer))
 		}
 		responseFunc = func() error {
+			ob.CanSpliceCopy = 1
 			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
 			return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
 		}
@@ -161,6 +162,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 			return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
 		}
 		responseFunc = func() error {
+			ob.CanSpliceCopy = 1
 			defer timer.SetTimeout(p.Timeouts.UplinkOnly)
 			reader := &UDPReader{Reader: udpConn}
 			return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer))

+ 2 - 0
proxy/socks/server.go

@@ -199,6 +199,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 	}
 
 	responseDone := func() error {
+		inbound.CanSpliceCopy = 1
 		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
 
 		v2writer := buf.NewWriter(writer)
@@ -256,6 +257,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
 	if inbound != nil && inbound.Source.IsValid() {
 		errors.LogInfo(ctx, "client UDP connection from ", inbound.Source)
 	}
+	inbound.CanSpliceCopy = 1
 
 	var dest *net.Destination
 

+ 9 - 9
proxy/vless/encoding/encoding.go

@@ -175,16 +175,16 @@ func DecodeResponseHeader(reader io.Reader, request *protocol.RequestHeader) (*A
 func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer, conn net.Conn, input *bytes.Reader, rawInput *bytes.Buffer, trafficState *proxy.TrafficState, ob *session.Outbound, isUplink bool, ctx context.Context) error {
 	err := func() error {
 		for {
-			if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy {
+			if isUplink && trafficState.Inbound.UplinkReaderDirectCopy || !isUplink && trafficState.Outbound.DownlinkReaderDirectCopy {
 				var writerConn net.Conn
 				var inTimer *signal.ActivityTimer
 				if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Conn != nil {
 					writerConn = inbound.Conn
 					inTimer = inbound.Timer
-					if inbound.CanSpliceCopy == 2 {
+					if isUplink && inbound.CanSpliceCopy == 2 {
 						inbound.CanSpliceCopy = 1
 					}
-					if ob != nil && ob.CanSpliceCopy == 2 { // ob need to be passed in due to context can change
+					if !isUplink && ob != nil && ob.CanSpliceCopy == 2 { // ob need to be passed in due to context can change
 						ob.CanSpliceCopy = 1
 					}
 				}
@@ -193,7 +193,7 @@ func XtlsRead(reader buf.Reader, writer buf.Writer, timer *signal.ActivityTimer,
 			buffer, err := reader.ReadMultiBuffer()
 			if !buffer.IsEmpty() {
 				timer.Update()
-				if isUplink && trafficState.UplinkReaderDirectCopy || !isUplink && trafficState.DownlinkReaderDirectCopy {
+				if isUplink && trafficState.Inbound.UplinkReaderDirectCopy || !isUplink && trafficState.Outbound.DownlinkReaderDirectCopy {
 					// XTLS Vision processes struct TLS Conn's input and rawInput
 					if inputBuffer, err := buf.ReadFrom(input); err == nil {
 						if !inputBuffer.IsEmpty() {
@@ -227,12 +227,12 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
 		var ct stats.Counter
 		for {
 			buffer, err := reader.ReadMultiBuffer()
-			if isUplink && trafficState.UplinkWriterDirectCopy || !isUplink && trafficState.DownlinkWriterDirectCopy {
+			if isUplink && trafficState.Outbound.UplinkWriterDirectCopy || !isUplink && trafficState.Inbound.DownlinkWriterDirectCopy {
 				if inbound := session.InboundFromContext(ctx); inbound != nil {
-					if inbound.CanSpliceCopy == 2 {
+					if !isUplink && inbound.CanSpliceCopy == 2 {
 						inbound.CanSpliceCopy = 1
 					}
-					if ob != nil && ob.CanSpliceCopy == 2 {
+					if isUplink && ob != nil && ob.CanSpliceCopy == 2 {
 						ob.CanSpliceCopy = 1
 					}
 				}
@@ -240,9 +240,9 @@ func XtlsWrite(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdate
 				writer = buf.NewWriter(rawConn)
 				ct = writerCounter
 				if isUplink {
-					trafficState.UplinkWriterDirectCopy = false
+					trafficState.Outbound.UplinkWriterDirectCopy = false
 				} else {
-					trafficState.DownlinkWriterDirectCopy = false
+					trafficState.Inbound.DownlinkWriterDirectCopy = false
 				}
 			}
 			if !buffer.IsEmpty() {