conn.go 27 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "context"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "github.com/pion/dtls/v2/internal/closer"
  14. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  15. "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
  16. "github.com/pion/dtls/v2/pkg/protocol"
  17. "github.com/pion/dtls/v2/pkg/protocol/alert"
  18. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  19. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  20. "github.com/pion/logging"
  21. "github.com/pion/transport/v2/connctx"
  22. "github.com/pion/transport/v2/deadline"
  23. "github.com/pion/transport/v2/replaydetector"
  24. )
  25. const (
  26. initialTickerInterval = time.Second
  27. cookieLength = 20
  28. sessionLength = 32
  29. defaultNamedCurve = elliptic.X25519
  30. inboundBufferSize = 8192
  31. // Default replay protection window is specified by RFC 6347 Section 4.1.2.6
  32. defaultReplayProtectionWindow = 64
  33. )
  34. func invalidKeyingLabels() map[string]bool {
  35. return map[string]bool{
  36. "client finished": true,
  37. "server finished": true,
  38. "master secret": true,
  39. "key expansion": true,
  40. }
  41. }
  42. // Conn represents a DTLS connection
  43. type Conn struct {
  44. lock sync.RWMutex // Internal lock (must not be public)
  45. nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from
  46. fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling
  47. handshakeCache *handshakeCache // caching of handshake messages for verifyData generation
  48. decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read`
  49. state State // Internal state
  50. maximumTransmissionUnit int
  51. handshakeCompletedSuccessfully atomic.Value
  52. encryptedPackets [][]byte
  53. connectionClosedByUser bool
  54. closeLock sync.Mutex
  55. closed *closer.Closer
  56. handshakeLoopsFinished sync.WaitGroup
  57. readDeadline *deadline.Deadline
  58. writeDeadline *deadline.Deadline
  59. log logging.LeveledLogger
  60. reading chan struct{}
  61. handshakeRecv chan chan struct{}
  62. cancelHandshaker func()
  63. cancelHandshakeReader func()
  64. fsm *handshakeFSM
  65. replayProtectionWindow uint
  66. }
  67. func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
  68. err := validateConfig(config)
  69. if err != nil {
  70. return nil, err
  71. }
  72. if nextConn == nil {
  73. return nil, errNilNextConn
  74. }
  75. cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
  76. if err != nil {
  77. return nil, err
  78. }
  79. signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
  80. if err != nil {
  81. return nil, err
  82. }
  83. workerInterval := initialTickerInterval
  84. if config.FlightInterval != 0 {
  85. workerInterval = config.FlightInterval
  86. }
  87. loggerFactory := config.LoggerFactory
  88. if loggerFactory == nil {
  89. loggerFactory = logging.NewDefaultLoggerFactory()
  90. }
  91. logger := loggerFactory.NewLogger("dtls")
  92. mtu := config.MTU
  93. if mtu <= 0 {
  94. mtu = defaultMTU
  95. }
  96. replayProtectionWindow := config.ReplayProtectionWindow
  97. if replayProtectionWindow <= 0 {
  98. replayProtectionWindow = defaultReplayProtectionWindow
  99. }
  100. c := &Conn{
  101. nextConn: connctx.New(nextConn),
  102. fragmentBuffer: newFragmentBuffer(),
  103. handshakeCache: newHandshakeCache(),
  104. maximumTransmissionUnit: mtu,
  105. decrypted: make(chan interface{}, 1),
  106. log: logger,
  107. readDeadline: deadline.New(),
  108. writeDeadline: deadline.New(),
  109. reading: make(chan struct{}, 1),
  110. handshakeRecv: make(chan chan struct{}),
  111. closed: closer.NewCloser(),
  112. cancelHandshaker: func() {},
  113. replayProtectionWindow: uint(replayProtectionWindow),
  114. state: State{
  115. isClient: isClient,
  116. },
  117. }
  118. c.setRemoteEpoch(0)
  119. c.setLocalEpoch(0)
  120. serverName := config.ServerName
  121. // Do not allow the use of an IP address literal as an SNI value.
  122. // See RFC 6066, Section 3.
  123. if net.ParseIP(serverName) != nil {
  124. serverName = ""
  125. }
  126. curves := config.EllipticCurves
  127. if len(curves) == 0 {
  128. curves = defaultCurves
  129. }
  130. hsCfg := &handshakeConfig{
  131. localPSKCallback: config.PSK,
  132. localPSKIdentityHint: config.PSKIdentityHint,
  133. localCipherSuites: cipherSuites,
  134. localSignatureSchemes: signatureSchemes,
  135. extendedMasterSecret: config.ExtendedMasterSecret,
  136. localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
  137. serverName: serverName,
  138. supportedProtocols: config.SupportedProtocols,
  139. clientAuth: config.ClientAuth,
  140. localCertificates: config.Certificates,
  141. insecureSkipVerify: config.InsecureSkipVerify,
  142. verifyPeerCertificate: config.VerifyPeerCertificate,
  143. verifyConnection: config.VerifyConnection,
  144. rootCAs: config.RootCAs,
  145. clientCAs: config.ClientCAs,
  146. customCipherSuites: config.CustomCipherSuites,
  147. retransmitInterval: workerInterval,
  148. log: logger,
  149. initialEpoch: 0,
  150. keyLogWriter: config.KeyLogWriter,
  151. sessionStore: config.SessionStore,
  152. ellipticCurves: curves,
  153. localGetCertificate: config.GetCertificate,
  154. localGetClientCertificate: config.GetClientCertificate,
  155. insecureSkipHelloVerify: config.InsecureSkipVerifyHello,
  156. customClientHelloRandom: config.CustomClientHelloRandom,
  157. }
  158. // rfc5246#section-7.4.3
  159. // In addition, the hash and signature algorithms MUST be compatible
  160. // with the key in the server's end-entity certificate.
  161. if !isClient {
  162. cert, err := hsCfg.getCertificate(&ClientHelloInfo{})
  163. if err != nil && !errors.Is(err, errNoCertificates) {
  164. return nil, err
  165. }
  166. hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites)
  167. }
  168. var initialFlight flightVal
  169. var initialFSMState handshakeState
  170. if initialState != nil {
  171. if c.state.isClient {
  172. initialFlight = flight5
  173. } else {
  174. initialFlight = flight6
  175. }
  176. initialFSMState = handshakeFinished
  177. c.state = *initialState
  178. } else {
  179. if c.state.isClient {
  180. initialFlight = flight1
  181. } else {
  182. initialFlight = flight0
  183. }
  184. initialFSMState = handshakePreparing
  185. }
  186. // Do handshake
  187. if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
  188. return nil, err
  189. }
  190. c.log.Trace("Handshake Completed")
  191. return c, nil
  192. }
  193. // Dial connects to the given network address and establishes a DTLS connection on top.
  194. // Connection handshake will timeout using ConnectContextMaker in the Config.
  195. // If you want to specify the timeout duration, use DialWithContext() instead.
  196. func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
  197. ctx, cancel := config.connectContextMaker()
  198. defer cancel()
  199. return DialWithContext(ctx, network, raddr, config)
  200. }
  201. // Client establishes a DTLS connection over an existing connection.
  202. // Connection handshake will timeout using ConnectContextMaker in the Config.
  203. // If you want to specify the timeout duration, use ClientWithContext() instead.
  204. func Client(conn net.Conn, config *Config) (*Conn, error) {
  205. ctx, cancel := config.connectContextMaker()
  206. defer cancel()
  207. return ClientWithContext(ctx, conn, config)
  208. }
  209. // Server listens for incoming DTLS connections.
  210. // Connection handshake will timeout using ConnectContextMaker in the Config.
  211. // If you want to specify the timeout duration, use ServerWithContext() instead.
  212. func Server(conn net.Conn, config *Config) (*Conn, error) {
  213. ctx, cancel := config.connectContextMaker()
  214. defer cancel()
  215. return ServerWithContext(ctx, conn, config)
  216. }
  217. // DialWithContext connects to the given network address and establishes a DTLS connection on top.
  218. func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
  219. pConn, err := net.DialUDP(network, nil, raddr)
  220. if err != nil {
  221. return nil, err
  222. }
  223. return ClientWithContext(ctx, pConn, config)
  224. }
  225. // ClientWithContext establishes a DTLS connection over an existing connection.
  226. func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
  227. switch {
  228. case config == nil:
  229. return nil, errNoConfigProvided
  230. case config.PSK != nil && config.PSKIdentityHint == nil:
  231. return nil, errPSKAndIdentityMustBeSetForClient
  232. }
  233. return createConn(ctx, conn, config, true, nil)
  234. }
  235. // ServerWithContext listens for incoming DTLS connections.
  236. func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
  237. if config == nil {
  238. return nil, errNoConfigProvided
  239. }
  240. return createConn(ctx, conn, config, false, nil)
  241. }
  242. // Read reads data from the connection.
  243. func (c *Conn) Read(p []byte) (n int, err error) {
  244. if !c.isHandshakeCompletedSuccessfully() {
  245. return 0, errHandshakeInProgress
  246. }
  247. select {
  248. case <-c.readDeadline.Done():
  249. return 0, errDeadlineExceeded
  250. default:
  251. }
  252. for {
  253. select {
  254. case <-c.readDeadline.Done():
  255. return 0, errDeadlineExceeded
  256. case out, ok := <-c.decrypted:
  257. if !ok {
  258. return 0, io.EOF
  259. }
  260. switch val := out.(type) {
  261. case ([]byte):
  262. if len(p) < len(val) {
  263. return 0, errBufferTooSmall
  264. }
  265. copy(p, val)
  266. return len(val), nil
  267. case (error):
  268. return 0, val
  269. }
  270. }
  271. }
  272. }
  273. // Write writes len(p) bytes from p to the DTLS connection
  274. func (c *Conn) Write(p []byte) (int, error) {
  275. if c.isConnectionClosed() {
  276. return 0, ErrConnClosed
  277. }
  278. select {
  279. case <-c.writeDeadline.Done():
  280. return 0, errDeadlineExceeded
  281. default:
  282. }
  283. if !c.isHandshakeCompletedSuccessfully() {
  284. return 0, errHandshakeInProgress
  285. }
  286. return len(p), c.writePackets(c.writeDeadline, []*packet{
  287. {
  288. record: &recordlayer.RecordLayer{
  289. Header: recordlayer.Header{
  290. Epoch: c.state.getLocalEpoch(),
  291. Version: protocol.Version1_2,
  292. },
  293. Content: &protocol.ApplicationData{
  294. Data: p,
  295. },
  296. },
  297. shouldEncrypt: true,
  298. },
  299. })
  300. }
  301. // Close closes the connection.
  302. func (c *Conn) Close() error {
  303. err := c.close(true) //nolint:contextcheck
  304. c.handshakeLoopsFinished.Wait()
  305. return err
  306. }
  307. // ConnectionState returns basic DTLS details about the connection.
  308. // Note that this replaced the `Export` function of v1.
  309. func (c *Conn) ConnectionState() State {
  310. c.lock.RLock()
  311. defer c.lock.RUnlock()
  312. return *c.state.clone()
  313. }
  314. // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
  315. func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
  316. c.lock.RLock()
  317. defer c.lock.RUnlock()
  318. if c.state.srtpProtectionProfile == 0 {
  319. return 0, false
  320. }
  321. return c.state.srtpProtectionProfile, true
  322. }
  323. func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
  324. c.lock.Lock()
  325. defer c.lock.Unlock()
  326. var rawPackets [][]byte
  327. for _, p := range pkts {
  328. if h, ok := p.record.Content.(*handshake.Handshake); ok {
  329. handshakeRaw, err := p.record.Marshal()
  330. if err != nil {
  331. return err
  332. }
  333. c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
  334. srvCliStr(c.state.isClient), h.Header.Type.String(),
  335. p.record.Header.Epoch, h.Header.MessageSequence)
  336. c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
  337. rawHandshakePackets, err := c.processHandshakePacket(p, h)
  338. if err != nil {
  339. return err
  340. }
  341. rawPackets = append(rawPackets, rawHandshakePackets...)
  342. } else {
  343. rawPacket, err := c.processPacket(p)
  344. if err != nil {
  345. return err
  346. }
  347. rawPackets = append(rawPackets, rawPacket)
  348. }
  349. }
  350. if len(rawPackets) == 0 {
  351. return nil
  352. }
  353. compactedRawPackets := c.compactRawPackets(rawPackets)
  354. for _, compactedRawPackets := range compactedRawPackets {
  355. if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil {
  356. return netError(err)
  357. }
  358. }
  359. return nil
  360. }
  361. func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
  362. // avoid a useless copy in the common case
  363. if len(rawPackets) == 1 {
  364. return rawPackets
  365. }
  366. combinedRawPackets := make([][]byte, 0)
  367. currentCombinedRawPacket := make([]byte, 0)
  368. for _, rawPacket := range rawPackets {
  369. if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
  370. combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
  371. currentCombinedRawPacket = []byte{}
  372. }
  373. currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...)
  374. }
  375. combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
  376. return combinedRawPackets
  377. }
  378. func (c *Conn) processPacket(p *packet) ([]byte, error) {
  379. epoch := p.record.Header.Epoch
  380. for len(c.state.localSequenceNumber) <= int(epoch) {
  381. c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
  382. }
  383. seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
  384. if seq > recordlayer.MaxSequenceNumber {
  385. // RFC 6347 Section 4.1.0
  386. // The implementation must either abandon an association or rehandshake
  387. // prior to allowing the sequence number to wrap.
  388. return nil, errSequenceNumberOverflow
  389. }
  390. p.record.Header.SequenceNumber = seq
  391. rawPacket, err := p.record.Marshal()
  392. if err != nil {
  393. return nil, err
  394. }
  395. if p.shouldEncrypt {
  396. var err error
  397. rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
  398. if err != nil {
  399. return nil, err
  400. }
  401. }
  402. return rawPacket, nil
  403. }
  404. func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) {
  405. rawPackets := make([][]byte, 0)
  406. handshakeFragments, err := c.fragmentHandshake(h)
  407. if err != nil {
  408. return nil, err
  409. }
  410. epoch := p.record.Header.Epoch
  411. for len(c.state.localSequenceNumber) <= int(epoch) {
  412. c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
  413. }
  414. for _, handshakeFragment := range handshakeFragments {
  415. seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
  416. if seq > recordlayer.MaxSequenceNumber {
  417. return nil, errSequenceNumberOverflow
  418. }
  419. recordlayerHeader := &recordlayer.Header{
  420. Version: p.record.Header.Version,
  421. ContentType: p.record.Header.ContentType,
  422. ContentLen: uint16(len(handshakeFragment)),
  423. Epoch: p.record.Header.Epoch,
  424. SequenceNumber: seq,
  425. }
  426. rawPacket, err := recordlayerHeader.Marshal()
  427. if err != nil {
  428. return nil, err
  429. }
  430. p.record.Header = *recordlayerHeader
  431. rawPacket = append(rawPacket, handshakeFragment...)
  432. if p.shouldEncrypt {
  433. var err error
  434. rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
  435. if err != nil {
  436. return nil, err
  437. }
  438. }
  439. rawPackets = append(rawPackets, rawPacket)
  440. }
  441. return rawPackets, nil
  442. }
  443. func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
  444. content, err := h.Message.Marshal()
  445. if err != nil {
  446. return nil, err
  447. }
  448. fragmentedHandshakes := make([][]byte, 0)
  449. contentFragments := splitBytes(content, c.maximumTransmissionUnit)
  450. if len(contentFragments) == 0 {
  451. contentFragments = [][]byte{
  452. {},
  453. }
  454. }
  455. offset := 0
  456. for _, contentFragment := range contentFragments {
  457. contentFragmentLen := len(contentFragment)
  458. headerFragment := &handshake.Header{
  459. Type: h.Header.Type,
  460. Length: h.Header.Length,
  461. MessageSequence: h.Header.MessageSequence,
  462. FragmentOffset: uint32(offset),
  463. FragmentLength: uint32(contentFragmentLen),
  464. }
  465. offset += contentFragmentLen
  466. fragmentedHandshake, err := headerFragment.Marshal()
  467. if err != nil {
  468. return nil, err
  469. }
  470. fragmentedHandshake = append(fragmentedHandshake, contentFragment...)
  471. fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
  472. }
  473. return fragmentedHandshakes, nil
  474. }
  475. var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
  476. New: func() interface{} {
  477. b := make([]byte, inboundBufferSize)
  478. return &b
  479. },
  480. }
  481. func (c *Conn) readAndBuffer(ctx context.Context) error {
  482. bufptr, ok := poolReadBuffer.Get().(*[]byte)
  483. if !ok {
  484. return errFailedToAccessPoolReadBuffer
  485. }
  486. defer poolReadBuffer.Put(bufptr)
  487. b := *bufptr
  488. i, err := c.nextConn.ReadContext(ctx, b)
  489. if err != nil {
  490. return netError(err)
  491. }
  492. pkts, err := recordlayer.UnpackDatagram(b[:i])
  493. if err != nil {
  494. return err
  495. }
  496. var hasHandshake bool
  497. for _, p := range pkts {
  498. hs, alert, err := c.handleIncomingPacket(ctx, p, true)
  499. if alert != nil {
  500. if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
  501. if err == nil {
  502. err = alertErr
  503. }
  504. }
  505. }
  506. if hs {
  507. hasHandshake = true
  508. }
  509. var e *alertError
  510. if errors.As(err, &e) {
  511. if e.IsFatalOrCloseNotify() {
  512. return e
  513. }
  514. } else if err != nil {
  515. return e
  516. }
  517. }
  518. if hasHandshake {
  519. done := make(chan struct{})
  520. select {
  521. case c.handshakeRecv <- done:
  522. // If the other party may retransmit the flight,
  523. // we should respond even if it not a new message.
  524. <-done
  525. case <-c.fsm.Done():
  526. }
  527. }
  528. return nil
  529. }
  530. func (c *Conn) handleQueuedPackets(ctx context.Context) error {
  531. pkts := c.encryptedPackets
  532. c.encryptedPackets = nil
  533. for _, p := range pkts {
  534. _, alert, err := c.handleIncomingPacket(ctx, p, false) // don't re-enqueue
  535. if alert != nil {
  536. if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
  537. if err == nil {
  538. err = alertErr
  539. }
  540. }
  541. }
  542. var e *alertError
  543. if errors.As(err, &e) {
  544. if e.IsFatalOrCloseNotify() {
  545. return e
  546. }
  547. } else if err != nil {
  548. return e
  549. }
  550. }
  551. return nil
  552. }
  553. func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
  554. h := &recordlayer.Header{}
  555. if err := h.Unmarshal(buf); err != nil {
  556. // Decode error must be silently discarded
  557. // [RFC6347 Section-4.1.2.7]
  558. c.log.Debugf("discarded broken packet: %v", err)
  559. return false, nil, nil
  560. }
  561. // Validate epoch
  562. remoteEpoch := c.state.getRemoteEpoch()
  563. if h.Epoch > remoteEpoch {
  564. if h.Epoch > remoteEpoch+1 {
  565. c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
  566. h.Epoch, h.SequenceNumber,
  567. )
  568. return false, nil, nil
  569. }
  570. if enqueue {
  571. c.log.Debug("received packet of next epoch, queuing packet")
  572. c.encryptedPackets = append(c.encryptedPackets, buf)
  573. }
  574. return false, nil, nil
  575. }
  576. // Anti-replay protection
  577. for len(c.state.replayDetector) <= int(h.Epoch) {
  578. c.state.replayDetector = append(c.state.replayDetector,
  579. replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber),
  580. )
  581. }
  582. markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber)
  583. if !ok {
  584. c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
  585. h.Epoch, h.SequenceNumber,
  586. )
  587. return false, nil, nil
  588. }
  589. // Decrypt
  590. if h.Epoch != 0 {
  591. if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
  592. if enqueue {
  593. c.encryptedPackets = append(c.encryptedPackets, buf)
  594. c.log.Debug("handshake not finished, queuing packet")
  595. }
  596. return false, nil, nil
  597. }
  598. var err error
  599. buf, err = c.state.cipherSuite.Decrypt(buf)
  600. if err != nil {
  601. c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
  602. return false, nil, nil
  603. }
  604. }
  605. isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
  606. if err != nil {
  607. // Decode error must be silently discarded
  608. // [RFC6347 Section-4.1.2.7]
  609. c.log.Debugf("defragment failed: %s", err)
  610. return false, nil, nil
  611. } else if isHandshake {
  612. markPacketAsValid()
  613. for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
  614. header := &handshake.Header{}
  615. if err := header.Unmarshal(out); err != nil {
  616. c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
  617. continue
  618. }
  619. c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
  620. }
  621. return true, nil, nil
  622. }
  623. r := &recordlayer.RecordLayer{}
  624. if err := r.Unmarshal(buf); err != nil {
  625. return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
  626. }
  627. switch content := r.Content.(type) {
  628. case *alert.Alert:
  629. c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
  630. var a *alert.Alert
  631. if content.Description == alert.CloseNotify {
  632. // Respond with a close_notify [RFC5246 Section 7.2.1]
  633. a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
  634. }
  635. markPacketAsValid()
  636. return false, a, &alertError{content}
  637. case *protocol.ChangeCipherSpec:
  638. if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
  639. if enqueue {
  640. c.encryptedPackets = append(c.encryptedPackets, buf)
  641. c.log.Debugf("CipherSuite not initialized, queuing packet")
  642. }
  643. return false, nil, nil
  644. }
  645. newRemoteEpoch := h.Epoch + 1
  646. c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
  647. if c.state.getRemoteEpoch()+1 == newRemoteEpoch {
  648. c.setRemoteEpoch(newRemoteEpoch)
  649. markPacketAsValid()
  650. }
  651. case *protocol.ApplicationData:
  652. if h.Epoch == 0 {
  653. return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
  654. }
  655. markPacketAsValid()
  656. select {
  657. case c.decrypted <- content.Data:
  658. case <-c.closed.Done():
  659. case <-ctx.Done():
  660. }
  661. default:
  662. return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
  663. }
  664. return false, nil, nil
  665. }
  666. func (c *Conn) recvHandshake() <-chan chan struct{} {
  667. return c.handshakeRecv
  668. }
  669. func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
  670. if level == alert.Fatal && len(c.state.SessionID) > 0 {
  671. // According to the RFC, we need to delete the stored session.
  672. // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
  673. if ss := c.fsm.cfg.sessionStore; ss != nil {
  674. c.log.Tracef("clean invalid session: %s", c.state.SessionID)
  675. if err := ss.Del(c.sessionKey()); err != nil {
  676. return err
  677. }
  678. }
  679. }
  680. return c.writePackets(ctx, []*packet{
  681. {
  682. record: &recordlayer.RecordLayer{
  683. Header: recordlayer.Header{
  684. Epoch: c.state.getLocalEpoch(),
  685. Version: protocol.Version1_2,
  686. },
  687. Content: &alert.Alert{
  688. Level: level,
  689. Description: desc,
  690. },
  691. },
  692. shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
  693. },
  694. })
  695. }
  696. func (c *Conn) setHandshakeCompletedSuccessfully() {
  697. c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
  698. }
  699. func (c *Conn) isHandshakeCompletedSuccessfully() bool {
  700. boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
  701. return boolean.bool
  702. }
  703. func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit
  704. c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)
  705. done := make(chan struct{})
  706. ctxRead, cancelRead := context.WithCancel(context.Background())
  707. c.cancelHandshakeReader = cancelRead
  708. cfg.onFlightState = func(f flightVal, s handshakeState) {
  709. if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
  710. c.setHandshakeCompletedSuccessfully()
  711. close(done)
  712. }
  713. }
  714. ctxHs, cancel := context.WithCancel(context.Background())
  715. c.cancelHandshaker = cancel
  716. firstErr := make(chan error, 1)
  717. c.handshakeLoopsFinished.Add(2)
  718. // Handshake routine should be live until close.
  719. // The other party may request retransmission of the last flight to cope with packet drop.
  720. go func() {
  721. defer c.handshakeLoopsFinished.Done()
  722. err := c.fsm.Run(ctxHs, c, initialState)
  723. if !errors.Is(err, context.Canceled) {
  724. select {
  725. case firstErr <- err:
  726. default:
  727. }
  728. }
  729. }()
  730. go func() {
  731. defer func() {
  732. // Escaping read loop.
  733. // It's safe to close decrypted channnel now.
  734. close(c.decrypted)
  735. // Force stop handshaker when the underlying connection is closed.
  736. cancel()
  737. }()
  738. defer c.handshakeLoopsFinished.Done()
  739. for {
  740. if err := c.readAndBuffer(ctxRead); err != nil {
  741. var e *alertError
  742. if errors.As(err, &e) {
  743. if !e.IsFatalOrCloseNotify() {
  744. if c.isHandshakeCompletedSuccessfully() {
  745. // Pass the error to Read()
  746. select {
  747. case c.decrypted <- err:
  748. case <-c.closed.Done():
  749. case <-ctxRead.Done():
  750. }
  751. }
  752. continue // non-fatal alert must not stop read loop
  753. }
  754. } else {
  755. switch {
  756. case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF):
  757. default:
  758. if c.isHandshakeCompletedSuccessfully() {
  759. // Keep read loop and pass the read error to Read()
  760. select {
  761. case c.decrypted <- err:
  762. case <-c.closed.Done():
  763. case <-ctxRead.Done():
  764. }
  765. continue // non-fatal alert must not stop read loop
  766. }
  767. }
  768. }
  769. select {
  770. case firstErr <- err:
  771. default:
  772. }
  773. if e != nil {
  774. if e.IsFatalOrCloseNotify() {
  775. _ = c.close(false) //nolint:contextcheck
  776. }
  777. }
  778. if !c.isConnectionClosed() && errors.Is(err, context.Canceled) {
  779. c.log.Trace("handshake timeouts - closing underline connection")
  780. _ = c.close(false) //nolint:contextcheck
  781. }
  782. return
  783. }
  784. }
  785. }()
  786. select {
  787. case err := <-firstErr:
  788. cancelRead()
  789. cancel()
  790. c.handshakeLoopsFinished.Wait()
  791. return c.translateHandshakeCtxError(err)
  792. case <-ctx.Done():
  793. cancelRead()
  794. cancel()
  795. c.handshakeLoopsFinished.Wait()
  796. return c.translateHandshakeCtxError(ctx.Err())
  797. case <-done:
  798. return nil
  799. }
  800. }
  801. func (c *Conn) translateHandshakeCtxError(err error) error {
  802. if err == nil {
  803. return nil
  804. }
  805. if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
  806. return nil
  807. }
  808. return &HandshakeError{Err: err}
  809. }
  810. func (c *Conn) close(byUser bool) error {
  811. c.cancelHandshaker()
  812. c.cancelHandshakeReader()
  813. if c.isHandshakeCompletedSuccessfully() && byUser {
  814. // Discard error from notify() to return non-error on the first user call of Close()
  815. // even if the underlying connection is already closed.
  816. _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify)
  817. }
  818. c.closeLock.Lock()
  819. // Don't return ErrConnClosed at the first time of the call from user.
  820. closedByUser := c.connectionClosedByUser
  821. if byUser {
  822. c.connectionClosedByUser = true
  823. }
  824. isClosed := c.isConnectionClosed()
  825. c.closed.Close()
  826. c.closeLock.Unlock()
  827. if closedByUser {
  828. return ErrConnClosed
  829. }
  830. if isClosed {
  831. return nil
  832. }
  833. return c.nextConn.Close()
  834. }
  835. func (c *Conn) isConnectionClosed() bool {
  836. select {
  837. case <-c.closed.Done():
  838. return true
  839. default:
  840. return false
  841. }
  842. }
  843. func (c *Conn) setLocalEpoch(epoch uint16) {
  844. c.state.localEpoch.Store(epoch)
  845. }
  846. func (c *Conn) setRemoteEpoch(epoch uint16) {
  847. c.state.remoteEpoch.Store(epoch)
  848. }
  849. // LocalAddr implements net.Conn.LocalAddr
  850. func (c *Conn) LocalAddr() net.Addr {
  851. return c.nextConn.LocalAddr()
  852. }
  853. // RemoteAddr implements net.Conn.RemoteAddr
  854. func (c *Conn) RemoteAddr() net.Addr {
  855. return c.nextConn.RemoteAddr()
  856. }
  857. func (c *Conn) sessionKey() []byte {
  858. if c.state.isClient {
  859. // As ServerName can be like 0.example.com, it's better to add
  860. // delimiter character which is not allowed to be in
  861. // neither address or domain name.
  862. return []byte(c.nextConn.RemoteAddr().String() + "_" + c.fsm.cfg.serverName)
  863. }
  864. return c.state.SessionID
  865. }
  866. // SetDeadline implements net.Conn.SetDeadline
  867. func (c *Conn) SetDeadline(t time.Time) error {
  868. c.readDeadline.Set(t)
  869. return c.SetWriteDeadline(t)
  870. }
  871. // SetReadDeadline implements net.Conn.SetReadDeadline
  872. func (c *Conn) SetReadDeadline(t time.Time) error {
  873. c.readDeadline.Set(t)
  874. // Read deadline is fully managed by this layer.
  875. // Don't set read deadline to underlying connection.
  876. return nil
  877. }
  878. // SetWriteDeadline implements net.Conn.SetWriteDeadline
  879. func (c *Conn) SetWriteDeadline(t time.Time) error {
  880. c.writeDeadline.Set(t)
  881. // Write deadline is also fully managed by this layer.
  882. return nil
  883. }