conn.go 28 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "context"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "github.com/pion/dtls/v2/internal/closer"
  14. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  15. "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
  16. "github.com/pion/dtls/v2/pkg/protocol"
  17. "github.com/pion/dtls/v2/pkg/protocol/alert"
  18. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  19. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  20. "github.com/pion/logging"
  21. "github.com/pion/transport/v2/connctx"
  22. "github.com/pion/transport/v2/deadline"
  23. "github.com/pion/transport/v2/replaydetector"
  24. )
  25. const (
  26. initialTickerInterval = time.Second
  27. cookieLength = 20
  28. sessionLength = 32
  29. defaultNamedCurve = elliptic.X25519
  30. inboundBufferSize = 8192
  31. // Default replay protection window is specified by RFC 6347 Section 4.1.2.6
  32. defaultReplayProtectionWindow = 64
  33. )
  34. func invalidKeyingLabels() map[string]bool {
  35. return map[string]bool{
  36. "client finished": true,
  37. "server finished": true,
  38. "master secret": true,
  39. "key expansion": true,
  40. }
  41. }
  42. // Conn represents a DTLS connection
  43. type Conn struct {
  44. lock sync.RWMutex // Internal lock (must not be public)
  45. nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from
  46. fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling
  47. handshakeCache *handshakeCache // caching of handshake messages for verifyData generation
  48. decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read`
  49. state State // Internal state
  50. maximumTransmissionUnit int
  51. handshakeCompletedSuccessfully atomic.Value
  52. encryptedPackets [][]byte
  53. connectionClosedByUser bool
  54. closeLock sync.Mutex
  55. closed *closer.Closer
  56. handshakeLoopsFinished sync.WaitGroup
  57. readDeadline *deadline.Deadline
  58. writeDeadline *deadline.Deadline
  59. log logging.LeveledLogger
  60. reading chan struct{}
  61. handshakeRecv chan chan struct{}
  62. cancelHandshaker func()
  63. cancelHandshakeReader func()
  64. fsm *handshakeFSM
  65. replayProtectionWindow uint
  66. }
  67. func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
  68. err := validateConfig(config)
  69. if err != nil {
  70. return nil, err
  71. }
  72. if nextConn == nil {
  73. return nil, errNilNextConn
  74. }
  75. cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
  76. if err != nil {
  77. return nil, err
  78. }
  79. signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
  80. if err != nil {
  81. return nil, err
  82. }
  83. workerInterval := initialTickerInterval
  84. if config.FlightInterval != 0 {
  85. workerInterval = config.FlightInterval
  86. }
  87. loggerFactory := config.LoggerFactory
  88. if loggerFactory == nil {
  89. loggerFactory = logging.NewDefaultLoggerFactory()
  90. }
  91. logger := loggerFactory.NewLogger("dtls")
  92. mtu := config.MTU
  93. if mtu <= 0 {
  94. mtu = defaultMTU
  95. }
  96. replayProtectionWindow := config.ReplayProtectionWindow
  97. if replayProtectionWindow <= 0 {
  98. replayProtectionWindow = defaultReplayProtectionWindow
  99. }
  100. c := &Conn{
  101. nextConn: connctx.New(nextConn),
  102. fragmentBuffer: newFragmentBuffer(),
  103. handshakeCache: newHandshakeCache(),
  104. maximumTransmissionUnit: mtu,
  105. decrypted: make(chan interface{}, 1),
  106. log: logger,
  107. readDeadline: deadline.New(),
  108. writeDeadline: deadline.New(),
  109. reading: make(chan struct{}, 1),
  110. handshakeRecv: make(chan chan struct{}),
  111. closed: closer.NewCloser(),
  112. cancelHandshaker: func() {},
  113. replayProtectionWindow: uint(replayProtectionWindow),
  114. state: State{
  115. isClient: isClient,
  116. },
  117. }
  118. c.setRemoteEpoch(0)
  119. c.setLocalEpoch(0)
  120. serverName := config.ServerName
  121. // Do not allow the use of an IP address literal as an SNI value.
  122. // See RFC 6066, Section 3.
  123. if net.ParseIP(serverName) != nil {
  124. serverName = ""
  125. }
  126. curves := config.EllipticCurves
  127. if len(curves) == 0 {
  128. curves = defaultCurves
  129. }
  130. hsCfg := &handshakeConfig{
  131. localPSKCallback: config.PSK,
  132. localPSKIdentityHint: config.PSKIdentityHint,
  133. localCipherSuites: cipherSuites,
  134. localSignatureSchemes: signatureSchemes,
  135. extendedMasterSecret: config.ExtendedMasterSecret,
  136. localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
  137. serverName: serverName,
  138. supportedProtocols: config.SupportedProtocols,
  139. clientAuth: config.ClientAuth,
  140. localCertificates: config.Certificates,
  141. insecureSkipVerify: config.InsecureSkipVerify,
  142. verifyPeerCertificate: config.VerifyPeerCertificate,
  143. verifyConnection: config.VerifyConnection,
  144. rootCAs: config.RootCAs,
  145. clientCAs: config.ClientCAs,
  146. customCipherSuites: config.CustomCipherSuites,
  147. retransmitInterval: workerInterval,
  148. log: logger,
  149. initialEpoch: 0,
  150. keyLogWriter: config.KeyLogWriter,
  151. sessionStore: config.SessionStore,
  152. ellipticCurves: curves,
  153. localGetCertificate: config.GetCertificate,
  154. localGetClientCertificate: config.GetClientCertificate,
  155. insecureSkipHelloVerify: config.InsecureSkipVerifyHello,
  156. // [Psiphon]
  157. // Conjure DTLS support, from: https://github.com/mingyech/dtls/commit/a56eccc1
  158. customClientHelloRandom: config.CustomClientHelloRandom,
  159. }
  160. // rfc5246#section-7.4.3
  161. // In addition, the hash and signature algorithms MUST be compatible
  162. // with the key in the server's end-entity certificate.
  163. if !isClient {
  164. cert, err := hsCfg.getCertificate(&ClientHelloInfo{})
  165. if err != nil && !errors.Is(err, errNoCertificates) {
  166. return nil, err
  167. }
  168. hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites)
  169. }
  170. var initialFlight flightVal
  171. var initialFSMState handshakeState
  172. if initialState != nil {
  173. if c.state.isClient {
  174. initialFlight = flight5
  175. } else {
  176. initialFlight = flight6
  177. }
  178. initialFSMState = handshakeFinished
  179. c.state = *initialState
  180. } else {
  181. if c.state.isClient {
  182. initialFlight = flight1
  183. } else {
  184. initialFlight = flight0
  185. }
  186. initialFSMState = handshakePreparing
  187. }
  188. // Do handshake
  189. if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
  190. return nil, err
  191. }
  192. c.log.Trace("Handshake Completed")
  193. return c, nil
  194. }
  195. // Dial connects to the given network address and establishes a DTLS connection on top.
  196. // Connection handshake will timeout using ConnectContextMaker in the Config.
  197. // If you want to specify the timeout duration, use DialWithContext() instead.
  198. func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
  199. ctx, cancel := config.connectContextMaker()
  200. defer cancel()
  201. return DialWithContext(ctx, network, raddr, config)
  202. }
  203. // Client establishes a DTLS connection over an existing connection.
  204. // Connection handshake will timeout using ConnectContextMaker in the Config.
  205. // If you want to specify the timeout duration, use ClientWithContext() instead.
  206. func Client(conn net.Conn, config *Config) (*Conn, error) {
  207. ctx, cancel := config.connectContextMaker()
  208. defer cancel()
  209. return ClientWithContext(ctx, conn, config)
  210. }
  211. // Server listens for incoming DTLS connections.
  212. // Connection handshake will timeout using ConnectContextMaker in the Config.
  213. // If you want to specify the timeout duration, use ServerWithContext() instead.
  214. func Server(conn net.Conn, config *Config) (*Conn, error) {
  215. ctx, cancel := config.connectContextMaker()
  216. defer cancel()
  217. return ServerWithContext(ctx, conn, config)
  218. }
  219. // DialWithContext connects to the given network address and establishes a DTLS connection on top.
  220. func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
  221. pConn, err := net.DialUDP(network, nil, raddr)
  222. if err != nil {
  223. return nil, err
  224. }
  225. return ClientWithContext(ctx, pConn, config)
  226. }
  227. // ClientWithContext establishes a DTLS connection over an existing connection.
  228. func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
  229. switch {
  230. case config == nil:
  231. return nil, errNoConfigProvided
  232. case config.PSK != nil && config.PSKIdentityHint == nil:
  233. return nil, errPSKAndIdentityMustBeSetForClient
  234. }
  235. return createConn(ctx, conn, config, true, nil)
  236. }
  237. // ServerWithContext listens for incoming DTLS connections.
  238. func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
  239. if config == nil {
  240. return nil, errNoConfigProvided
  241. }
  242. return createConn(ctx, conn, config, false, nil)
  243. }
  244. // Read reads data from the connection.
  245. func (c *Conn) Read(p []byte) (n int, err error) {
  246. if !c.isHandshakeCompletedSuccessfully() {
  247. return 0, errHandshakeInProgress
  248. }
  249. select {
  250. case <-c.readDeadline.Done():
  251. return 0, errDeadlineExceeded
  252. default:
  253. }
  254. for {
  255. select {
  256. case <-c.readDeadline.Done():
  257. return 0, errDeadlineExceeded
  258. case out, ok := <-c.decrypted:
  259. if !ok {
  260. return 0, io.EOF
  261. }
  262. switch val := out.(type) {
  263. case ([]byte):
  264. if len(p) < len(val) {
  265. return 0, errBufferTooSmall
  266. }
  267. copy(p, val)
  268. return len(val), nil
  269. case (error):
  270. return 0, val
  271. }
  272. }
  273. }
  274. }
  275. // Write writes len(p) bytes from p to the DTLS connection
  276. func (c *Conn) Write(p []byte) (int, error) {
  277. if c.isConnectionClosed() {
  278. return 0, ErrConnClosed
  279. }
  280. select {
  281. case <-c.writeDeadline.Done():
  282. return 0, errDeadlineExceeded
  283. default:
  284. }
  285. if !c.isHandshakeCompletedSuccessfully() {
  286. return 0, errHandshakeInProgress
  287. }
  288. return len(p), c.writePackets(c.writeDeadline, []*packet{
  289. {
  290. record: &recordlayer.RecordLayer{
  291. Header: recordlayer.Header{
  292. Epoch: c.state.getLocalEpoch(),
  293. Version: protocol.Version1_2,
  294. },
  295. Content: &protocol.ApplicationData{
  296. Data: p,
  297. },
  298. },
  299. shouldEncrypt: true,
  300. },
  301. })
  302. }
  303. // Close closes the connection.
  304. func (c *Conn) Close() error {
  305. err := c.close(true) //nolint:contextcheck
  306. c.handshakeLoopsFinished.Wait()
  307. return err
  308. }
  309. // ConnectionState returns basic DTLS details about the connection.
  310. // Note that this replaced the `Export` function of v1.
  311. func (c *Conn) ConnectionState() State {
  312. c.lock.RLock()
  313. defer c.lock.RUnlock()
  314. return *c.state.clone()
  315. }
  316. // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
  317. func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
  318. profile := c.state.getSRTPProtectionProfile()
  319. if profile == 0 {
  320. return 0, false
  321. }
  322. return profile, true
  323. }
  324. func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
  325. c.lock.Lock()
  326. defer c.lock.Unlock()
  327. var rawPackets [][]byte
  328. for _, p := range pkts {
  329. if h, ok := p.record.Content.(*handshake.Handshake); ok {
  330. handshakeRaw, err := p.record.Marshal()
  331. if err != nil {
  332. return err
  333. }
  334. c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
  335. srvCliStr(c.state.isClient), h.Header.Type.String(),
  336. p.record.Header.Epoch, h.Header.MessageSequence)
  337. c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
  338. rawHandshakePackets, err := c.processHandshakePacket(p, h)
  339. if err != nil {
  340. return err
  341. }
  342. rawPackets = append(rawPackets, rawHandshakePackets...)
  343. } else {
  344. rawPacket, err := c.processPacket(p)
  345. if err != nil {
  346. return err
  347. }
  348. rawPackets = append(rawPackets, rawPacket)
  349. }
  350. }
  351. if len(rawPackets) == 0 {
  352. return nil
  353. }
  354. compactedRawPackets := c.compactRawPackets(rawPackets)
  355. for _, compactedRawPackets := range compactedRawPackets {
  356. if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil {
  357. return netError(err)
  358. }
  359. }
  360. return nil
  361. }
  362. func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
  363. // avoid a useless copy in the common case
  364. if len(rawPackets) == 1 {
  365. return rawPackets
  366. }
  367. combinedRawPackets := make([][]byte, 0)
  368. currentCombinedRawPacket := make([]byte, 0)
  369. for _, rawPacket := range rawPackets {
  370. if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
  371. combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
  372. currentCombinedRawPacket = []byte{}
  373. }
  374. currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...)
  375. }
  376. combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
  377. return combinedRawPackets
  378. }
  379. func (c *Conn) processPacket(p *packet) ([]byte, error) {
  380. epoch := p.record.Header.Epoch
  381. for len(c.state.localSequenceNumber) <= int(epoch) {
  382. c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
  383. }
  384. seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
  385. if seq > recordlayer.MaxSequenceNumber {
  386. // RFC 6347 Section 4.1.0
  387. // The implementation must either abandon an association or rehandshake
  388. // prior to allowing the sequence number to wrap.
  389. return nil, errSequenceNumberOverflow
  390. }
  391. p.record.Header.SequenceNumber = seq
  392. rawPacket, err := p.record.Marshal()
  393. if err != nil {
  394. return nil, err
  395. }
  396. if p.shouldEncrypt {
  397. var err error
  398. rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
  399. if err != nil {
  400. return nil, err
  401. }
  402. }
  403. return rawPacket, nil
  404. }
  405. func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) {
  406. rawPackets := make([][]byte, 0)
  407. handshakeFragments, err := c.fragmentHandshake(h)
  408. if err != nil {
  409. return nil, err
  410. }
  411. epoch := p.record.Header.Epoch
  412. for len(c.state.localSequenceNumber) <= int(epoch) {
  413. c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
  414. }
  415. for _, handshakeFragment := range handshakeFragments {
  416. seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
  417. if seq > recordlayer.MaxSequenceNumber {
  418. return nil, errSequenceNumberOverflow
  419. }
  420. recordlayerHeader := &recordlayer.Header{
  421. Version: p.record.Header.Version,
  422. ContentType: p.record.Header.ContentType,
  423. ContentLen: uint16(len(handshakeFragment)),
  424. Epoch: p.record.Header.Epoch,
  425. SequenceNumber: seq,
  426. }
  427. rawPacket, err := recordlayerHeader.Marshal()
  428. if err != nil {
  429. return nil, err
  430. }
  431. p.record.Header = *recordlayerHeader
  432. rawPacket = append(rawPacket, handshakeFragment...)
  433. if p.shouldEncrypt {
  434. var err error
  435. rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
  436. if err != nil {
  437. return nil, err
  438. }
  439. }
  440. rawPackets = append(rawPackets, rawPacket)
  441. }
  442. return rawPackets, nil
  443. }
  444. func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
  445. content, err := h.Message.Marshal()
  446. if err != nil {
  447. return nil, err
  448. }
  449. fragmentedHandshakes := make([][]byte, 0)
  450. contentFragments := splitBytes(content, c.maximumTransmissionUnit)
  451. if len(contentFragments) == 0 {
  452. contentFragments = [][]byte{
  453. {},
  454. }
  455. }
  456. offset := 0
  457. for _, contentFragment := range contentFragments {
  458. contentFragmentLen := len(contentFragment)
  459. headerFragment := &handshake.Header{
  460. Type: h.Header.Type,
  461. Length: h.Header.Length,
  462. MessageSequence: h.Header.MessageSequence,
  463. FragmentOffset: uint32(offset),
  464. FragmentLength: uint32(contentFragmentLen),
  465. }
  466. offset += contentFragmentLen
  467. fragmentedHandshake, err := headerFragment.Marshal()
  468. if err != nil {
  469. return nil, err
  470. }
  471. fragmentedHandshake = append(fragmentedHandshake, contentFragment...)
  472. fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
  473. }
  474. return fragmentedHandshakes, nil
  475. }
  476. var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
  477. New: func() interface{} {
  478. b := make([]byte, inboundBufferSize)
  479. return &b
  480. },
  481. }
  482. func (c *Conn) readAndBuffer(ctx context.Context) error {
  483. bufptr, ok := poolReadBuffer.Get().(*[]byte)
  484. if !ok {
  485. return errFailedToAccessPoolReadBuffer
  486. }
  487. defer poolReadBuffer.Put(bufptr)
  488. b := *bufptr
  489. i, err := c.nextConn.ReadContext(ctx, b)
  490. if err != nil {
  491. return netError(err)
  492. }
  493. pkts, err := recordlayer.UnpackDatagram(b[:i])
  494. if err != nil {
  495. return err
  496. }
  497. var hasHandshake bool
  498. for _, p := range pkts {
  499. hs, alert, err := c.handleIncomingPacket(ctx, p, true)
  500. if alert != nil {
  501. if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
  502. if err == nil {
  503. err = alertErr
  504. }
  505. }
  506. }
  507. if hs {
  508. hasHandshake = true
  509. }
  510. var e *alertError
  511. if errors.As(err, &e) {
  512. if e.IsFatalOrCloseNotify() {
  513. return e
  514. }
  515. } else if err != nil {
  516. return e
  517. }
  518. }
  519. if hasHandshake {
  520. done := make(chan struct{})
  521. select {
  522. case c.handshakeRecv <- done:
  523. // If the other party may retransmit the flight,
  524. // we should respond even if it not a new message.
  525. <-done
  526. case <-c.fsm.Done():
  527. }
  528. }
  529. return nil
  530. }
  531. func (c *Conn) handleQueuedPackets(ctx context.Context) error {
  532. pkts := c.encryptedPackets
  533. c.encryptedPackets = nil
  534. for _, p := range pkts {
  535. _, alert, err := c.handleIncomingPacket(ctx, p, false) // don't re-enqueue
  536. if alert != nil {
  537. if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
  538. if err == nil {
  539. err = alertErr
  540. }
  541. }
  542. }
  543. var e *alertError
  544. if errors.As(err, &e) {
  545. if e.IsFatalOrCloseNotify() {
  546. return e
  547. }
  548. } else if err != nil {
  549. return e
  550. }
  551. }
  552. return nil
  553. }
  554. func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
  555. h := &recordlayer.Header{}
  556. if err := h.Unmarshal(buf); err != nil {
  557. // Decode error must be silently discarded
  558. // [RFC6347 Section-4.1.2.7]
  559. c.log.Debugf("discarded broken packet: %v", err)
  560. return false, nil, nil
  561. }
  562. // Validate epoch
  563. remoteEpoch := c.state.getRemoteEpoch()
  564. if h.Epoch > remoteEpoch {
  565. if h.Epoch > remoteEpoch+1 {
  566. c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
  567. h.Epoch, h.SequenceNumber,
  568. )
  569. return false, nil, nil
  570. }
  571. if enqueue {
  572. c.log.Debug("received packet of next epoch, queuing packet")
  573. c.encryptedPackets = append(c.encryptedPackets, buf)
  574. }
  575. return false, nil, nil
  576. }
  577. // Anti-replay protection
  578. for len(c.state.replayDetector) <= int(h.Epoch) {
  579. c.state.replayDetector = append(c.state.replayDetector,
  580. replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber),
  581. )
  582. }
  583. markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber)
  584. if !ok {
  585. c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
  586. h.Epoch, h.SequenceNumber,
  587. )
  588. return false, nil, nil
  589. }
  590. // Decrypt
  591. if h.Epoch != 0 {
  592. if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
  593. if enqueue {
  594. c.encryptedPackets = append(c.encryptedPackets, buf)
  595. c.log.Debug("handshake not finished, queuing packet")
  596. }
  597. return false, nil, nil
  598. }
  599. var err error
  600. buf, err = c.state.cipherSuite.Decrypt(buf)
  601. if err != nil {
  602. c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
  603. return false, nil, nil
  604. }
  605. }
  606. isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
  607. if err != nil {
  608. // Decode error must be silently discarded
  609. // [RFC6347 Section-4.1.2.7]
  610. c.log.Debugf("defragment failed: %s", err)
  611. return false, nil, nil
  612. } else if isHandshake {
  613. markPacketAsValid()
  614. for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
  615. header := &handshake.Header{}
  616. if err := header.Unmarshal(out); err != nil {
  617. c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
  618. continue
  619. }
  620. c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
  621. }
  622. return true, nil, nil
  623. }
  624. r := &recordlayer.RecordLayer{}
  625. if err := r.Unmarshal(buf); err != nil {
  626. return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
  627. }
  628. switch content := r.Content.(type) {
  629. case *alert.Alert:
  630. c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
  631. var a *alert.Alert
  632. if content.Description == alert.CloseNotify {
  633. // Respond with a close_notify [RFC5246 Section 7.2.1]
  634. a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
  635. }
  636. markPacketAsValid()
  637. return false, a, &alertError{content}
  638. case *protocol.ChangeCipherSpec:
  639. if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
  640. if enqueue {
  641. c.encryptedPackets = append(c.encryptedPackets, buf)
  642. c.log.Debugf("CipherSuite not initialized, queuing packet")
  643. }
  644. return false, nil, nil
  645. }
  646. newRemoteEpoch := h.Epoch + 1
  647. c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
  648. if c.state.getRemoteEpoch()+1 == newRemoteEpoch {
  649. c.setRemoteEpoch(newRemoteEpoch)
  650. markPacketAsValid()
  651. }
  652. case *protocol.ApplicationData:
  653. if h.Epoch == 0 {
  654. return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
  655. }
  656. markPacketAsValid()
  657. select {
  658. case c.decrypted <- content.Data:
  659. case <-c.closed.Done():
  660. case <-ctx.Done():
  661. }
  662. default:
  663. return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
  664. }
  665. return false, nil, nil
  666. }
  667. func (c *Conn) recvHandshake() <-chan chan struct{} {
  668. return c.handshakeRecv
  669. }
  670. func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
  671. if level == alert.Fatal && len(c.state.SessionID) > 0 {
  672. // According to the RFC, we need to delete the stored session.
  673. // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
  674. if ss := c.fsm.cfg.sessionStore; ss != nil {
  675. c.log.Tracef("clean invalid session: %s", c.state.SessionID)
  676. if err := ss.Del(c.sessionKey()); err != nil {
  677. return err
  678. }
  679. }
  680. }
  681. return c.writePackets(ctx, []*packet{
  682. {
  683. record: &recordlayer.RecordLayer{
  684. Header: recordlayer.Header{
  685. Epoch: c.state.getLocalEpoch(),
  686. Version: protocol.Version1_2,
  687. },
  688. Content: &alert.Alert{
  689. Level: level,
  690. Description: desc,
  691. },
  692. },
  693. shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
  694. },
  695. })
  696. }
  697. func (c *Conn) setHandshakeCompletedSuccessfully() {
  698. c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
  699. }
  700. func (c *Conn) isHandshakeCompletedSuccessfully() bool {
  701. boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
  702. return boolean.bool
  703. }
  704. func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit
  705. c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)
  706. done := make(chan struct{})
  707. ctxRead, cancelRead := context.WithCancel(context.Background())
  708. c.cancelHandshakeReader = cancelRead
  709. cfg.onFlightState = func(f flightVal, s handshakeState) {
  710. if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
  711. c.setHandshakeCompletedSuccessfully()
  712. close(done)
  713. }
  714. }
  715. // [Psiphon]
  716. // Pass dial context into handshake goroutine for GetDTLSSeed.
  717. ctxHs, cancel := context.WithCancel(ctx)
  718. c.cancelHandshaker = cancel
  719. firstErr := make(chan error, 1)
  720. c.handshakeLoopsFinished.Add(2)
  721. // Handshake routine should be live until close.
  722. // The other party may request retransmission of the last flight to cope with packet drop.
  723. go func() {
  724. defer c.handshakeLoopsFinished.Done()
  725. err := c.fsm.Run(ctxHs, c, initialState)
  726. if !errors.Is(err, context.Canceled) {
  727. select {
  728. case firstErr <- err:
  729. default:
  730. }
  731. }
  732. }()
  733. go func() {
  734. defer func() {
  735. // Escaping read loop.
  736. // It's safe to close decrypted channnel now.
  737. close(c.decrypted)
  738. // Force stop handshaker when the underlying connection is closed.
  739. cancel()
  740. }()
  741. defer c.handshakeLoopsFinished.Done()
  742. for {
  743. if err := c.readAndBuffer(ctxRead); err != nil {
  744. var e *alertError
  745. if errors.As(err, &e) {
  746. if !e.IsFatalOrCloseNotify() {
  747. if c.isHandshakeCompletedSuccessfully() {
  748. // Pass the error to Read()
  749. select {
  750. case c.decrypted <- err:
  751. case <-c.closed.Done():
  752. case <-ctxRead.Done():
  753. }
  754. }
  755. continue // non-fatal alert must not stop read loop
  756. }
  757. } else {
  758. switch {
  759. case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed):
  760. case errors.Is(err, recordlayer.ErrInvalidPacketLength):
  761. // Decode error must be silently discarded
  762. // [RFC6347 Section-4.1.2.7]
  763. continue
  764. default:
  765. if c.isHandshakeCompletedSuccessfully() {
  766. // Keep read loop and pass the read error to Read()
  767. select {
  768. case c.decrypted <- err:
  769. case <-c.closed.Done():
  770. case <-ctxRead.Done():
  771. }
  772. continue // non-fatal alert must not stop read loop
  773. }
  774. }
  775. }
  776. select {
  777. case firstErr <- err:
  778. default:
  779. }
  780. if e != nil {
  781. if e.IsFatalOrCloseNotify() {
  782. _ = c.close(false) //nolint:contextcheck
  783. }
  784. }
  785. if !c.isConnectionClosed() && errors.Is(err, context.Canceled) {
  786. c.log.Trace("handshake timeouts - closing underline connection")
  787. _ = c.close(false) //nolint:contextcheck
  788. }
  789. return
  790. }
  791. }
  792. }()
  793. select {
  794. case err := <-firstErr:
  795. cancelRead()
  796. cancel()
  797. c.handshakeLoopsFinished.Wait()
  798. return c.translateHandshakeCtxError(err)
  799. case <-ctx.Done():
  800. cancelRead()
  801. cancel()
  802. c.handshakeLoopsFinished.Wait()
  803. return c.translateHandshakeCtxError(ctx.Err())
  804. case <-done:
  805. return nil
  806. }
  807. }
  808. func (c *Conn) translateHandshakeCtxError(err error) error {
  809. if err == nil {
  810. return nil
  811. }
  812. if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
  813. return nil
  814. }
  815. return &HandshakeError{Err: err}
  816. }
  817. func (c *Conn) close(byUser bool) error {
  818. c.cancelHandshaker()
  819. c.cancelHandshakeReader()
  820. if c.isHandshakeCompletedSuccessfully() && byUser {
  821. // Discard error from notify() to return non-error on the first user call of Close()
  822. // even if the underlying connection is already closed.
  823. _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify)
  824. }
  825. c.closeLock.Lock()
  826. // Don't return ErrConnClosed at the first time of the call from user.
  827. closedByUser := c.connectionClosedByUser
  828. if byUser {
  829. c.connectionClosedByUser = true
  830. }
  831. isClosed := c.isConnectionClosed()
  832. c.closed.Close()
  833. c.closeLock.Unlock()
  834. if closedByUser {
  835. return ErrConnClosed
  836. }
  837. if isClosed {
  838. return nil
  839. }
  840. return c.nextConn.Close()
  841. }
  842. func (c *Conn) isConnectionClosed() bool {
  843. select {
  844. case <-c.closed.Done():
  845. return true
  846. default:
  847. return false
  848. }
  849. }
  850. func (c *Conn) setLocalEpoch(epoch uint16) {
  851. c.state.localEpoch.Store(epoch)
  852. }
  853. func (c *Conn) setRemoteEpoch(epoch uint16) {
  854. c.state.remoteEpoch.Store(epoch)
  855. }
  856. // LocalAddr implements net.Conn.LocalAddr
  857. func (c *Conn) LocalAddr() net.Addr {
  858. return c.nextConn.LocalAddr()
  859. }
  860. // RemoteAddr implements net.Conn.RemoteAddr
  861. func (c *Conn) RemoteAddr() net.Addr {
  862. return c.nextConn.RemoteAddr()
  863. }
  864. func (c *Conn) sessionKey() []byte {
  865. if c.state.isClient {
  866. // As ServerName can be like 0.example.com, it's better to add
  867. // delimiter character which is not allowed to be in
  868. // neither address or domain name.
  869. return []byte(c.nextConn.RemoteAddr().String() + "_" + c.fsm.cfg.serverName)
  870. }
  871. return c.state.SessionID
  872. }
  873. // SetDeadline implements net.Conn.SetDeadline
  874. func (c *Conn) SetDeadline(t time.Time) error {
  875. c.readDeadline.Set(t)
  876. return c.SetWriteDeadline(t)
  877. }
  878. // SetReadDeadline implements net.Conn.SetReadDeadline
  879. func (c *Conn) SetReadDeadline(t time.Time) error {
  880. c.readDeadline.Set(t)
  881. // Read deadline is fully managed by this layer.
  882. // Don't set read deadline to underlying connection.
  883. return nil
  884. }
  885. // SetWriteDeadline implements net.Conn.SetWriteDeadline
  886. func (c *Conn) SetWriteDeadline(t time.Time) error {
  887. c.writeDeadline.Set(t)
  888. // Write deadline is also fully managed by this layer.
  889. return nil
  890. }