conn.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "context"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "github.com/pion/dtls/v2/internal/closer"
  14. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  15. "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
  16. "github.com/pion/dtls/v2/pkg/protocol"
  17. "github.com/pion/dtls/v2/pkg/protocol/alert"
  18. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  19. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  20. "github.com/pion/logging"
  21. "github.com/pion/transport/v2/connctx"
  22. "github.com/pion/transport/v2/deadline"
  23. "github.com/pion/transport/v2/replaydetector"
  24. )
  25. const (
  26. initialTickerInterval = time.Second
  27. cookieLength = 20
  28. sessionLength = 32
  29. defaultNamedCurve = elliptic.X25519
  30. inboundBufferSize = 8192
  31. // Default replay protection window is specified by RFC 6347 Section 4.1.2.6
  32. defaultReplayProtectionWindow = 64
  33. )
  34. func invalidKeyingLabels() map[string]bool {
  35. return map[string]bool{
  36. "client finished": true,
  37. "server finished": true,
  38. "master secret": true,
  39. "key expansion": true,
  40. }
  41. }
  42. // Conn represents a DTLS connection
  43. type Conn struct {
  44. lock sync.RWMutex // Internal lock (must not be public)
  45. nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from
  46. fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling
  47. handshakeCache *handshakeCache // caching of handshake messages for verifyData generation
  48. decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read`
  49. state State // Internal state
  50. maximumTransmissionUnit int
  51. handshakeCompletedSuccessfully atomic.Value
  52. encryptedPackets [][]byte
  53. connectionClosedByUser bool
  54. closeLock sync.Mutex
  55. closed *closer.Closer
  56. handshakeLoopsFinished sync.WaitGroup
  57. readDeadline *deadline.Deadline
  58. writeDeadline *deadline.Deadline
  59. log logging.LeveledLogger
  60. reading chan struct{}
  61. handshakeRecv chan chan struct{}
  62. cancelHandshaker func()
  63. cancelHandshakeReader func()
  64. fsm *handshakeFSM
  65. replayProtectionWindow uint
  66. }
  67. func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
  68. err := validateConfig(config)
  69. if err != nil {
  70. return nil, err
  71. }
  72. if nextConn == nil {
  73. return nil, errNilNextConn
  74. }
  75. cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
  76. if err != nil {
  77. return nil, err
  78. }
  79. signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
  80. if err != nil {
  81. return nil, err
  82. }
  83. workerInterval := initialTickerInterval
  84. if config.FlightInterval != 0 {
  85. workerInterval = config.FlightInterval
  86. }
  87. loggerFactory := config.LoggerFactory
  88. if loggerFactory == nil {
  89. loggerFactory = logging.NewDefaultLoggerFactory()
  90. }
  91. logger := loggerFactory.NewLogger("dtls")
  92. mtu := config.MTU
  93. if mtu <= 0 {
  94. mtu = defaultMTU
  95. }
  96. replayProtectionWindow := config.ReplayProtectionWindow
  97. if replayProtectionWindow <= 0 {
  98. replayProtectionWindow = defaultReplayProtectionWindow
  99. }
  100. c := &Conn{
  101. nextConn: connctx.New(nextConn),
  102. fragmentBuffer: newFragmentBuffer(),
  103. handshakeCache: newHandshakeCache(),
  104. maximumTransmissionUnit: mtu,
  105. decrypted: make(chan interface{}, 1),
  106. log: logger,
  107. readDeadline: deadline.New(),
  108. writeDeadline: deadline.New(),
  109. reading: make(chan struct{}, 1),
  110. handshakeRecv: make(chan chan struct{}),
  111. closed: closer.NewCloser(),
  112. cancelHandshaker: func() {},
  113. replayProtectionWindow: uint(replayProtectionWindow),
  114. state: State{
  115. isClient: isClient,
  116. },
  117. }
  118. c.setRemoteEpoch(0)
  119. c.setLocalEpoch(0)
  120. serverName := config.ServerName
  121. // Do not allow the use of an IP address literal as an SNI value.
  122. // See RFC 6066, Section 3.
  123. if net.ParseIP(serverName) != nil {
  124. serverName = ""
  125. }
  126. curves := config.EllipticCurves
  127. if len(curves) == 0 {
  128. curves = defaultCurves
  129. }
  130. hsCfg := &handshakeConfig{
  131. localPSKCallback: config.PSK,
  132. localPSKIdentityHint: config.PSKIdentityHint,
  133. localCipherSuites: cipherSuites,
  134. localSignatureSchemes: signatureSchemes,
  135. extendedMasterSecret: config.ExtendedMasterSecret,
  136. localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
  137. serverName: serverName,
  138. supportedProtocols: config.SupportedProtocols,
  139. clientAuth: config.ClientAuth,
  140. localCertificates: config.Certificates,
  141. insecureSkipVerify: config.InsecureSkipVerify,
  142. verifyPeerCertificate: config.VerifyPeerCertificate,
  143. verifyConnection: config.VerifyConnection,
  144. rootCAs: config.RootCAs,
  145. clientCAs: config.ClientCAs,
  146. customCipherSuites: config.CustomCipherSuites,
  147. retransmitInterval: workerInterval,
  148. log: logger,
  149. initialEpoch: 0,
  150. keyLogWriter: config.KeyLogWriter,
  151. sessionStore: config.SessionStore,
  152. ellipticCurves: curves,
  153. localGetCertificate: config.GetCertificate,
  154. localGetClientCertificate: config.GetClientCertificate,
  155. insecureSkipHelloVerify: config.InsecureSkipVerifyHello,
  156. // [Psiphon]
  157. // Conjure DTLS support, from: https://github.com/mingyech/dtls/commit/a56eccc1
  158. customClientHelloRandom: config.CustomClientHelloRandom,
  159. }
  160. // rfc5246#section-7.4.3
  161. // In addition, the hash and signature algorithms MUST be compatible
  162. // with the key in the server's end-entity certificate.
  163. if !isClient {
  164. cert, err := hsCfg.getCertificate(&ClientHelloInfo{})
  165. if err != nil && !errors.Is(err, errNoCertificates) {
  166. return nil, err
  167. }
  168. hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites)
  169. }
  170. var initialFlight flightVal
  171. var initialFSMState handshakeState
  172. if initialState != nil {
  173. if c.state.isClient {
  174. initialFlight = flight5
  175. } else {
  176. initialFlight = flight6
  177. }
  178. initialFSMState = handshakeFinished
  179. c.state = *initialState
  180. } else {
  181. if c.state.isClient {
  182. initialFlight = flight1
  183. } else {
  184. initialFlight = flight0
  185. }
  186. initialFSMState = handshakePreparing
  187. }
  188. // Do handshake
  189. if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
  190. return nil, err
  191. }
  192. c.log.Trace("Handshake Completed")
  193. return c, nil
  194. }
  195. // Dial connects to the given network address and establishes a DTLS connection on top.
  196. // Connection handshake will timeout using ConnectContextMaker in the Config.
  197. // If you want to specify the timeout duration, use DialWithContext() instead.
  198. func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
  199. ctx, cancel := config.connectContextMaker()
  200. defer cancel()
  201. return DialWithContext(ctx, network, raddr, config)
  202. }
  203. // Client establishes a DTLS connection over an existing connection.
  204. // Connection handshake will timeout using ConnectContextMaker in the Config.
  205. // If you want to specify the timeout duration, use ClientWithContext() instead.
  206. func Client(conn net.Conn, config *Config) (*Conn, error) {
  207. ctx, cancel := config.connectContextMaker()
  208. defer cancel()
  209. return ClientWithContext(ctx, conn, config)
  210. }
  211. // Server listens for incoming DTLS connections.
  212. // Connection handshake will timeout using ConnectContextMaker in the Config.
  213. // If you want to specify the timeout duration, use ServerWithContext() instead.
  214. func Server(conn net.Conn, config *Config) (*Conn, error) {
  215. ctx, cancel := config.connectContextMaker()
  216. defer cancel()
  217. return ServerWithContext(ctx, conn, config)
  218. }
  219. // DialWithContext connects to the given network address and establishes a DTLS connection on top.
  220. func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
  221. pConn, err := net.DialUDP(network, nil, raddr)
  222. if err != nil {
  223. return nil, err
  224. }
  225. return ClientWithContext(ctx, pConn, config)
  226. }
  227. // ClientWithContext establishes a DTLS connection over an existing connection.
  228. func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
  229. switch {
  230. case config == nil:
  231. return nil, errNoConfigProvided
  232. case config.PSK != nil && config.PSKIdentityHint == nil:
  233. return nil, errPSKAndIdentityMustBeSetForClient
  234. }
  235. return createConn(ctx, conn, config, true, nil)
  236. }
  237. // ServerWithContext listens for incoming DTLS connections.
  238. func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
  239. if config == nil {
  240. return nil, errNoConfigProvided
  241. }
  242. return createConn(ctx, conn, config, false, nil)
  243. }
  244. // Read reads data from the connection.
  245. func (c *Conn) Read(p []byte) (n int, err error) {
  246. if !c.isHandshakeCompletedSuccessfully() {
  247. return 0, errHandshakeInProgress
  248. }
  249. select {
  250. case <-c.readDeadline.Done():
  251. return 0, errDeadlineExceeded
  252. default:
  253. }
  254. for {
  255. select {
  256. case <-c.readDeadline.Done():
  257. return 0, errDeadlineExceeded
  258. case out, ok := <-c.decrypted:
  259. if !ok {
  260. return 0, io.EOF
  261. }
  262. switch val := out.(type) {
  263. case ([]byte):
  264. if len(p) < len(val) {
  265. return 0, errBufferTooSmall
  266. }
  267. copy(p, val)
  268. return len(val), nil
  269. case (error):
  270. return 0, val
  271. }
  272. }
  273. }
  274. }
  275. // Write writes len(p) bytes from p to the DTLS connection
  276. func (c *Conn) Write(p []byte) (int, error) {
  277. if c.isConnectionClosed() {
  278. return 0, ErrConnClosed
  279. }
  280. select {
  281. case <-c.writeDeadline.Done():
  282. return 0, errDeadlineExceeded
  283. default:
  284. }
  285. if !c.isHandshakeCompletedSuccessfully() {
  286. return 0, errHandshakeInProgress
  287. }
  288. return len(p), c.writePackets(c.writeDeadline, []*packet{
  289. {
  290. record: &recordlayer.RecordLayer{
  291. Header: recordlayer.Header{
  292. Epoch: c.state.getLocalEpoch(),
  293. Version: protocol.Version1_2,
  294. },
  295. Content: &protocol.ApplicationData{
  296. Data: p,
  297. },
  298. },
  299. shouldEncrypt: true,
  300. },
  301. })
  302. }
  303. // Close closes the connection.
  304. func (c *Conn) Close() error {
  305. err := c.close(true) //nolint:contextcheck
  306. c.handshakeLoopsFinished.Wait()
  307. return err
  308. }
  309. // ConnectionState returns basic DTLS details about the connection.
  310. // Note that this replaced the `Export` function of v1.
  311. func (c *Conn) ConnectionState() State {
  312. c.lock.RLock()
  313. defer c.lock.RUnlock()
  314. return *c.state.clone()
  315. }
  316. // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
  317. func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
  318. profile := c.state.getSRTPProtectionProfile()
  319. if profile == 0 {
  320. return 0, false
  321. }
  322. return profile, true
  323. }
  324. func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
  325. c.lock.Lock()
  326. defer c.lock.Unlock()
  327. var rawPackets [][]byte
  328. for _, p := range pkts {
  329. if h, ok := p.record.Content.(*handshake.Handshake); ok {
  330. handshakeRaw, err := p.record.Marshal()
  331. if err != nil {
  332. return err
  333. }
  334. c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
  335. srvCliStr(c.state.isClient), h.Header.Type.String(),
  336. p.record.Header.Epoch, h.Header.MessageSequence)
  337. c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
  338. rawHandshakePackets, err := c.processHandshakePacket(p, h)
  339. if err != nil {
  340. return err
  341. }
  342. rawPackets = append(rawPackets, rawHandshakePackets...)
  343. } else {
  344. rawPacket, err := c.processPacket(p)
  345. if err != nil {
  346. return err
  347. }
  348. rawPackets = append(rawPackets, rawPacket)
  349. }
  350. }
  351. if len(rawPackets) == 0 {
  352. return nil
  353. }
  354. compactedRawPackets := c.compactRawPackets(rawPackets)
  355. for _, compactedRawPackets := range compactedRawPackets {
  356. if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil {
  357. return netError(err)
  358. }
  359. }
  360. return nil
  361. }
  362. func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
  363. // avoid a useless copy in the common case
  364. if len(rawPackets) == 1 {
  365. return rawPackets
  366. }
  367. combinedRawPackets := make([][]byte, 0)
  368. currentCombinedRawPacket := make([]byte, 0)
  369. for _, rawPacket := range rawPackets {
  370. if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
  371. combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
  372. currentCombinedRawPacket = []byte{}
  373. }
  374. currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...)
  375. }
  376. combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
  377. return combinedRawPackets
  378. }
  379. func (c *Conn) processPacket(p *packet) ([]byte, error) {
  380. epoch := p.record.Header.Epoch
  381. for len(c.state.localSequenceNumber) <= int(epoch) {
  382. c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
  383. }
  384. seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
  385. if seq > recordlayer.MaxSequenceNumber {
  386. // RFC 6347 Section 4.1.0
  387. // The implementation must either abandon an association or rehandshake
  388. // prior to allowing the sequence number to wrap.
  389. return nil, errSequenceNumberOverflow
  390. }
  391. p.record.Header.SequenceNumber = seq
  392. rawPacket, err := p.record.Marshal()
  393. if err != nil {
  394. return nil, err
  395. }
  396. if p.shouldEncrypt {
  397. var err error
  398. rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
  399. if err != nil {
  400. return nil, err
  401. }
  402. }
  403. return rawPacket, nil
  404. }
  405. func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) {
  406. rawPackets := make([][]byte, 0)
  407. handshakeFragments, err := c.fragmentHandshake(h)
  408. if err != nil {
  409. return nil, err
  410. }
  411. epoch := p.record.Header.Epoch
  412. for len(c.state.localSequenceNumber) <= int(epoch) {
  413. c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
  414. }
  415. for _, handshakeFragment := range handshakeFragments {
  416. seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
  417. if seq > recordlayer.MaxSequenceNumber {
  418. return nil, errSequenceNumberOverflow
  419. }
  420. recordlayerHeader := &recordlayer.Header{
  421. Version: p.record.Header.Version,
  422. ContentType: p.record.Header.ContentType,
  423. ContentLen: uint16(len(handshakeFragment)),
  424. Epoch: p.record.Header.Epoch,
  425. SequenceNumber: seq,
  426. }
  427. rawPacket, err := recordlayerHeader.Marshal()
  428. if err != nil {
  429. return nil, err
  430. }
  431. p.record.Header = *recordlayerHeader
  432. rawPacket = append(rawPacket, handshakeFragment...)
  433. if p.shouldEncrypt {
  434. var err error
  435. rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
  436. if err != nil {
  437. return nil, err
  438. }
  439. }
  440. rawPackets = append(rawPackets, rawPacket)
  441. }
  442. return rawPackets, nil
  443. }
  444. func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
  445. content, err := h.Message.Marshal()
  446. if err != nil {
  447. return nil, err
  448. }
  449. fragmentedHandshakes := make([][]byte, 0)
  450. contentFragments := splitBytes(content, c.maximumTransmissionUnit)
  451. if len(contentFragments) == 0 {
  452. contentFragments = [][]byte{
  453. {},
  454. }
  455. }
  456. offset := 0
  457. for _, contentFragment := range contentFragments {
  458. contentFragmentLen := len(contentFragment)
  459. headerFragment := &handshake.Header{
  460. Type: h.Header.Type,
  461. Length: h.Header.Length,
  462. MessageSequence: h.Header.MessageSequence,
  463. FragmentOffset: uint32(offset),
  464. FragmentLength: uint32(contentFragmentLen),
  465. }
  466. offset += contentFragmentLen
  467. fragmentedHandshake, err := headerFragment.Marshal()
  468. if err != nil {
  469. return nil, err
  470. }
  471. fragmentedHandshake = append(fragmentedHandshake, contentFragment...)
  472. fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
  473. }
  474. return fragmentedHandshakes, nil
  475. }
  476. var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
  477. New: func() interface{} {
  478. b := make([]byte, inboundBufferSize)
  479. return &b
  480. },
  481. }
  482. func (c *Conn) readAndBuffer(ctx context.Context) error {
  483. bufptr, ok := poolReadBuffer.Get().(*[]byte)
  484. if !ok {
  485. return errFailedToAccessPoolReadBuffer
  486. }
  487. defer poolReadBuffer.Put(bufptr)
  488. b := *bufptr
  489. i, err := c.nextConn.ReadContext(ctx, b)
  490. if err != nil {
  491. return netError(err)
  492. }
  493. pkts, err := recordlayer.UnpackDatagram(b[:i])
  494. if err != nil {
  495. return err
  496. }
  497. var hasHandshake bool
  498. for _, p := range pkts {
  499. hs, alert, err := c.handleIncomingPacket(ctx, p, true)
  500. if alert != nil {
  501. if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
  502. if err == nil {
  503. err = alertErr
  504. }
  505. }
  506. }
  507. if hs {
  508. hasHandshake = true
  509. }
  510. var e *alertError
  511. if errors.As(err, &e) {
  512. if e.IsFatalOrCloseNotify() {
  513. return e
  514. }
  515. } else if err != nil {
  516. return e
  517. }
  518. }
  519. if hasHandshake {
  520. done := make(chan struct{})
  521. select {
  522. case c.handshakeRecv <- done:
  523. // If the other party may retransmit the flight,
  524. // we should respond even if it not a new message.
  525. <-done
  526. case <-c.fsm.Done():
  527. }
  528. }
  529. return nil
  530. }
  531. func (c *Conn) handleQueuedPackets(ctx context.Context) error {
  532. pkts := c.encryptedPackets
  533. c.encryptedPackets = nil
  534. for _, p := range pkts {
  535. _, alert, err := c.handleIncomingPacket(ctx, p, false) // don't re-enqueue
  536. if alert != nil {
  537. if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
  538. if err == nil {
  539. err = alertErr
  540. }
  541. }
  542. }
  543. var e *alertError
  544. if errors.As(err, &e) {
  545. if e.IsFatalOrCloseNotify() {
  546. return e
  547. }
  548. } else if err != nil {
  549. return e
  550. }
  551. }
  552. return nil
  553. }
  554. func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
  555. h := &recordlayer.Header{}
  556. if err := h.Unmarshal(buf); err != nil {
  557. // Decode error must be silently discarded
  558. // [RFC6347 Section-4.1.2.7]
  559. c.log.Debugf("discarded broken packet: %v", err)
  560. return false, nil, nil
  561. }
  562. // Validate epoch
  563. remoteEpoch := c.state.getRemoteEpoch()
  564. if h.Epoch > remoteEpoch {
  565. if h.Epoch > remoteEpoch+1 {
  566. c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
  567. h.Epoch, h.SequenceNumber,
  568. )
  569. return false, nil, nil
  570. }
  571. if enqueue {
  572. c.log.Debug("received packet of next epoch, queuing packet")
  573. c.encryptedPackets = append(c.encryptedPackets, buf)
  574. }
  575. return false, nil, nil
  576. }
  577. // Anti-replay protection
  578. for len(c.state.replayDetector) <= int(h.Epoch) {
  579. c.state.replayDetector = append(c.state.replayDetector,
  580. replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber),
  581. )
  582. }
  583. markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber)
  584. if !ok {
  585. c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
  586. h.Epoch, h.SequenceNumber,
  587. )
  588. return false, nil, nil
  589. }
  590. // Decrypt
  591. if h.Epoch != 0 {
  592. if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
  593. if enqueue {
  594. c.encryptedPackets = append(c.encryptedPackets, buf)
  595. c.log.Debug("handshake not finished, queuing packet")
  596. }
  597. return false, nil, nil
  598. }
  599. var err error
  600. buf, err = c.state.cipherSuite.Decrypt(buf)
  601. if err != nil {
  602. c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
  603. return false, nil, nil
  604. }
  605. }
  606. isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
  607. if err != nil {
  608. // Decode error must be silently discarded
  609. // [RFC6347 Section-4.1.2.7]
  610. c.log.Debugf("defragment failed: %s", err)
  611. return false, nil, nil
  612. } else if isHandshake {
  613. markPacketAsValid()
  614. for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
  615. header := &handshake.Header{}
  616. if err := header.Unmarshal(out); err != nil {
  617. c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
  618. continue
  619. }
  620. c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
  621. }
  622. return true, nil, nil
  623. }
  624. r := &recordlayer.RecordLayer{}
  625. if err := r.Unmarshal(buf); err != nil {
  626. return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
  627. }
  628. switch content := r.Content.(type) {
  629. case *alert.Alert:
  630. c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
  631. var a *alert.Alert
  632. if content.Description == alert.CloseNotify {
  633. // Respond with a close_notify [RFC5246 Section 7.2.1]
  634. a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
  635. }
  636. markPacketAsValid()
  637. return false, a, &alertError{content}
  638. case *protocol.ChangeCipherSpec:
  639. if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
  640. if enqueue {
  641. c.encryptedPackets = append(c.encryptedPackets, buf)
  642. c.log.Debugf("CipherSuite not initialized, queuing packet")
  643. }
  644. return false, nil, nil
  645. }
  646. newRemoteEpoch := h.Epoch + 1
  647. c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
  648. if c.state.getRemoteEpoch()+1 == newRemoteEpoch {
  649. c.setRemoteEpoch(newRemoteEpoch)
  650. markPacketAsValid()
  651. }
  652. case *protocol.ApplicationData:
  653. if h.Epoch == 0 {
  654. return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
  655. }
  656. markPacketAsValid()
  657. select {
  658. case c.decrypted <- content.Data:
  659. case <-c.closed.Done():
  660. case <-ctx.Done():
  661. }
  662. default:
  663. return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
  664. }
  665. return false, nil, nil
  666. }
  667. func (c *Conn) recvHandshake() <-chan chan struct{} {
  668. return c.handshakeRecv
  669. }
  670. func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
  671. if level == alert.Fatal && len(c.state.SessionID) > 0 {
  672. // According to the RFC, we need to delete the stored session.
  673. // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
  674. if ss := c.fsm.cfg.sessionStore; ss != nil {
  675. c.log.Tracef("clean invalid session: %s", c.state.SessionID)
  676. if err := ss.Del(c.sessionKey()); err != nil {
  677. return err
  678. }
  679. }
  680. }
  681. return c.writePackets(ctx, []*packet{
  682. {
  683. record: &recordlayer.RecordLayer{
  684. Header: recordlayer.Header{
  685. Epoch: c.state.getLocalEpoch(),
  686. Version: protocol.Version1_2,
  687. },
  688. Content: &alert.Alert{
  689. Level: level,
  690. Description: desc,
  691. },
  692. },
  693. shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
  694. },
  695. })
  696. }
  697. func (c *Conn) setHandshakeCompletedSuccessfully() {
  698. c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
  699. }
  700. func (c *Conn) isHandshakeCompletedSuccessfully() bool {
  701. boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
  702. return boolean.bool
  703. }
  704. func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit
  705. c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)
  706. done := make(chan struct{})
  707. ctxRead, cancelRead := context.WithCancel(context.Background())
  708. c.cancelHandshakeReader = cancelRead
  709. cfg.onFlightState = func(f flightVal, s handshakeState) {
  710. if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
  711. c.setHandshakeCompletedSuccessfully()
  712. close(done)
  713. }
  714. }
  715. ctxHs, cancel := context.WithCancel(context.Background())
  716. c.cancelHandshaker = cancel
  717. firstErr := make(chan error, 1)
  718. c.handshakeLoopsFinished.Add(2)
  719. // Handshake routine should be live until close.
  720. // The other party may request retransmission of the last flight to cope with packet drop.
  721. go func() {
  722. defer c.handshakeLoopsFinished.Done()
  723. err := c.fsm.Run(ctxHs, c, initialState)
  724. if !errors.Is(err, context.Canceled) {
  725. select {
  726. case firstErr <- err:
  727. default:
  728. }
  729. }
  730. }()
  731. go func() {
  732. defer func() {
  733. // Escaping read loop.
  734. // It's safe to close decrypted channnel now.
  735. close(c.decrypted)
  736. // Force stop handshaker when the underlying connection is closed.
  737. cancel()
  738. }()
  739. defer c.handshakeLoopsFinished.Done()
  740. for {
  741. if err := c.readAndBuffer(ctxRead); err != nil {
  742. var e *alertError
  743. if errors.As(err, &e) {
  744. if !e.IsFatalOrCloseNotify() {
  745. if c.isHandshakeCompletedSuccessfully() {
  746. // Pass the error to Read()
  747. select {
  748. case c.decrypted <- err:
  749. case <-c.closed.Done():
  750. case <-ctxRead.Done():
  751. }
  752. }
  753. continue // non-fatal alert must not stop read loop
  754. }
  755. } else {
  756. switch {
  757. case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed):
  758. case errors.Is(err, recordlayer.ErrInvalidPacketLength):
  759. // Decode error must be silently discarded
  760. // [RFC6347 Section-4.1.2.7]
  761. continue
  762. default:
  763. if c.isHandshakeCompletedSuccessfully() {
  764. // Keep read loop and pass the read error to Read()
  765. select {
  766. case c.decrypted <- err:
  767. case <-c.closed.Done():
  768. case <-ctxRead.Done():
  769. }
  770. continue // non-fatal alert must not stop read loop
  771. }
  772. }
  773. }
  774. select {
  775. case firstErr <- err:
  776. default:
  777. }
  778. if e != nil {
  779. if e.IsFatalOrCloseNotify() {
  780. _ = c.close(false) //nolint:contextcheck
  781. }
  782. }
  783. if !c.isConnectionClosed() && errors.Is(err, context.Canceled) {
  784. c.log.Trace("handshake timeouts - closing underline connection")
  785. _ = c.close(false) //nolint:contextcheck
  786. }
  787. return
  788. }
  789. }
  790. }()
  791. select {
  792. case err := <-firstErr:
  793. cancelRead()
  794. cancel()
  795. c.handshakeLoopsFinished.Wait()
  796. return c.translateHandshakeCtxError(err)
  797. case <-ctx.Done():
  798. cancelRead()
  799. cancel()
  800. c.handshakeLoopsFinished.Wait()
  801. return c.translateHandshakeCtxError(ctx.Err())
  802. case <-done:
  803. return nil
  804. }
  805. }
  806. func (c *Conn) translateHandshakeCtxError(err error) error {
  807. if err == nil {
  808. return nil
  809. }
  810. if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
  811. return nil
  812. }
  813. return &HandshakeError{Err: err}
  814. }
  815. func (c *Conn) close(byUser bool) error {
  816. c.cancelHandshaker()
  817. c.cancelHandshakeReader()
  818. if c.isHandshakeCompletedSuccessfully() && byUser {
  819. // Discard error from notify() to return non-error on the first user call of Close()
  820. // even if the underlying connection is already closed.
  821. _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify)
  822. }
  823. c.closeLock.Lock()
  824. // Don't return ErrConnClosed at the first time of the call from user.
  825. closedByUser := c.connectionClosedByUser
  826. if byUser {
  827. c.connectionClosedByUser = true
  828. }
  829. isClosed := c.isConnectionClosed()
  830. c.closed.Close()
  831. c.closeLock.Unlock()
  832. if closedByUser {
  833. return ErrConnClosed
  834. }
  835. if isClosed {
  836. return nil
  837. }
  838. return c.nextConn.Close()
  839. }
  840. func (c *Conn) isConnectionClosed() bool {
  841. select {
  842. case <-c.closed.Done():
  843. return true
  844. default:
  845. return false
  846. }
  847. }
  848. func (c *Conn) setLocalEpoch(epoch uint16) {
  849. c.state.localEpoch.Store(epoch)
  850. }
  851. func (c *Conn) setRemoteEpoch(epoch uint16) {
  852. c.state.remoteEpoch.Store(epoch)
  853. }
  854. // LocalAddr implements net.Conn.LocalAddr
  855. func (c *Conn) LocalAddr() net.Addr {
  856. return c.nextConn.LocalAddr()
  857. }
  858. // RemoteAddr implements net.Conn.RemoteAddr
  859. func (c *Conn) RemoteAddr() net.Addr {
  860. return c.nextConn.RemoteAddr()
  861. }
  862. func (c *Conn) sessionKey() []byte {
  863. if c.state.isClient {
  864. // As ServerName can be like 0.example.com, it's better to add
  865. // delimiter character which is not allowed to be in
  866. // neither address or domain name.
  867. return []byte(c.nextConn.RemoteAddr().String() + "_" + c.fsm.cfg.serverName)
  868. }
  869. return c.state.SessionID
  870. }
  871. // SetDeadline implements net.Conn.SetDeadline
  872. func (c *Conn) SetDeadline(t time.Time) error {
  873. c.readDeadline.Set(t)
  874. return c.SetWriteDeadline(t)
  875. }
  876. // SetReadDeadline implements net.Conn.SetReadDeadline
  877. func (c *Conn) SetReadDeadline(t time.Time) error {
  878. c.readDeadline.Set(t)
  879. // Read deadline is fully managed by this layer.
  880. // Don't set read deadline to underlying connection.
  881. return nil
  882. }
  883. // SetWriteDeadline implements net.Conn.SetWriteDeadline
  884. func (c *Conn) SetWriteDeadline(t time.Time) error {
  885. c.writeDeadline.Set(t)
  886. // Write deadline is also fully managed by this layer.
  887. return nil
  888. }