u_conn.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. // Copyright 2017 Google Inc. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package tls
  5. import (
  6. "bufio"
  7. "bytes"
  8. "crypto/cipher"
  9. "encoding/binary"
  10. "errors"
  11. "io"
  12. "net"
  13. "strconv"
  14. "sync"
  15. "sync/atomic"
  16. )
  17. type UConn struct {
  18. *Conn
  19. Extensions []TLSExtension
  20. clientHelloID ClientHelloID
  21. HandshakeState ClientHandshakeState
  22. HandshakeStateBuilt bool
  23. // IncludeEmptySNI indicates to include an SNI extension when the
  24. // ServerName is "". This is non-standard behavior. Common TLS
  25. // implementations (Go, BoringSSL, etc.) omit the SNI extention in
  26. // this case.
  27. //
  28. // One concrete instance is when the remote host name is an IP address;
  29. // https://tools.ietf.org/html/rfc6066#section-3 prohibits an SNI with an
  30. // IP address.
  31. //
  32. // Go's hostnameInSNI sets the ServerName to "":
  33. // https://github.com/golang/go/blob/release-branch.go1.9/src/crypto/tls/handshake_client.go#L804
  34. //
  35. // And then omits the SNI extension:
  36. // https://github.com/golang/go/blob/release-branch.go1.9/src/crypto/tls/handshake_messages.go#L150
  37. //
  38. // IncludeEmptySNI is set to true for test runs, as test data expects
  39. // empty SNI extensions.
  40. IncludeEmptySNI bool
  41. }
  42. // UClient returns a new uTLS client, with behavior depending on clientHelloID.
  43. // Config CAN be nil, but make sure to eventually specify ServerName.
  44. func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn {
  45. if config == nil {
  46. config = &Config{}
  47. }
  48. tlsConn := Conn{conn: conn, config: config, isClient: true}
  49. handshakeState := ClientHandshakeState{C: &tlsConn, Hello: &ClientHelloMsg{}}
  50. uconn := UConn{Conn: &tlsConn, clientHelloID: clientHelloID, HandshakeState: handshakeState}
  51. return &uconn
  52. }
  53. // BuildHandshakeState() overwrites most fields, therefore, it is advised to manually call this function,
  54. // if you need to inspect/change contents after parroting/making default Golang ClientHello.
  55. // Otherwise, there is no need to call this function explicitly.
  56. func (uconn *UConn) BuildHandshakeState() error {
  57. if uconn.clientHelloID == HelloGolang {
  58. // use default Golang ClientHello.
  59. hello, err := makeClientHello(uconn.config)
  60. if uconn.HandshakeState.Session != nil {
  61. // session is lost at makeClientHello(), let's reapply
  62. uconn.SetSessionState(uconn.HandshakeState.Session)
  63. }
  64. if err != nil {
  65. return err
  66. }
  67. uconn.HandshakeState.Hello = hello.getPublicPtr()
  68. } else {
  69. err := uconn.generateClientHelloConfig(uconn.clientHelloID)
  70. if err != nil {
  71. return err
  72. }
  73. err = uconn.ApplyConfig()
  74. if err != nil {
  75. return err
  76. }
  77. err = uconn.MarshalClientHello()
  78. if err != nil {
  79. return err
  80. }
  81. }
  82. uconn.HandshakeStateBuilt = true
  83. return nil
  84. }
  85. // If you want you session tickets to be reused - use same cache on following connections
  86. func (uconn *UConn) SetSessionState(session *ClientSessionState) {
  87. uconn.HandshakeState.Session = session
  88. if session != nil {
  89. uconn.HandshakeState.Hello.SessionTicket = session.sessionTicket
  90. }
  91. uconn.HandshakeState.Hello.TicketSupported = true
  92. for _, ext := range uconn.Extensions {
  93. st, ok := ext.(*SessionTicketExtension)
  94. if ok {
  95. st.Session = session
  96. }
  97. }
  98. }
  99. // If you want you session tickets to be reused - use same cache on following connections
  100. func (uconn *UConn) SetSessionCache(cache ClientSessionCache) {
  101. uconn.config.ClientSessionCache = cache
  102. uconn.HandshakeState.Hello.TicketSupported = true
  103. }
  104. // r has to be 32 bytes long
  105. func (uconn *UConn) SetClientRandom(r []byte) error {
  106. if len(r) != 32 {
  107. return errors.New("Incorrect client random length! Expected: 32, got: " + strconv.Itoa(len(r)))
  108. } else {
  109. uconn.HandshakeState.Hello.Random = make([]byte, 32)
  110. copy(uconn.HandshakeState.Hello.Random, r)
  111. return nil
  112. }
  113. }
  114. func (uconn *UConn) SetSNI(sni string) {
  115. hname := hostnameInSNI(sni)
  116. uconn.config.ServerName = hname
  117. for _, ext := range uconn.Extensions {
  118. sniExt, ok := ext.(*SNIExtension)
  119. if ok {
  120. sniExt.ServerName = hname
  121. }
  122. }
  123. }
  124. // Handshake runs the client handshake using given clientHandshakeState
  125. // Requires hs.hello, and, optionally, hs.session to be set.
  126. func (c *UConn) Handshake() error {
  127. // This code was copied almost as is from tls/conn.go
  128. // c.handshakeErr and c.handshakeComplete are protected by
  129. // c.handshakeMutex. In order to perform a handshake, we need to lock
  130. // c.in also and c.handshakeMutex must be locked after c.in.
  131. //
  132. // However, if a Read() operation is hanging then it'll be holding the
  133. // lock on c.in and so taking it here would cause all operations that
  134. // need to check whether a handshake is pending (such as Write) to
  135. // block.
  136. //
  137. // Thus we first take c.handshakeMutex to check whether a handshake is
  138. // needed.
  139. //
  140. // If so then, previously, this code would unlock handshakeMutex and
  141. // then lock c.in and handshakeMutex in the correct order to run the
  142. // handshake. The problem was that it was possible for a Read to
  143. // complete the handshake once handshakeMutex was unlocked and then
  144. // keep c.in while waiting for network data. Thus a concurrent
  145. // operation could be blocked on c.in.
  146. //
  147. // Thus handshakeCond is used to signal that a goroutine is committed
  148. // to running the handshake and other goroutines can wait on it if they
  149. // need. handshakeCond is protected by handshakeMutex.
  150. c.handshakeMutex.Lock()
  151. defer c.handshakeMutex.Unlock()
  152. for {
  153. if err := c.handshakeErr; err != nil {
  154. return err
  155. }
  156. if c.handshakeComplete {
  157. return nil
  158. }
  159. if c.handshakeCond == nil {
  160. break
  161. }
  162. c.handshakeCond.Wait()
  163. }
  164. // Set handshakeCond to indicate that this goroutine is committing to
  165. // running the handshake.
  166. c.handshakeCond = sync.NewCond(&c.handshakeMutex)
  167. c.handshakeMutex.Unlock()
  168. c.in.Lock()
  169. defer c.in.Unlock()
  170. c.handshakeMutex.Lock()
  171. // The handshake cannot have completed when handshakeMutex was unlocked
  172. // because this goroutine set handshakeCond.
  173. if c.handshakeErr != nil || c.handshakeComplete {
  174. panic("handshake should not have been able to complete after handshakeCond was set")
  175. }
  176. if !c.isClient {
  177. panic("Servers should not call ClientHandshakeWithState()")
  178. }
  179. if !c.HandshakeStateBuilt {
  180. err := c.BuildHandshakeState()
  181. if err != nil {
  182. return err
  183. }
  184. }
  185. privateState := c.HandshakeState.getPrivatePtr()
  186. c.handshakeErr = c.clientHandshakeWithState(privateState)
  187. c.HandshakeState = *privateState.getPublicPtr()
  188. if c.handshakeErr == nil {
  189. c.handshakes++
  190. } else {
  191. // If an error occurred during the hadshake try to flush the
  192. // alert that might be left in the buffer.
  193. c.flush()
  194. }
  195. if c.handshakeErr == nil && !c.handshakeComplete {
  196. panic("handshake should have had a result.")
  197. }
  198. // Wake any other goroutines that are waiting for this handshake to complete.
  199. c.handshakeCond.Broadcast()
  200. c.handshakeCond = nil
  201. return c.handshakeErr
  202. }
  203. // Copy-pasted from tls.Conn in its entirety. But c.Handshake() is now utls' one, not tls.
  204. // Write writes data to the connection.
  205. func (c *UConn) Write(b []byte) (int, error) {
  206. // interlock with Close below
  207. for {
  208. x := atomic.LoadInt32(&c.activeCall)
  209. if x&1 != 0 {
  210. return 0, errClosed
  211. }
  212. if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
  213. defer atomic.AddInt32(&c.activeCall, -2)
  214. break
  215. }
  216. }
  217. if err := c.Handshake(); err != nil {
  218. return 0, err
  219. }
  220. c.out.Lock()
  221. defer c.out.Unlock()
  222. if err := c.out.err; err != nil {
  223. return 0, err
  224. }
  225. if !c.handshakeComplete {
  226. return 0, alertInternalError
  227. }
  228. if c.closeNotifySent {
  229. return 0, errShutdown
  230. }
  231. // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
  232. // attack when using block mode ciphers due to predictable IVs.
  233. // This can be prevented by splitting each Application Data
  234. // record into two records, effectively randomizing the IV.
  235. //
  236. // http://www.openssl.org/~bodo/tls-cbc.txt
  237. // https://bugzilla.mozilla.org/show_bug.cgi?id=665814
  238. // http://www.imperialviolet.org/2012/01/15/beastfollowup.html
  239. var m int
  240. if len(b) > 1 && c.vers <= VersionTLS10 {
  241. if _, ok := c.out.cipher.(cipher.BlockMode); ok {
  242. n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
  243. if err != nil {
  244. return n, c.out.setErrorLocked(err)
  245. }
  246. m, b = 1, b[1:]
  247. }
  248. }
  249. n, err := c.writeRecordLocked(recordTypeApplicationData, b)
  250. return n + m, c.out.setErrorLocked(err)
  251. }
  252. // c.out.Mutex <= L; c.handshakeMutex <= L.
  253. func (c *UConn) clientHandshakeWithState(hs *clientHandshakeState) error {
  254. // This code was copied almost as is from tls/handshake_client.go
  255. if c.config == nil {
  256. c.config = &Config{}
  257. }
  258. // This may be a renegotiation handshake, in which case some fields
  259. // need to be reset.
  260. c.didResume = false
  261. if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify {
  262. return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
  263. }
  264. nextProtosLength := 0
  265. for _, proto := range c.config.NextProtos {
  266. if l := len(proto); l == 0 || l > 255 {
  267. return errors.New("tls: invalid NextProtos value")
  268. } else {
  269. nextProtosLength += 1 + l
  270. }
  271. }
  272. if nextProtosLength > 0xffff {
  273. return errors.New("tls: NextProtos values too large")
  274. }
  275. var session *ClientSessionState
  276. sessionCache := c.config.ClientSessionCache
  277. cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
  278. // If sessionCache is set but session itself isn't - try to retrieve session from cache
  279. if sessionCache != nil && hs.session == nil {
  280. hs.hello.ticketSupported = true
  281. // Session resumption is not allowed if renegotiating because
  282. // renegotiation is primarily used to allow a client to send a client
  283. // certificate, which would be skipped if session resumption occurred.
  284. if c.handshakes == 0 {
  285. // Try to resume a previously negotiated TLS session, if
  286. // available.
  287. candidateSession, ok := sessionCache.Get(cacheKey)
  288. if ok {
  289. // Check that the ciphersuite/version used for the
  290. // previous session are still valid.
  291. cipherSuiteOk := false
  292. for _, id := range hs.hello.cipherSuites {
  293. if id == candidateSession.cipherSuite {
  294. cipherSuiteOk = true
  295. break
  296. }
  297. }
  298. versOk := candidateSession.vers >= c.config.minVersion() &&
  299. candidateSession.vers <= c.config.maxVersion()
  300. if versOk && cipherSuiteOk {
  301. session = candidateSession
  302. }
  303. if session != nil {
  304. hs.hello.sessionTicket = session.sessionTicket
  305. // A random session ID is used to detect when the
  306. // server accepted the ticket and is resuming a session
  307. // (see RFC 5077).
  308. hs.hello.sessionId = make([]byte, 16)
  309. if _, err := io.ReadFull(c.config.rand(), hs.hello.sessionId); err != nil {
  310. return errors.New("tls: short read from Rand: " + err.Error())
  311. }
  312. }
  313. hs.session = session
  314. }
  315. }
  316. }
  317. if err := hs.handshake(); err != nil {
  318. return err
  319. }
  320. // If we had a successful handshake and hs.session is different from the one already cached - cache a new one
  321. if sessionCache != nil && hs.session != nil && hs.session != session {
  322. sessionCache.Put(cacheKey, hs.session)
  323. }
  324. return nil
  325. }
  326. func (uconn *UConn) ApplyConfig() error {
  327. for _, ext := range uconn.Extensions {
  328. err := ext.writeToUConn(uconn)
  329. if err != nil {
  330. return err
  331. }
  332. }
  333. return nil
  334. }
  335. func (uconn *UConn) MarshalClientHello() error {
  336. hello := uconn.HandshakeState.Hello
  337. headerLength := 2 + 32 + 1 + len(hello.SessionId) +
  338. 2 + len(hello.CipherSuites)*2 +
  339. 1 + len(hello.CompressionMethods)
  340. extensions := make([]TLSExtension, 0, len(uconn.Extensions))
  341. for _, ext := range uconn.Extensions {
  342. if SNI, ok := ext.(*SNIExtension); !ok ||
  343. len(SNI.ServerName) > 0 ||
  344. uconn.IncludeEmptySNI {
  345. extensions = append(extensions, ext)
  346. }
  347. }
  348. extensionsLen := 0
  349. var paddingExt *utlsPaddingExtension
  350. for _, ext := range extensions {
  351. if pe, ok := ext.(*utlsPaddingExtension); !ok {
  352. // If not padding - just add length of extension to total length
  353. extensionsLen += ext.Len()
  354. } else {
  355. // If padding - process it later
  356. if paddingExt == nil {
  357. paddingExt = pe
  358. } else {
  359. return errors.New("Multiple padding extensions!")
  360. }
  361. }
  362. }
  363. if paddingExt != nil {
  364. // determine padding extension presence and length
  365. paddingExt.Update(headerLength + 4 + extensionsLen + 2)
  366. extensionsLen += paddingExt.Len()
  367. }
  368. helloLen := headerLength
  369. if len(extensions) > 0 {
  370. helloLen += 2 + extensionsLen // 2 bytes for extensions' length
  371. }
  372. helloBuffer := bytes.Buffer{}
  373. bufferedWriter := bufio.NewWriterSize(&helloBuffer, helloLen+4) // 1 byte for tls record type, 3 for length
  374. // We use buffered Writer to avoid checking write errors after every Write(): whenever first error happens
  375. // Write() will become noop, and error will be accessible via Flush(), which is called once in the end
  376. binary.Write(bufferedWriter, binary.BigEndian, typeClientHello)
  377. helloLenBytes := []byte{byte(helloLen >> 16), byte(helloLen >> 8), byte(helloLen)} // poor man's uint24
  378. binary.Write(bufferedWriter, binary.BigEndian, helloLenBytes)
  379. binary.Write(bufferedWriter, binary.BigEndian, hello.Vers)
  380. binary.Write(bufferedWriter, binary.BigEndian, hello.Random)
  381. binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.SessionId)))
  382. binary.Write(bufferedWriter, binary.BigEndian, hello.SessionId)
  383. binary.Write(bufferedWriter, binary.BigEndian, uint16(len(hello.CipherSuites)<<1))
  384. for _, suite := range hello.CipherSuites {
  385. binary.Write(bufferedWriter, binary.BigEndian, suite)
  386. }
  387. binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.CompressionMethods)))
  388. binary.Write(bufferedWriter, binary.BigEndian, hello.CompressionMethods)
  389. if len(extensions) > 0 {
  390. binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen))
  391. for _, ext := range extensions {
  392. bufferedWriter.ReadFrom(ext)
  393. }
  394. }
  395. if helloBuffer.Len() != 4+helloLen {
  396. return errors.New("utls: unexpected ClientHello length. Expected: " + strconv.Itoa(4+helloLen) +
  397. ". Got: " + strconv.Itoa(helloBuffer.Len()))
  398. }
  399. err := bufferedWriter.Flush()
  400. if err != nil {
  401. return err
  402. }
  403. hello.Raw = helloBuffer.Bytes()
  404. return nil
  405. }
  406. // get current state of cipher and encrypt zeros to get keystream
  407. func (uconn *UConn) GetOutKeystream(length int) ([]byte, error) {
  408. zeros := make([]byte, length)
  409. if outCipher, ok := uconn.out.cipher.(cipher.AEAD); ok {
  410. // AEAD.Seal() does not mutate internal state, other ciphers might
  411. return outCipher.Seal(nil, uconn.out.seq[:], zeros, nil), nil
  412. }
  413. return nil, errors.New("Could not convert OutCipher to cipher.AEAD")
  414. }