Просмотр исходного кода

TLS ECH client: Add `echForceQuery` config (#4947)

https://github.com/XTLS/Xray-core/pull/4947#issuecomment-3124359776
风扇滑翔翼 10 месяцев назад
Родитель
Сommit
b2829219a0

+ 4 - 3
infra/conf/transport_internet.go

@@ -412,8 +412,9 @@ type TLSConfig struct {
 	MasterKeyLog                         string           `json:"masterKeyLog"`
 	ServerNameToVerify                   string           `json:"serverNameToVerify"`
 	VerifyPeerCertInNames                []string         `json:"verifyPeerCertInNames"`
-	ECHConfigList                        string           `json:"echConfigList"`
 	ECHServerKeys                        string           `json:"echServerKeys"`
+	ECHConfigList                        string           `json:"echConfigList"`
+	ECHForceQuery                        bool             `json:"echForceQuery"`
 }
 
 // Build implements Buildable.
@@ -485,8 +486,6 @@ func (c *TLSConfig) Build() (proto.Message, error) {
 	}
 	config.VerifyPeerCertInNames = c.VerifyPeerCertInNames
 
-	config.EchConfigList = c.ECHConfigList
-
 	if c.ECHServerKeys != "" {
 		EchPrivateKey, err := base64.StdEncoding.DecodeString(c.ECHServerKeys)
 		if err != nil {
@@ -494,6 +493,8 @@ func (c *TLSConfig) Build() (proto.Message, error) {
 		}
 		config.EchServerKeys = EchPrivateKey
 	}
+	config.EchForceQuery = c.ECHForceQuery
+	config.EchConfigList = c.ECHConfigList
 
 	return config, nil
 }

+ 30 - 20
transport/internet/tls/config.pb.go

@@ -217,8 +217,9 @@ type Config struct {
 	// @Document After allow_insecure (automatically), if the server's cert can't be verified by any of these names, pinned_peer_certificate_chain_sha256 will be tried.
 	// @Critical
 	VerifyPeerCertInNames []string `protobuf:"bytes,17,rep,name=verify_peer_cert_in_names,json=verifyPeerCertInNames,proto3" json:"verify_peer_cert_in_names,omitempty"`
-	EchConfigList         string   `protobuf:"bytes,18,opt,name=ech_config_list,json=echConfigList,proto3" json:"ech_config_list,omitempty"`
-	EchServerKeys         []byte   `protobuf:"bytes,19,opt,name=ech_server_keys,json=echServerKeys,proto3" json:"ech_server_keys,omitempty"`
+	EchServerKeys         []byte   `protobuf:"bytes,18,opt,name=ech_server_keys,json=echServerKeys,proto3" json:"ech_server_keys,omitempty"`
+	EchConfigList         string   `protobuf:"bytes,19,opt,name=ech_config_list,json=echConfigList,proto3" json:"ech_config_list,omitempty"`
+	EchForceQuery         bool     `protobuf:"varint,20,opt,name=ech_force_query,json=echForceQuery,proto3" json:"ech_force_query,omitempty"`
 }
 
 func (x *Config) Reset() {
@@ -363,6 +364,13 @@ func (x *Config) GetVerifyPeerCertInNames() []string {
 	return nil
 }
 
+func (x *Config) GetEchServerKeys() []byte {
+	if x != nil {
+		return x.EchServerKeys
+	}
+	return nil
+}
+
 func (x *Config) GetEchConfigList() string {
 	if x != nil {
 		return x.EchConfigList
@@ -370,11 +378,11 @@ func (x *Config) GetEchConfigList() string {
 	return ""
 }
 
-func (x *Config) GetEchServerKeys() []byte {
+func (x *Config) GetEchForceQuery() bool {
 	if x != nil {
-		return x.EchServerKeys
+		return x.EchForceQuery
 	}
-	return nil
+	return false
 }
 
 var File_transport_internet_tls_config_proto protoreflect.FileDescriptor
@@ -408,7 +416,7 @@ var file_transport_internet_tls_config_proto_rawDesc = []byte{
 	0x4e, 0x43, 0x49, 0x50, 0x48, 0x45, 0x52, 0x4d, 0x45, 0x4e, 0x54, 0x10, 0x00, 0x12, 0x14, 0x0a,
 	0x10, 0x41, 0x55, 0x54, 0x48, 0x4f, 0x52, 0x49, 0x54, 0x59, 0x5f, 0x56, 0x45, 0x52, 0x49, 0x46,
 	0x59, 0x10, 0x01, 0x12, 0x13, 0x0a, 0x0f, 0x41, 0x55, 0x54, 0x48, 0x4f, 0x52, 0x49, 0x54, 0x59,
-	0x5f, 0x49, 0x53, 0x53, 0x55, 0x45, 0x10, 0x02, 0x22, 0xea, 0x06, 0x0a, 0x06, 0x43, 0x6f, 0x6e,
+	0x5f, 0x49, 0x53, 0x53, 0x55, 0x45, 0x10, 0x02, 0x22, 0x92, 0x07, 0x0a, 0x06, 0x43, 0x6f, 0x6e,
 	0x66, 0x69, 0x67, 0x12, 0x25, 0x0a, 0x0e, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x5f, 0x69, 0x6e, 0x73,
 	0x65, 0x63, 0x75, 0x72, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x61, 0x6c, 0x6c,
 	0x6f, 0x77, 0x49, 0x6e, 0x73, 0x65, 0x63, 0x75, 0x72, 0x65, 0x12, 0x4a, 0x0a, 0x0b, 0x63, 0x65,
@@ -458,20 +466,22 @@ var file_transport_internet_tls_config_proto_rawDesc = []byte{
 	0x65, 0x72, 0x69, 0x66, 0x79, 0x5f, 0x70, 0x65, 0x65, 0x72, 0x5f, 0x63, 0x65, 0x72, 0x74, 0x5f,
 	0x69, 0x6e, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x18, 0x11, 0x20, 0x03, 0x28, 0x09, 0x52, 0x15,
 	0x76, 0x65, 0x72, 0x69, 0x66, 0x79, 0x50, 0x65, 0x65, 0x72, 0x43, 0x65, 0x72, 0x74, 0x49, 0x6e,
-	0x4e, 0x61, 0x6d, 0x65, 0x73, 0x12, 0x26, 0x0a, 0x0f, 0x65, 0x63, 0x68, 0x5f, 0x63, 0x6f, 0x6e,
-	0x66, 0x69, 0x67, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d,
-	0x65, 0x63, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x26, 0x0a,
-	0x0f, 0x65, 0x63, 0x68, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x6b, 0x65, 0x79, 0x73,
-	0x18, 0x13, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0d, 0x65, 0x63, 0x68, 0x53, 0x65, 0x72, 0x76, 0x65,
-	0x72, 0x4b, 0x65, 0x79, 0x73, 0x42, 0x73, 0x0a, 0x1f, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61,
-	0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65,
-	0x72, 0x6e, 0x65, 0x74, 0x2e, 0x74, 0x6c, 0x73, 0x50, 0x01, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68,
-	0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79,
-	0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f,
-	0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2f, 0x74, 0x6c, 0x73, 0xaa, 0x02, 0x1b, 0x58,
-	0x72, 0x61, 0x79, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x49, 0x6e,
-	0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x54, 0x6c, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
-	0x6f, 0x33,
+	0x4e, 0x61, 0x6d, 0x65, 0x73, 0x12, 0x26, 0x0a, 0x0f, 0x65, 0x63, 0x68, 0x5f, 0x73, 0x65, 0x72,
+	0x76, 0x65, 0x72, 0x5f, 0x6b, 0x65, 0x79, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0d,
+	0x65, 0x63, 0x68, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x73, 0x12, 0x26, 0x0a,
+	0x0f, 0x65, 0x63, 0x68, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x5f, 0x6c, 0x69, 0x73, 0x74,
+	0x18, 0x13, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x65, 0x63, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69,
+	0x67, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x0f, 0x65, 0x63, 0x68, 0x5f, 0x66, 0x6f, 0x72,
+	0x63, 0x65, 0x5f, 0x71, 0x75, 0x65, 0x72, 0x79, 0x18, 0x14, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d,
+	0x65, 0x63, 0x68, 0x46, 0x6f, 0x72, 0x63, 0x65, 0x51, 0x75, 0x65, 0x72, 0x79, 0x42, 0x73, 0x0a,
+	0x1f, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70,
+	0x6f, 0x72, 0x74, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x74, 0x6c, 0x73,
+	0x50, 0x01, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78,
+	0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x74, 0x72,
+	0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74,
+	0x2f, 0x74, 0x6c, 0x73, 0xaa, 0x02, 0x1b, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x54, 0x72, 0x61, 0x6e,
+	0x73, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x2e, 0x54,
+	0x6c, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
 }
 
 var (

+ 5 - 3
transport/internet/tls/config.proto

@@ -92,7 +92,9 @@ message Config {
   */
   repeated string verify_peer_cert_in_names = 17;
 
-  string ech_config_list = 18;
+  bytes ech_server_keys = 18;
 
-  bytes ech_server_keys = 19;
-}
+  string ech_config_list = 19;
+
+  bool ech_force_query = 20;
+}

+ 52 - 31
transport/internet/tls/ech.go

@@ -32,8 +32,26 @@ func ApplyECH(c *Config, config *tls.Config) error {
 	nameToQuery := c.ServerName
 	var DNSServer string
 
+	// for server
+	if len(c.EchServerKeys) != 0 {
+		KeySets, err := ConvertToGoECHKeys(c.EchServerKeys)
+		if err != nil {
+			return errors.New("Failed to unmarshal ECHKeySetList: ", err)
+		}
+		config.EncryptedClientHelloKeys = KeySets
+	}
+
 	// for client
 	if len(c.EchConfigList) != 0 {
+		defer func() {
+			// if failed to get ECHConfig, use an invalid one to make connection fail
+			if err != nil {
+				if c.EchForceQuery {
+					ECHConfig = []byte{1, 1, 4, 5, 1, 4}
+				}
+			}
+			config.EncryptedClientHelloConfigList = ECHConfig
+		}()
 		// direct base64 config
 		if strings.Contains(c.EchConfigList, "://") {
 			// query config from dns
@@ -51,7 +69,7 @@ func ApplyECH(c *Config, config *tls.Config) error {
 			if nameToQuery == "" {
 				return errors.New("Using DNS for ECH Config needs serverName or use Server format example.com+https://1.1.1.1/dns-query")
 			}
-			ECHConfig, err = QueryRecord(nameToQuery, DNSServer)
+			ECHConfig, err = QueryRecord(nameToQuery, DNSServer, c.EchForceQuery)
 			if err != nil {
 				return err
 			}
@@ -61,17 +79,6 @@ func ApplyECH(c *Config, config *tls.Config) error {
 				return errors.New("Failed to unmarshal ECHConfigList: ", err)
 			}
 		}
-
-		config.EncryptedClientHelloConfigList = ECHConfig
-	}
-
-	// for server
-	if len(c.EchServerKeys) != 0 {
-		KeySets, err := ConvertToGoECHKeys(c.EchServerKeys)
-		if err != nil {
-			return errors.New("Failed to unmarshal ECHKeySetList: ", err)
-		}
-		config.EncryptedClientHelloKeys = KeySets
 	}
 
 	return nil
@@ -86,9 +93,11 @@ type ECHConfigCache struct {
 type echConfigRecord struct {
 	config []byte
 	expire time.Time
+	err    error
 }
 
 var (
+	// key value must be like this: "example.com|udp://1.1.1.1"
 	GlobalECHConfigCache = utils.NewTypedSyncMap[string, *ECHConfigCache]()
 	clientForECHDOH      = utils.NewTypedSyncMap[string, *http.Client]()
 )
@@ -96,7 +105,7 @@ var (
 // Update updates the ECH config for given domain and server.
 // this method is concurrent safe, only one update request will be sent, others get the cache.
 // if isLockedUpdate is true, it will not try to acquire the lock.
-func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool) ([]byte, error) {
+func (c *ECHConfigCache) Update(domain string, server string, forceQuery bool, isLockedUpdate bool) ([]byte, error) {
 	if !isLockedUpdate {
 		c.UpdateLock.Lock()
 		defer c.UpdateLock.Unlock()
@@ -105,13 +114,23 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo
 	configRecord := c.configRecord.Load()
 	if configRecord.expire.After(time.Now()) {
 		errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain)
-		return configRecord.config, nil
+		return configRecord.config, configRecord.err
 	}
 	// Query ECH config from DNS server
 	errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server)
 	echConfig, ttl, err := dnsQuery(server, domain)
 	if err != nil {
-		return nil, err
+		if forceQuery {
+			return nil, err
+		} else {
+			configRecord = &echConfigRecord{
+				config: nil,
+				expire: time.Now().Add(10 * time.Minute),
+				err:    err,
+			}
+			c.configRecord.Store(configRecord)
+			return echConfig, err
+		}
 	}
 	configRecord = &echConfigRecord{
 		config: echConfig,
@@ -123,30 +142,31 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo
 
 // QueryRecord returns the ECH config for given domain.
 // If the record is not in cache or expired, it will query the DNS server and update the cache.
-func QueryRecord(domain string, server string) ([]byte, error) {
-	echConfigCache, ok := GlobalECHConfigCache.Load(domain)
+func QueryRecord(domain string, server string, forceQuery bool) ([]byte, error) {
+	GlobalECHConfigCacheKey := domain + "|" + server
+	echConfigCache, ok := GlobalECHConfigCache.Load(GlobalECHConfigCacheKey)
 	if !ok {
 		echConfigCache = &ECHConfigCache{}
 		echConfigCache.configRecord.Store(&echConfigRecord{})
-		echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(domain, echConfigCache)
+		echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(GlobalECHConfigCacheKey, echConfigCache)
 	}
 	configRecord := echConfigCache.configRecord.Load()
 	if configRecord.expire.After(time.Now()) {
 		errors.LogDebug(context.Background(), "Cache hit for domain: ", domain)
-		return configRecord.config, nil
+		return configRecord.config, configRecord.err
 	}
 
 	// If expire is zero value, it means we are in initial state, wait for the query to finish
 	// otherwise return old value immediately and update in a goroutine
 	// but if the cache is too old, wait for update
 	if configRecord.expire == (time.Time{}) || configRecord.expire.Add(time.Hour*6).Before(time.Now()) {
-		return echConfigCache.Update(domain, server, false)
+		return echConfigCache.Update(domain, server, false, forceQuery)
 	} else {
 		// If someone already acquired the lock, it means it is updating, do not start another update goroutine
 		if echConfigCache.UpdateLock.TryLock() {
 			go func() {
 				defer echConfigCache.UpdateLock.Unlock()
-				echConfigCache.Update(domain, server, true)
+				echConfigCache.Update(domain, server, true, forceQuery)
 			}()
 		}
 		return configRecord.config, nil
@@ -165,7 +185,7 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
 		m.Id = 0
 		msg, err := m.Pack()
 		if err != nil {
-			return []byte{}, 0, err
+			return nil, 0, err
 		}
 		var client *http.Client
 		if client, _ = clientForECHDOH.Load(server); client == nil {
@@ -194,20 +214,20 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
 		}
 		req, err := http.NewRequest("POST", server, bytes.NewReader(msg))
 		if err != nil {
-			return []byte{}, 0, err
+			return nil, 0, err
 		}
 		req.Header.Set("Content-Type", "application/dns-message")
 		resp, err := client.Do(req)
 		if err != nil {
-			return []byte{}, 0, err
+			return nil, 0, err
 		}
 		defer resp.Body.Close()
 		respBody, err := io.ReadAll(resp.Body)
 		if err != nil {
-			return []byte{}, 0, err
+			return nil, 0, err
 		}
 		if resp.StatusCode != http.StatusOK {
-			return []byte{}, 0, errors.New("query failed with response code:", resp.StatusCode)
+			return nil, 0, errors.New("query failed with response code:", resp.StatusCode)
 		}
 		dnsResolve = respBody
 	} else if strings.HasPrefix(server, "udp://") { // for classic udp dns server
@@ -231,24 +251,25 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
 			}
 		}()
 		if err != nil {
-			return []byte{}, 0, err
+			return nil, 0, err
 		}
 		msg, err := m.Pack()
 		if err != nil {
-			return []byte{}, 0, err
+			return nil, 0, err
 		}
 		conn.Write(msg)
 		udpResponse := make([]byte, 512)
+		conn.SetReadDeadline(time.Now().Add(5 * time.Second))
 		_, err = conn.Read(udpResponse)
 		if err != nil {
-			return []byte{}, 0, err
+			return nil, 0, err
 		}
 		dnsResolve = udpResponse
 	}
 	respMsg := new(dns.Msg)
 	err := respMsg.Unpack(dnsResolve)
 	if err != nil {
-		return []byte{}, 0, errors.New("failed to unpack dns response for ECH: ", err)
+		return nil, 0, errors.New("failed to unpack dns response for ECH: ", err)
 	}
 	if len(respMsg.Answer) > 0 {
 		for _, answer := range respMsg.Answer {
@@ -262,7 +283,7 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
 			}
 		}
 	}
-	return []byte{}, 0, errors.New("no ech record found")
+	return nil, 0, errors.New("no ech record found")
 }
 
 // reference github.com/OmarTariq612/goech

+ 51 - 5
transport/internet/tls/ech_test.go

@@ -1,4 +1,4 @@
-package tls_test
+package tls
 
 import (
 	"io"
@@ -8,13 +8,12 @@ import (
 	"testing"
 
 	"github.com/xtls/xray-core/common"
-	. "github.com/xtls/xray-core/transport/internet/tls"
 )
 
 func TestECHDial(t *testing.T) {
 	config := &Config{
-		ServerName:    "encryptedsni.com",
-		EchConfigList: "udp://1.1.1.1",
+		ServerName:    "cloudflare.com",
+		EchConfigList: "encryptedsni.com+udp://1.1.1.1",
 	}
 	// test concurrent Dial(to test cache problem)
 	wg := sync.WaitGroup{}
@@ -28,7 +27,7 @@ func TestECHDial(t *testing.T) {
 					TLSClientConfig: TLSConfig,
 				},
 			}
-			resp, err := client.Get("https://encryptedsni.com/cdn-cgi/trace")
+			resp, err := client.Get("https://cloudflare.com/cdn-cgi/trace")
 			common.Must(err)
 			defer resp.Body.Close()
 			body, err := io.ReadAll(resp.Body)
@@ -40,4 +39,51 @@ func TestECHDial(t *testing.T) {
 		}()
 	}
 	wg.Wait()
+	// check cache
+	echConfigCache, ok := GlobalECHConfigCache.Load("encryptedsni.com|udp://1.1.1.1")
+	if !ok {
+		t.Error("ECH config cache not found")
+
+	}
+	ok = echConfigCache.UpdateLock.TryLock()
+	if !ok {
+		t.Error("ECH config cache dead lock detected")
+	}
+	echConfigCache.UpdateLock.Unlock()
+	configRecord := echConfigCache.configRecord.Load()
+	if configRecord == nil {
+		t.Error("ECH config record not found in cache")
+	}
+}
+
+func TestECHDialFail(t *testing.T) {
+	config := &Config{
+		ServerName:    "cloudflare.com",
+		EchConfigList: "udp://1.1.1.1",
+	}
+	TLSConfig := config.GetTLSConfig()
+	TLSConfig.NextProtos = []string{"http/1.1"}
+	client := &http.Client{
+		Transport: &http.Transport{
+			TLSClientConfig: TLSConfig,
+		},
+	}
+	resp, err := client.Get("https://cloudflare.com/cdn-cgi/trace")
+	common.Must(err)
+	defer resp.Body.Close()
+	_, err = io.ReadAll(resp.Body)
+	common.Must(err)
+	// check cache
+	echConfigCache, ok := GlobalECHConfigCache.Load("cloudflare.com|udp://1.1.1.1")
+	if !ok {
+		t.Error("ECH config cache not found")
+	}
+	configRecord := echConfigCache.configRecord.Load()
+	if configRecord == nil {
+		t.Error("ECH config record not found in cache")
+		return
+	}
+	if configRecord.err == nil {
+		t.Error("unexpected nil error in ECH config record")
+	}
 }