| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- package dtls
- import (
- "context"
- "errors"
- "fmt"
- "io"
- "net"
- "sync"
- "sync/atomic"
- "time"
- "github.com/pion/dtls/v2/internal/closer"
- "github.com/pion/dtls/v2/pkg/crypto/elliptic"
- "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
- "github.com/pion/dtls/v2/pkg/protocol"
- "github.com/pion/dtls/v2/pkg/protocol/alert"
- "github.com/pion/dtls/v2/pkg/protocol/handshake"
- "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
- "github.com/pion/logging"
- "github.com/pion/transport/v2/connctx"
- "github.com/pion/transport/v2/deadline"
- "github.com/pion/transport/v2/replaydetector"
- )
- const (
- initialTickerInterval = time.Second
- cookieLength = 20
- sessionLength = 32
- defaultNamedCurve = elliptic.X25519
- inboundBufferSize = 8192
- // Default replay protection window is specified by RFC 6347 Section 4.1.2.6
- defaultReplayProtectionWindow = 64
- )
- func invalidKeyingLabels() map[string]bool {
- return map[string]bool{
- "client finished": true,
- "server finished": true,
- "master secret": true,
- "key expansion": true,
- }
- }
- // Conn represents a DTLS connection
- type Conn struct {
- lock sync.RWMutex // Internal lock (must not be public)
- nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from
- fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling
- handshakeCache *handshakeCache // caching of handshake messages for verifyData generation
- decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read`
- state State // Internal state
- maximumTransmissionUnit int
- handshakeCompletedSuccessfully atomic.Value
- encryptedPackets [][]byte
- connectionClosedByUser bool
- closeLock sync.Mutex
- closed *closer.Closer
- handshakeLoopsFinished sync.WaitGroup
- readDeadline *deadline.Deadline
- writeDeadline *deadline.Deadline
- log logging.LeveledLogger
- reading chan struct{}
- handshakeRecv chan chan struct{}
- cancelHandshaker func()
- cancelHandshakeReader func()
- fsm *handshakeFSM
- replayProtectionWindow uint
- }
- func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
- err := validateConfig(config)
- if err != nil {
- return nil, err
- }
- if nextConn == nil {
- return nil, errNilNextConn
- }
- cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
- if err != nil {
- return nil, err
- }
- signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
- if err != nil {
- return nil, err
- }
- workerInterval := initialTickerInterval
- if config.FlightInterval != 0 {
- workerInterval = config.FlightInterval
- }
- loggerFactory := config.LoggerFactory
- if loggerFactory == nil {
- loggerFactory = logging.NewDefaultLoggerFactory()
- }
- logger := loggerFactory.NewLogger("dtls")
- mtu := config.MTU
- if mtu <= 0 {
- mtu = defaultMTU
- }
- replayProtectionWindow := config.ReplayProtectionWindow
- if replayProtectionWindow <= 0 {
- replayProtectionWindow = defaultReplayProtectionWindow
- }
- c := &Conn{
- nextConn: connctx.New(nextConn),
- fragmentBuffer: newFragmentBuffer(),
- handshakeCache: newHandshakeCache(),
- maximumTransmissionUnit: mtu,
- decrypted: make(chan interface{}, 1),
- log: logger,
- readDeadline: deadline.New(),
- writeDeadline: deadline.New(),
- reading: make(chan struct{}, 1),
- handshakeRecv: make(chan chan struct{}),
- closed: closer.NewCloser(),
- cancelHandshaker: func() {},
- replayProtectionWindow: uint(replayProtectionWindow),
- state: State{
- isClient: isClient,
- },
- }
- c.setRemoteEpoch(0)
- c.setLocalEpoch(0)
- serverName := config.ServerName
- // Do not allow the use of an IP address literal as an SNI value.
- // See RFC 6066, Section 3.
- if net.ParseIP(serverName) != nil {
- serverName = ""
- }
- curves := config.EllipticCurves
- if len(curves) == 0 {
- curves = defaultCurves
- }
- hsCfg := &handshakeConfig{
- localPSKCallback: config.PSK,
- localPSKIdentityHint: config.PSKIdentityHint,
- localCipherSuites: cipherSuites,
- localSignatureSchemes: signatureSchemes,
- extendedMasterSecret: config.ExtendedMasterSecret,
- localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
- serverName: serverName,
- supportedProtocols: config.SupportedProtocols,
- clientAuth: config.ClientAuth,
- localCertificates: config.Certificates,
- insecureSkipVerify: config.InsecureSkipVerify,
- verifyPeerCertificate: config.VerifyPeerCertificate,
- verifyConnection: config.VerifyConnection,
- rootCAs: config.RootCAs,
- clientCAs: config.ClientCAs,
- customCipherSuites: config.CustomCipherSuites,
- retransmitInterval: workerInterval,
- log: logger,
- initialEpoch: 0,
- keyLogWriter: config.KeyLogWriter,
- sessionStore: config.SessionStore,
- ellipticCurves: curves,
- localGetCertificate: config.GetCertificate,
- localGetClientCertificate: config.GetClientCertificate,
- insecureSkipHelloVerify: config.InsecureSkipVerifyHello,
- // [Psiphon]
- // Conjure DTLS support, from: https://github.com/mingyech/dtls/commit/a56eccc1
- customClientHelloRandom: config.CustomClientHelloRandom,
- }
- // rfc5246#section-7.4.3
- // In addition, the hash and signature algorithms MUST be compatible
- // with the key in the server's end-entity certificate.
- if !isClient {
- cert, err := hsCfg.getCertificate(&ClientHelloInfo{})
- if err != nil && !errors.Is(err, errNoCertificates) {
- return nil, err
- }
- hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites)
- }
- var initialFlight flightVal
- var initialFSMState handshakeState
- if initialState != nil {
- if c.state.isClient {
- initialFlight = flight5
- } else {
- initialFlight = flight6
- }
- initialFSMState = handshakeFinished
- c.state = *initialState
- } else {
- if c.state.isClient {
- initialFlight = flight1
- } else {
- initialFlight = flight0
- }
- initialFSMState = handshakePreparing
- }
- // Do handshake
- if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
- return nil, err
- }
- c.log.Trace("Handshake Completed")
- return c, nil
- }
- // Dial connects to the given network address and establishes a DTLS connection on top.
- // Connection handshake will timeout using ConnectContextMaker in the Config.
- // If you want to specify the timeout duration, use DialWithContext() instead.
- func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
- ctx, cancel := config.connectContextMaker()
- defer cancel()
- return DialWithContext(ctx, network, raddr, config)
- }
- // Client establishes a DTLS connection over an existing connection.
- // Connection handshake will timeout using ConnectContextMaker in the Config.
- // If you want to specify the timeout duration, use ClientWithContext() instead.
- func Client(conn net.Conn, config *Config) (*Conn, error) {
- ctx, cancel := config.connectContextMaker()
- defer cancel()
- return ClientWithContext(ctx, conn, config)
- }
- // Server listens for incoming DTLS connections.
- // Connection handshake will timeout using ConnectContextMaker in the Config.
- // If you want to specify the timeout duration, use ServerWithContext() instead.
- func Server(conn net.Conn, config *Config) (*Conn, error) {
- ctx, cancel := config.connectContextMaker()
- defer cancel()
- return ServerWithContext(ctx, conn, config)
- }
- // DialWithContext connects to the given network address and establishes a DTLS connection on top.
- func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
- pConn, err := net.DialUDP(network, nil, raddr)
- if err != nil {
- return nil, err
- }
- return ClientWithContext(ctx, pConn, config)
- }
- // ClientWithContext establishes a DTLS connection over an existing connection.
- func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
- switch {
- case config == nil:
- return nil, errNoConfigProvided
- case config.PSK != nil && config.PSKIdentityHint == nil:
- return nil, errPSKAndIdentityMustBeSetForClient
- }
- return createConn(ctx, conn, config, true, nil)
- }
- // ServerWithContext listens for incoming DTLS connections.
- func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
- if config == nil {
- return nil, errNoConfigProvided
- }
- return createConn(ctx, conn, config, false, nil)
- }
- // Read reads data from the connection.
- func (c *Conn) Read(p []byte) (n int, err error) {
- if !c.isHandshakeCompletedSuccessfully() {
- return 0, errHandshakeInProgress
- }
- select {
- case <-c.readDeadline.Done():
- return 0, errDeadlineExceeded
- default:
- }
- for {
- select {
- case <-c.readDeadline.Done():
- return 0, errDeadlineExceeded
- case out, ok := <-c.decrypted:
- if !ok {
- return 0, io.EOF
- }
- switch val := out.(type) {
- case ([]byte):
- if len(p) < len(val) {
- return 0, errBufferTooSmall
- }
- copy(p, val)
- return len(val), nil
- case (error):
- return 0, val
- }
- }
- }
- }
- // Write writes len(p) bytes from p to the DTLS connection
- func (c *Conn) Write(p []byte) (int, error) {
- if c.isConnectionClosed() {
- return 0, ErrConnClosed
- }
- select {
- case <-c.writeDeadline.Done():
- return 0, errDeadlineExceeded
- default:
- }
- if !c.isHandshakeCompletedSuccessfully() {
- return 0, errHandshakeInProgress
- }
- return len(p), c.writePackets(c.writeDeadline, []*packet{
- {
- record: &recordlayer.RecordLayer{
- Header: recordlayer.Header{
- Epoch: c.state.getLocalEpoch(),
- Version: protocol.Version1_2,
- },
- Content: &protocol.ApplicationData{
- Data: p,
- },
- },
- shouldEncrypt: true,
- },
- })
- }
- // Close closes the connection.
- func (c *Conn) Close() error {
- err := c.close(true) //nolint:contextcheck
- c.handshakeLoopsFinished.Wait()
- return err
- }
- // ConnectionState returns basic DTLS details about the connection.
- // Note that this replaced the `Export` function of v1.
- func (c *Conn) ConnectionState() State {
- c.lock.RLock()
- defer c.lock.RUnlock()
- return *c.state.clone()
- }
- // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
- func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
- profile := c.state.getSRTPProtectionProfile()
- if profile == 0 {
- return 0, false
- }
- return profile, true
- }
- func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
- c.lock.Lock()
- defer c.lock.Unlock()
- var rawPackets [][]byte
- for _, p := range pkts {
- if h, ok := p.record.Content.(*handshake.Handshake); ok {
- handshakeRaw, err := p.record.Marshal()
- if err != nil {
- return err
- }
- c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
- srvCliStr(c.state.isClient), h.Header.Type.String(),
- p.record.Header.Epoch, h.Header.MessageSequence)
- c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
- rawHandshakePackets, err := c.processHandshakePacket(p, h)
- if err != nil {
- return err
- }
- rawPackets = append(rawPackets, rawHandshakePackets...)
- } else {
- rawPacket, err := c.processPacket(p)
- if err != nil {
- return err
- }
- rawPackets = append(rawPackets, rawPacket)
- }
- }
- if len(rawPackets) == 0 {
- return nil
- }
- compactedRawPackets := c.compactRawPackets(rawPackets)
- for _, compactedRawPackets := range compactedRawPackets {
- if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil {
- return netError(err)
- }
- }
- return nil
- }
- func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
- // avoid a useless copy in the common case
- if len(rawPackets) == 1 {
- return rawPackets
- }
- combinedRawPackets := make([][]byte, 0)
- currentCombinedRawPacket := make([]byte, 0)
- for _, rawPacket := range rawPackets {
- if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
- combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
- currentCombinedRawPacket = []byte{}
- }
- currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...)
- }
- combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
- return combinedRawPackets
- }
- func (c *Conn) processPacket(p *packet) ([]byte, error) {
- epoch := p.record.Header.Epoch
- for len(c.state.localSequenceNumber) <= int(epoch) {
- c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
- }
- seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
- if seq > recordlayer.MaxSequenceNumber {
- // RFC 6347 Section 4.1.0
- // The implementation must either abandon an association or rehandshake
- // prior to allowing the sequence number to wrap.
- return nil, errSequenceNumberOverflow
- }
- p.record.Header.SequenceNumber = seq
- rawPacket, err := p.record.Marshal()
- if err != nil {
- return nil, err
- }
- if p.shouldEncrypt {
- var err error
- rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
- if err != nil {
- return nil, err
- }
- }
- return rawPacket, nil
- }
- func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) {
- rawPackets := make([][]byte, 0)
- handshakeFragments, err := c.fragmentHandshake(h)
- if err != nil {
- return nil, err
- }
- epoch := p.record.Header.Epoch
- for len(c.state.localSequenceNumber) <= int(epoch) {
- c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
- }
- for _, handshakeFragment := range handshakeFragments {
- seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
- if seq > recordlayer.MaxSequenceNumber {
- return nil, errSequenceNumberOverflow
- }
- recordlayerHeader := &recordlayer.Header{
- Version: p.record.Header.Version,
- ContentType: p.record.Header.ContentType,
- ContentLen: uint16(len(handshakeFragment)),
- Epoch: p.record.Header.Epoch,
- SequenceNumber: seq,
- }
- rawPacket, err := recordlayerHeader.Marshal()
- if err != nil {
- return nil, err
- }
- p.record.Header = *recordlayerHeader
- rawPacket = append(rawPacket, handshakeFragment...)
- if p.shouldEncrypt {
- var err error
- rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
- if err != nil {
- return nil, err
- }
- }
- rawPackets = append(rawPackets, rawPacket)
- }
- return rawPackets, nil
- }
- func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
- content, err := h.Message.Marshal()
- if err != nil {
- return nil, err
- }
- fragmentedHandshakes := make([][]byte, 0)
- contentFragments := splitBytes(content, c.maximumTransmissionUnit)
- if len(contentFragments) == 0 {
- contentFragments = [][]byte{
- {},
- }
- }
- offset := 0
- for _, contentFragment := range contentFragments {
- contentFragmentLen := len(contentFragment)
- headerFragment := &handshake.Header{
- Type: h.Header.Type,
- Length: h.Header.Length,
- MessageSequence: h.Header.MessageSequence,
- FragmentOffset: uint32(offset),
- FragmentLength: uint32(contentFragmentLen),
- }
- offset += contentFragmentLen
- fragmentedHandshake, err := headerFragment.Marshal()
- if err != nil {
- return nil, err
- }
- fragmentedHandshake = append(fragmentedHandshake, contentFragment...)
- fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
- }
- return fragmentedHandshakes, nil
- }
- var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
- New: func() interface{} {
- b := make([]byte, inboundBufferSize)
- return &b
- },
- }
- func (c *Conn) readAndBuffer(ctx context.Context) error {
- bufptr, ok := poolReadBuffer.Get().(*[]byte)
- if !ok {
- return errFailedToAccessPoolReadBuffer
- }
- defer poolReadBuffer.Put(bufptr)
- b := *bufptr
- i, err := c.nextConn.ReadContext(ctx, b)
- if err != nil {
- return netError(err)
- }
- pkts, err := recordlayer.UnpackDatagram(b[:i])
- if err != nil {
- return err
- }
- var hasHandshake bool
- for _, p := range pkts {
- hs, alert, err := c.handleIncomingPacket(ctx, p, true)
- if alert != nil {
- if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
- if err == nil {
- err = alertErr
- }
- }
- }
- if hs {
- hasHandshake = true
- }
- var e *alertError
- if errors.As(err, &e) {
- if e.IsFatalOrCloseNotify() {
- return e
- }
- } else if err != nil {
- return e
- }
- }
- if hasHandshake {
- done := make(chan struct{})
- select {
- case c.handshakeRecv <- done:
- // If the other party may retransmit the flight,
- // we should respond even if it not a new message.
- <-done
- case <-c.fsm.Done():
- }
- }
- return nil
- }
- func (c *Conn) handleQueuedPackets(ctx context.Context) error {
- pkts := c.encryptedPackets
- c.encryptedPackets = nil
- for _, p := range pkts {
- _, alert, err := c.handleIncomingPacket(ctx, p, false) // don't re-enqueue
- if alert != nil {
- if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
- if err == nil {
- err = alertErr
- }
- }
- }
- var e *alertError
- if errors.As(err, &e) {
- if e.IsFatalOrCloseNotify() {
- return e
- }
- } else if err != nil {
- return e
- }
- }
- return nil
- }
- func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
- h := &recordlayer.Header{}
- if err := h.Unmarshal(buf); err != nil {
- // Decode error must be silently discarded
- // [RFC6347 Section-4.1.2.7]
- c.log.Debugf("discarded broken packet: %v", err)
- return false, nil, nil
- }
- // Validate epoch
- remoteEpoch := c.state.getRemoteEpoch()
- if h.Epoch > remoteEpoch {
- if h.Epoch > remoteEpoch+1 {
- c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
- h.Epoch, h.SequenceNumber,
- )
- return false, nil, nil
- }
- if enqueue {
- c.log.Debug("received packet of next epoch, queuing packet")
- c.encryptedPackets = append(c.encryptedPackets, buf)
- }
- return false, nil, nil
- }
- // Anti-replay protection
- for len(c.state.replayDetector) <= int(h.Epoch) {
- c.state.replayDetector = append(c.state.replayDetector,
- replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber),
- )
- }
- markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber)
- if !ok {
- c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
- h.Epoch, h.SequenceNumber,
- )
- return false, nil, nil
- }
- // Decrypt
- if h.Epoch != 0 {
- if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
- if enqueue {
- c.encryptedPackets = append(c.encryptedPackets, buf)
- c.log.Debug("handshake not finished, queuing packet")
- }
- return false, nil, nil
- }
- var err error
- buf, err = c.state.cipherSuite.Decrypt(buf)
- if err != nil {
- c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
- return false, nil, nil
- }
- }
- isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
- if err != nil {
- // Decode error must be silently discarded
- // [RFC6347 Section-4.1.2.7]
- c.log.Debugf("defragment failed: %s", err)
- return false, nil, nil
- } else if isHandshake {
- markPacketAsValid()
- for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
- header := &handshake.Header{}
- if err := header.Unmarshal(out); err != nil {
- c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
- continue
- }
- c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
- }
- return true, nil, nil
- }
- r := &recordlayer.RecordLayer{}
- if err := r.Unmarshal(buf); err != nil {
- return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
- }
- switch content := r.Content.(type) {
- case *alert.Alert:
- c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
- var a *alert.Alert
- if content.Description == alert.CloseNotify {
- // Respond with a close_notify [RFC5246 Section 7.2.1]
- a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
- }
- markPacketAsValid()
- return false, a, &alertError{content}
- case *protocol.ChangeCipherSpec:
- if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
- if enqueue {
- c.encryptedPackets = append(c.encryptedPackets, buf)
- c.log.Debugf("CipherSuite not initialized, queuing packet")
- }
- return false, nil, nil
- }
- newRemoteEpoch := h.Epoch + 1
- c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
- if c.state.getRemoteEpoch()+1 == newRemoteEpoch {
- c.setRemoteEpoch(newRemoteEpoch)
- markPacketAsValid()
- }
- case *protocol.ApplicationData:
- if h.Epoch == 0 {
- return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
- }
- markPacketAsValid()
- select {
- case c.decrypted <- content.Data:
- case <-c.closed.Done():
- case <-ctx.Done():
- }
- default:
- return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
- }
- return false, nil, nil
- }
- func (c *Conn) recvHandshake() <-chan chan struct{} {
- return c.handshakeRecv
- }
- func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
- if level == alert.Fatal && len(c.state.SessionID) > 0 {
- // According to the RFC, we need to delete the stored session.
- // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
- if ss := c.fsm.cfg.sessionStore; ss != nil {
- c.log.Tracef("clean invalid session: %s", c.state.SessionID)
- if err := ss.Del(c.sessionKey()); err != nil {
- return err
- }
- }
- }
- return c.writePackets(ctx, []*packet{
- {
- record: &recordlayer.RecordLayer{
- Header: recordlayer.Header{
- Epoch: c.state.getLocalEpoch(),
- Version: protocol.Version1_2,
- },
- Content: &alert.Alert{
- Level: level,
- Description: desc,
- },
- },
- shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
- },
- })
- }
- func (c *Conn) setHandshakeCompletedSuccessfully() {
- c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
- }
- func (c *Conn) isHandshakeCompletedSuccessfully() bool {
- boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
- return boolean.bool
- }
- func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit
- c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)
- done := make(chan struct{})
- ctxRead, cancelRead := context.WithCancel(context.Background())
- c.cancelHandshakeReader = cancelRead
- cfg.onFlightState = func(f flightVal, s handshakeState) {
- if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
- c.setHandshakeCompletedSuccessfully()
- close(done)
- }
- }
- // [Psiphon]
- // Pass dial context into handshake goroutine for GetDTLSSeed.
- ctxHs, cancel := context.WithCancel(ctx)
- c.cancelHandshaker = cancel
- firstErr := make(chan error, 1)
- c.handshakeLoopsFinished.Add(2)
- // Handshake routine should be live until close.
- // The other party may request retransmission of the last flight to cope with packet drop.
- go func() {
- defer c.handshakeLoopsFinished.Done()
- err := c.fsm.Run(ctxHs, c, initialState)
- if !errors.Is(err, context.Canceled) {
- select {
- case firstErr <- err:
- default:
- }
- }
- }()
- go func() {
- defer func() {
- // Escaping read loop.
- // It's safe to close decrypted channnel now.
- close(c.decrypted)
- // Force stop handshaker when the underlying connection is closed.
- cancel()
- }()
- defer c.handshakeLoopsFinished.Done()
- for {
- if err := c.readAndBuffer(ctxRead); err != nil {
- var e *alertError
- if errors.As(err, &e) {
- if !e.IsFatalOrCloseNotify() {
- if c.isHandshakeCompletedSuccessfully() {
- // Pass the error to Read()
- select {
- case c.decrypted <- err:
- case <-c.closed.Done():
- case <-ctxRead.Done():
- }
- }
- continue // non-fatal alert must not stop read loop
- }
- } else {
- switch {
- case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed):
- case errors.Is(err, recordlayer.ErrInvalidPacketLength):
- // Decode error must be silently discarded
- // [RFC6347 Section-4.1.2.7]
- continue
- default:
- if c.isHandshakeCompletedSuccessfully() {
- // Keep read loop and pass the read error to Read()
- select {
- case c.decrypted <- err:
- case <-c.closed.Done():
- case <-ctxRead.Done():
- }
- continue // non-fatal alert must not stop read loop
- }
- }
- }
- select {
- case firstErr <- err:
- default:
- }
- if e != nil {
- if e.IsFatalOrCloseNotify() {
- _ = c.close(false) //nolint:contextcheck
- }
- }
- if !c.isConnectionClosed() && errors.Is(err, context.Canceled) {
- c.log.Trace("handshake timeouts - closing underline connection")
- _ = c.close(false) //nolint:contextcheck
- }
- return
- }
- }
- }()
- select {
- case err := <-firstErr:
- cancelRead()
- cancel()
- c.handshakeLoopsFinished.Wait()
- return c.translateHandshakeCtxError(err)
- case <-ctx.Done():
- cancelRead()
- cancel()
- c.handshakeLoopsFinished.Wait()
- return c.translateHandshakeCtxError(ctx.Err())
- case <-done:
- return nil
- }
- }
- func (c *Conn) translateHandshakeCtxError(err error) error {
- if err == nil {
- return nil
- }
- if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
- return nil
- }
- return &HandshakeError{Err: err}
- }
- func (c *Conn) close(byUser bool) error {
- c.cancelHandshaker()
- c.cancelHandshakeReader()
- if c.isHandshakeCompletedSuccessfully() && byUser {
- // Discard error from notify() to return non-error on the first user call of Close()
- // even if the underlying connection is already closed.
- _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify)
- }
- c.closeLock.Lock()
- // Don't return ErrConnClosed at the first time of the call from user.
- closedByUser := c.connectionClosedByUser
- if byUser {
- c.connectionClosedByUser = true
- }
- isClosed := c.isConnectionClosed()
- c.closed.Close()
- c.closeLock.Unlock()
- if closedByUser {
- return ErrConnClosed
- }
- if isClosed {
- return nil
- }
- return c.nextConn.Close()
- }
- func (c *Conn) isConnectionClosed() bool {
- select {
- case <-c.closed.Done():
- return true
- default:
- return false
- }
- }
- func (c *Conn) setLocalEpoch(epoch uint16) {
- c.state.localEpoch.Store(epoch)
- }
- func (c *Conn) setRemoteEpoch(epoch uint16) {
- c.state.remoteEpoch.Store(epoch)
- }
- // LocalAddr implements net.Conn.LocalAddr
- func (c *Conn) LocalAddr() net.Addr {
- return c.nextConn.LocalAddr()
- }
- // RemoteAddr implements net.Conn.RemoteAddr
- func (c *Conn) RemoteAddr() net.Addr {
- return c.nextConn.RemoteAddr()
- }
- func (c *Conn) sessionKey() []byte {
- if c.state.isClient {
- // As ServerName can be like 0.example.com, it's better to add
- // delimiter character which is not allowed to be in
- // neither address or domain name.
- return []byte(c.nextConn.RemoteAddr().String() + "_" + c.fsm.cfg.serverName)
- }
- return c.state.SessionID
- }
- // SetDeadline implements net.Conn.SetDeadline
- func (c *Conn) SetDeadline(t time.Time) error {
- c.readDeadline.Set(t)
- return c.SetWriteDeadline(t)
- }
- // SetReadDeadline implements net.Conn.SetReadDeadline
- func (c *Conn) SetReadDeadline(t time.Time) error {
- c.readDeadline.Set(t)
- // Read deadline is fully managed by this layer.
- // Don't set read deadline to underlying connection.
- return nil
- }
- // SetWriteDeadline implements net.Conn.SetWriteDeadline
- func (c *Conn) SetWriteDeadline(t time.Time) error {
- c.writeDeadline.Set(t)
- // Write deadline is also fully managed by this layer.
- return nil
- }
|