ech.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. package tls
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/tls"
  6. "encoding/base64"
  7. "encoding/binary"
  8. "fmt"
  9. "io"
  10. "net/http"
  11. "net/url"
  12. "strings"
  13. "sync"
  14. "sync/atomic"
  15. "time"
  16. utls "github.com/refraction-networking/utls"
  17. "github.com/xtls/xray-core/common/crypto"
  18. dns2 "github.com/xtls/xray-core/features/dns"
  19. "golang.org/x/net/http2"
  20. "github.com/miekg/dns"
  21. "github.com/xtls/xray-core/common/errors"
  22. "github.com/xtls/xray-core/common/net"
  23. "github.com/xtls/xray-core/common/utils"
  24. "github.com/xtls/xray-core/transport/internet"
  25. "golang.org/x/crypto/cryptobyte"
  26. )
  27. func ApplyECH(c *Config, config *tls.Config) error {
  28. var ECHConfig []byte
  29. var err error
  30. var nameToQuery string
  31. if net.ParseAddress(config.ServerName).Family().IsDomain() {
  32. nameToQuery = config.ServerName
  33. }
  34. var DNSServer string
  35. // for server
  36. if len(c.EchServerKeys) != 0 {
  37. KeySets, err := ConvertToGoECHKeys(c.EchServerKeys)
  38. if err != nil {
  39. return errors.New("Failed to unmarshal ECHKeySetList: ", err)
  40. }
  41. config.EncryptedClientHelloKeys = KeySets
  42. }
  43. // for client
  44. if len(c.EchConfigList) != 0 {
  45. ECHForceQuery := c.EchForceQuery
  46. switch ECHForceQuery {
  47. case "none", "half", "full":
  48. case "":
  49. ECHForceQuery = "none" // default to none
  50. default:
  51. panic("Invalid ECHForceQuery: " + c.EchForceQuery)
  52. }
  53. defer func() {
  54. // if failed to get ECHConfig, use an invalid one to make connection fail
  55. if err != nil || len(ECHConfig) == 0 {
  56. if ECHForceQuery == "full" {
  57. ECHConfig = []byte{1, 1, 4, 5, 1, 4}
  58. }
  59. }
  60. config.EncryptedClientHelloConfigList = ECHConfig
  61. }()
  62. // direct base64 config
  63. if strings.Contains(c.EchConfigList, "://") {
  64. // query config from dns
  65. parts := strings.Split(c.EchConfigList, "+")
  66. if len(parts) == 2 {
  67. // parse ECH DNS server in format of "example.com+https://1.1.1.1/dns-query"
  68. nameToQuery = parts[0]
  69. DNSServer = parts[1]
  70. } else if len(parts) == 1 {
  71. // normal format
  72. DNSServer = parts[0]
  73. } else {
  74. return errors.New("Invalid ECH DNS server format: ", c.EchConfigList)
  75. }
  76. if nameToQuery == "" {
  77. return errors.New("Using DNS for ECH Config needs serverName or use Server format example.com+https://1.1.1.1/dns-query")
  78. }
  79. ECHConfig, err = QueryRecord(nameToQuery, DNSServer, c.EchForceQuery, c.EchSocketSettings)
  80. if err != nil {
  81. return errors.New("Failed to query ECH DNS record for domain: ", nameToQuery, " at server: ", DNSServer).Base(err)
  82. }
  83. } else {
  84. ECHConfig, err = base64.StdEncoding.DecodeString(c.EchConfigList)
  85. if err != nil {
  86. return errors.New("Failed to unmarshal ECHConfigList: ", err)
  87. }
  88. }
  89. }
  90. return nil
  91. }
  92. type ECHConfigCache struct {
  93. configRecord atomic.Pointer[echConfigRecord]
  94. // updateLock is not for preventing concurrent read/write, but for preventing concurrent update
  95. UpdateLock sync.Mutex
  96. }
  97. type echConfigRecord struct {
  98. config []byte
  99. expire time.Time
  100. err error
  101. }
  102. var (
  103. // The keys for both maps must be generated by ECHCacheKey().
  104. GlobalECHConfigCache = utils.NewTypedSyncMap[string, *ECHConfigCache]()
  105. clientForECHDOH = utils.NewTypedSyncMap[string, *http.Client]()
  106. )
  107. // sockopt can be nil if not specified.
  108. // if for clientForECHDOH, domain can be empty.
  109. func ECHCacheKey(server, domain string, sockopt *internet.SocketConfig) string {
  110. return server + "|" + domain + "|" + fmt.Sprintf("%p", sockopt)
  111. }
  112. // Update updates the ECH config for given domain and server.
  113. // this method is concurrent safe, only one update request will be sent, others get the cache.
  114. // if isLockedUpdate is true, it will not try to acquire the lock.
  115. func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool, forceQuery string, sockopt *internet.SocketConfig) ([]byte, error) {
  116. if !isLockedUpdate {
  117. c.UpdateLock.Lock()
  118. defer c.UpdateLock.Unlock()
  119. }
  120. // Double check cache after acquiring lock
  121. configRecord := c.configRecord.Load()
  122. if configRecord.expire.After(time.Now()) && configRecord.err == nil {
  123. errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain)
  124. return configRecord.config, configRecord.err
  125. }
  126. // Query ECH config from DNS server
  127. errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server)
  128. echConfig, ttl, err := dnsQuery(server, domain, sockopt)
  129. // if in "full", directly return
  130. if err != nil && forceQuery == "full" {
  131. return nil, err
  132. }
  133. if ttl == 0 {
  134. ttl = dns2.DefaultTTL
  135. }
  136. configRecord = &echConfigRecord{
  137. config: echConfig,
  138. expire: time.Now().Add(time.Duration(ttl) * time.Second),
  139. err: err,
  140. }
  141. c.configRecord.Store(configRecord)
  142. return configRecord.config, configRecord.err
  143. }
  144. // QueryRecord returns the ECH config for given domain.
  145. // If the record is not in cache or expired, it will query the DNS server and update the cache.
  146. func QueryRecord(domain string, server string, forceQuery string, sockopt *internet.SocketConfig) ([]byte, error) {
  147. GlobalECHConfigCacheKey := ECHCacheKey(server, domain, sockopt)
  148. echConfigCache, ok := GlobalECHConfigCache.Load(GlobalECHConfigCacheKey)
  149. if !ok {
  150. echConfigCache = &ECHConfigCache{}
  151. echConfigCache.configRecord.Store(&echConfigRecord{})
  152. echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(GlobalECHConfigCacheKey, echConfigCache)
  153. }
  154. configRecord := echConfigCache.configRecord.Load()
  155. if configRecord.expire.After(time.Now()) && (configRecord.err == nil || forceQuery == "none") {
  156. errors.LogDebug(context.Background(), "Cache hit for domain: ", domain)
  157. return configRecord.config, configRecord.err
  158. }
  159. // If expire is zero value, it means we are in initial state, wait for the query to finish
  160. // otherwise return old value immediately and update in a goroutine
  161. // but if the cache is too old, wait for update
  162. if configRecord.expire == (time.Time{}) || configRecord.expire.Add(time.Hour*6).Before(time.Now()) {
  163. return echConfigCache.Update(domain, server, false, forceQuery, sockopt)
  164. } else {
  165. // If someone already acquired the lock, it means it is updating, do not start another update goroutine
  166. if echConfigCache.UpdateLock.TryLock() {
  167. go func() {
  168. defer echConfigCache.UpdateLock.Unlock()
  169. echConfigCache.Update(domain, server, true, forceQuery, sockopt)
  170. }()
  171. }
  172. return configRecord.config, configRecord.err
  173. }
  174. }
  175. // dnsQuery is the real func for sending type65 query for given domain to given DNS server.
  176. // return ECH config, TTL and error
  177. func dnsQuery(server string, domain string, sockopt *internet.SocketConfig) ([]byte, uint32, error) {
  178. m := new(dns.Msg)
  179. var dnsResolve []byte
  180. m.SetQuestion(dns.Fqdn(domain), dns.TypeHTTPS)
  181. // for DOH server
  182. if strings.HasPrefix(server, "https://") || strings.HasPrefix(server, "h2c://") {
  183. h2c := strings.HasPrefix(server, "h2c://")
  184. m.SetEdns0(4096, false) // 4096 is the buffer size, false means no DNSSEC
  185. padding := &dns.EDNS0_PADDING{Padding: make([]byte, int(crypto.RandBetween(100, 300)))}
  186. if opt := m.IsEdns0(); opt != nil {
  187. opt.Option = append(opt.Option, padding)
  188. }
  189. // always 0 in DOH
  190. m.Id = 0
  191. msg, err := m.Pack()
  192. if err != nil {
  193. return nil, 0, err
  194. }
  195. var client *http.Client
  196. serverKey := ECHCacheKey(server, "", sockopt)
  197. if client, _ = clientForECHDOH.Load(serverKey); client == nil {
  198. // All traffic sent by core should via xray's internet.DialSystem
  199. // This involves the behavior of some Android VPN GUI clients
  200. tr := &http2.Transport{
  201. IdleConnTimeout: net.ConnIdleTimeout,
  202. ReadIdleTimeout: net.ChromeH2KeepAlivePeriod,
  203. DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
  204. dest, err := net.ParseDestination(network + ":" + addr)
  205. if err != nil {
  206. return nil, err
  207. }
  208. var conn net.Conn
  209. conn, err = internet.DialSystem(ctx, dest, sockopt)
  210. if err != nil {
  211. return nil, err
  212. }
  213. if !h2c {
  214. u, err := url.Parse(server)
  215. if err != nil {
  216. return nil, err
  217. }
  218. conn = utls.UClient(conn, &utls.Config{ServerName: u.Hostname()}, utls.HelloChrome_Auto)
  219. if err := conn.(*utls.UConn).HandshakeContext(ctx); err != nil {
  220. return nil, err
  221. }
  222. }
  223. return conn, nil
  224. },
  225. }
  226. c := &http.Client{
  227. Timeout: 30 * time.Second,
  228. Transport: tr,
  229. }
  230. client, _ = clientForECHDOH.LoadOrStore(serverKey, c)
  231. }
  232. req, err := http.NewRequest("POST", server, bytes.NewReader(msg))
  233. if err != nil {
  234. return nil, 0, err
  235. }
  236. req.Header.Set("Accept", "application/dns-message")
  237. req.Header.Set("Content-Type", "application/dns-message")
  238. req.Header.Set("User-Agent", utils.ChromeUA)
  239. req.Header.Set("X-Padding", utils.H2Base62Pad(crypto.RandBetween(100, 1000)))
  240. resp, err := client.Do(req)
  241. if err != nil {
  242. return nil, 0, err
  243. }
  244. defer resp.Body.Close()
  245. respBody, err := io.ReadAll(resp.Body)
  246. if err != nil {
  247. return nil, 0, err
  248. }
  249. if resp.StatusCode != http.StatusOK {
  250. return nil, 0, errors.New("query failed with response code:", resp.StatusCode)
  251. }
  252. dnsResolve = respBody
  253. } else if strings.HasPrefix(server, "udp://") { // for classic udp dns server
  254. udpServerAddr := server[len("udp://"):]
  255. // default port 53 if not specified
  256. if !strings.Contains(udpServerAddr, ":") {
  257. udpServerAddr = udpServerAddr + ":53"
  258. }
  259. dest, err := net.ParseDestination("udp" + ":" + udpServerAddr)
  260. if err != nil {
  261. return nil, 0, errors.New("failed to parse udp dns server ", udpServerAddr, " for ECH: ", err)
  262. }
  263. dnsTimeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  264. defer cancel()
  265. // use xray's internet.DialSystem as mentioned above
  266. conn, err := internet.DialSystem(dnsTimeoutCtx, dest, sockopt)
  267. if err != nil {
  268. return nil, 0, err
  269. }
  270. defer func() {
  271. err := conn.Close()
  272. if err != nil {
  273. errors.LogDebug(context.Background(), "Failed to close connection: ", err)
  274. }
  275. }()
  276. msg, err := m.Pack()
  277. if err != nil {
  278. return nil, 0, err
  279. }
  280. conn.Write(msg)
  281. udpResponse := make([]byte, 512)
  282. conn.SetReadDeadline(time.Now().Add(5 * time.Second))
  283. _, err = conn.Read(udpResponse)
  284. if err != nil {
  285. return nil, 0, err
  286. }
  287. dnsResolve = udpResponse
  288. }
  289. respMsg := new(dns.Msg)
  290. err := respMsg.Unpack(dnsResolve)
  291. if err != nil {
  292. return nil, 0, errors.New("failed to unpack dns response for ECH: ", err)
  293. }
  294. if len(respMsg.Answer) > 0 {
  295. for _, answer := range respMsg.Answer {
  296. if https, ok := answer.(*dns.HTTPS); ok && https.Hdr.Name == dns.Fqdn(domain) {
  297. for _, v := range https.Value {
  298. if echConfig, ok := v.(*dns.SVCBECHConfig); ok {
  299. errors.LogDebug(context.Background(), "Get ECH config:", echConfig.String(), " TTL:", respMsg.Answer[0].Header().Ttl)
  300. return echConfig.ECH, answer.Header().Ttl, nil
  301. }
  302. }
  303. }
  304. }
  305. }
  306. // empty is valid, means no ECH config found
  307. return nil, dns2.DefaultTTL, nil
  308. }
  309. var ErrInvalidLen = errors.New("goech: invalid length")
  310. func ConvertToGoECHKeys(data []byte) ([]tls.EncryptedClientHelloKey, error) {
  311. var keys []tls.EncryptedClientHelloKey
  312. s := cryptobyte.String(data)
  313. for !s.Empty() {
  314. if len(s) < 2 {
  315. return keys, ErrInvalidLen
  316. }
  317. keyLength := int(binary.BigEndian.Uint16(s[:2]))
  318. if len(s) < keyLength+4 {
  319. return keys, ErrInvalidLen
  320. }
  321. configLength := int(binary.BigEndian.Uint16(s[keyLength+2 : keyLength+4]))
  322. if len(s) < 2+keyLength+2+configLength {
  323. return keys, ErrInvalidLen
  324. }
  325. child := cryptobyte.String(s[:2+keyLength+2+configLength])
  326. var (
  327. sk, config cryptobyte.String
  328. )
  329. if !child.ReadUint16LengthPrefixed(&sk) || !child.ReadUint16LengthPrefixed(&config) || !child.Empty() {
  330. return keys, ErrInvalidLen
  331. }
  332. if !s.Skip(2 + keyLength + 2 + configLength) {
  333. return keys, ErrInvalidLen
  334. }
  335. keys = append(keys, tls.EncryptedClientHelloKey{
  336. Config: config,
  337. PrivateKey: sk,
  338. })
  339. }
  340. return keys, nil
  341. }