Преглед изворни кода

XHTTP, WS, HU: Forbid "host" in `headers`, read `serverName` instead (#4142)

WebSocket's config files should be updated ASAP.
RPRX пре 1 година
родитељ
комит
a2b773135a

+ 1 - 3
common/reflect/marshal_test.go

@@ -204,9 +204,7 @@ func getConfig() string {
 			  "security": "none",
 			  "wsSettings": {
 				"path": "/?ed=2048",
-				"headers": {
-				  "Host": "bing.com"
-				}
+				"host": "bing.com"
 			  }
 			}
 		  }

+ 17 - 23
infra/conf/transport_internet.go

@@ -163,13 +163,13 @@ func (c *WebSocketConfig) Build() (proto.Message, error) {
 			path = u.String()
 		}
 	}
-	// If http host is not set in the Host field, but in headers field, we add it to Host Field here.
-	// If we don't do that, http host will be overwritten as address.
-	// Host priority: Host field > headers field > address.
-	if c.Host == "" && c.Headers["host"] != "" {
-		c.Host = c.Headers["host"]
-	} else if c.Host == "" && c.Headers["Host"] != "" {
-		c.Host = c.Headers["Host"]
+	// Priority (client): host > serverName > address
+	for k, v := range c.Headers {
+		errors.PrintDeprecatedFeatureWarning(`"host" in "headers"`, `independent "host"`)
+		if c.Host == "" {
+			c.Host = v
+		}
+		delete(c.Headers, k)
 	}
 	config := &websocket.Config{
 		Path:                path,
@@ -202,15 +202,11 @@ func (c *HttpUpgradeConfig) Build() (proto.Message, error) {
 			path = u.String()
 		}
 	}
-	// If http host is not set in the Host field, but in headers field, we add it to Host Field here.
-	// If we don't do that, http host will be overwritten as address.
-	// Host priority: Host field > headers field > address.
-	if c.Host == "" && c.Headers["host"] != "" {
-		c.Host = c.Headers["host"]
-		delete(c.Headers, "host")
-	} else if c.Host == "" && c.Headers["Host"] != "" {
-		c.Host = c.Headers["Host"]
-		delete(c.Headers, "Host")
+	// Priority (client): host > serverName > address
+	for k := range c.Headers {
+		if strings.ToLower(k) == "host" {
+			return nil, errors.New(`"headers" can't contain "host"`)
+		}
 	}
 	config := &httpupgrade.Config{
 		Path:                path,
@@ -274,13 +270,11 @@ func (c *SplitHTTPConfig) Build() (proto.Message, error) {
 		c = &extra
 	}
 
-	// If http host is not set in the Host field, but in headers field, we add it to Host Field here.
-	// If we don't do that, http host will be overwritten as address.
-	// Host priority: Host field > headers field > address.
-	if c.Host == "" && c.Headers["host"] != "" {
-		c.Host = c.Headers["host"]
-	} else if c.Host == "" && c.Headers["Host"] != "" {
-		c.Host = c.Headers["Host"]
+	// Priority (client): host > serverName > address
+	for k := range c.Headers {
+		if strings.ToLower(k) == "host" {
+			return nil, errors.New(`"headers" can't contain "host"`)
+		}
 	}
 
 	if c.Xmux.MaxConnections != nil && c.Xmux.MaxConnections.To > 0 && c.Xmux.MaxConcurrency != nil && c.Xmux.MaxConcurrency.To > 0 {

+ 1 - 6
infra/conf/xray_test.go

@@ -48,9 +48,7 @@ func TestXrayConfig(t *testing.T) {
 					"streamSettings": {
 						"network": "ws",
 						"wsSettings": {
-							"headers": {
-								"host": "example.domain"
-							},
+							"host": "example.domain",
 							"path": ""
 						},
 						"tlsSettings": {
@@ -139,9 +137,6 @@ func TestXrayConfig(t *testing.T) {
 										ProtocolName: "websocket",
 										Settings: serial.ToTypedMessage(&websocket.Config{
 											Host: "example.domain",
-											Header: map[string]string{
-												"host": "example.domain",
-											},
 										}),
 									},
 								},

+ 11 - 5
transport/internet/httpupgrade/dialer.go

@@ -53,9 +53,10 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings *
 
 	var conn net.Conn
 	var requestURL url.URL
-	if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
-		tlsConfig := config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
-		if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil {
+	tConfig := tls.ConfigFromStreamSettings(streamSettings)
+	if tConfig != nil {
+		tlsConfig := tConfig.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
+		if fingerprint := tls.GetFingerprint(tConfig.Fingerprint); fingerprint != nil {
 			conn = tls.UClient(pconn, tlsConfig, fingerprint)
 			if err := conn.(*tls.UConn).WebsocketHandshakeContext(ctx); err != nil {
 				return nil, err
@@ -69,12 +70,17 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings *
 		requestURL.Scheme = "http"
 	}
 
-	requestURL.Host = dest.NetAddr()
+	requestURL.Host = transportConfiguration.Host
+	if requestURL.Host == "" && tConfig != nil {
+		requestURL.Host = tConfig.ServerName
+	}
+	if requestURL.Host == "" {
+		requestURL.Host = dest.Address.String()
+	}
 	requestURL.Path = transportConfiguration.GetNormalizedPath()
 	req := &http.Request{
 		Method: http.MethodGet,
 		URL:    &requestURL,
-		Host:   transportConfiguration.Host,
 		Header: make(http.Header),
 	}
 	for key, value := range transportConfiguration.Header {

+ 19 - 4
transport/internet/splithttp/dialer.go

@@ -259,8 +259,14 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 		requestURL.Scheme = "http"
 	}
 	requestURL.Host = transportConfiguration.Host
+	if requestURL.Host == "" && tlsConfig != nil {
+		requestURL.Host = tlsConfig.ServerName
+	}
+	if requestURL.Host == "" && realityConfig != nil {
+		requestURL.Host = realityConfig.ServerName
+	}
 	if requestURL.Host == "" {
-		requestURL.Host = dest.NetAddr()
+		requestURL.Host = dest.Address.String()
 	}
 
 	sessionIdUuid := uuid.New()
@@ -279,16 +285,25 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 		}
 		globalDialerAccess.Unlock()
 		memory2 := streamSettings.DownloadSettings
-		httpClient2, muxRes2 = getHTTPClient(ctx, *memory2.Destination, memory2) // just panic
-		if tls.ConfigFromStreamSettings(memory2) != nil || reality.ConfigFromStreamSettings(memory2) != nil {
+		dest2 := *memory2.Destination // just panic
+		httpClient2, muxRes2 = getHTTPClient(ctx, dest2, memory2)
+		tlsConfig2 := tls.ConfigFromStreamSettings(memory2)
+		realityConfig2 := reality.ConfigFromStreamSettings(memory2)
+		if tlsConfig2 != nil || realityConfig2 != nil {
 			requestURL2.Scheme = "https"
 		} else {
 			requestURL2.Scheme = "http"
 		}
 		config2 := memory2.ProtocolSettings.(*Config)
 		requestURL2.Host = config2.Host
+		if requestURL2.Host == "" && tlsConfig2 != nil {
+			requestURL2.Host = tlsConfig2.ServerName
+		}
+		if requestURL2.Host == "" && realityConfig2 != nil {
+			requestURL2.Host = realityConfig2.ServerName
+		}
 		if requestURL2.Host == "" {
-			requestURL2.Host = memory2.Destination.NetAddr()
+			requestURL2.Host = dest2.Address.String()
 		}
 		requestURL2.Path = config2.GetNormalizedPath() + sessionIdUuid.String()
 		requestURL2.RawQuery = config2.GetNormalizedQuery()

+ 0 - 1
transport/internet/websocket/config.go

@@ -23,7 +23,6 @@ func (c *Config) GetRequestHeader() http.Header {
 	for k, v := range c.Header {
 		header.Add(k, v)
 	}
-	header.Set("Host", c.Host)
 	return header
 }
 

+ 12 - 3
transport/internet/websocket/dialer.go

@@ -58,11 +58,12 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
 
 	protocol := "ws"
 
-	if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
+	tConfig := tls.ConfigFromStreamSettings(streamSettings)
+	if tConfig != nil {
 		protocol = "wss"
-		tlsConfig := config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
+		tlsConfig := tConfig.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
 		dialer.TLSClientConfig = tlsConfig
-		if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil {
+		if fingerprint := tls.GetFingerprint(tConfig.Fingerprint); fingerprint != nil {
 			dialer.NetDialTLSContext = func(_ context.Context, _, addr string) (gonet.Conn, error) {
 				// Like the NetDial in the dialer
 				pconn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
@@ -103,6 +104,14 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
 	}
 
 	header := wsSettings.GetRequestHeader()
+	// See dialer.DialContext()
+	header.Set("Host", wsSettings.Host)
+	if header.Get("Host") == "" && tConfig != nil {
+		header.Set("Host", tConfig.ServerName)
+	}
+	if header.Get("Host") == "" {
+		header.Set("Host", dest.Address.String())
+	}
 	if ed != nil {
 		// RawURLEncoding is support by both V2Ray/V2Fly and XRay.
 		header.Set("Sec-WebSocket-Protocol", base64.RawURLEncoding.EncodeToString(ed))