| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- package dtls
- import (
- "context"
- "crypto/tls"
- "crypto/x509"
- "fmt"
- "io"
- "net"
- "sync"
- "time"
- "github.com/pion/dtls/v2/pkg/crypto/elliptic"
- "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
- "github.com/pion/dtls/v2/pkg/protocol/alert"
- "github.com/pion/dtls/v2/pkg/protocol/handshake"
- "github.com/pion/logging"
- )
- // [RFC6347 Section-4.2.4]
- // +-----------+
- // +---> | PREPARING | <--------------------+
- // | +-----------+ |
- // | | |
- // | | Buffer next flight |
- // | | |
- // | \|/ |
- // | +-----------+ |
- // | | SENDING |<------------------+ | Send
- // | +-----------+ | | HelloRequest
- // Receive | | | |
- // next | | Send flight | | or
- // flight | +--------+ | |
- // | | | Set retransmit timer | | Receive
- // | | \|/ | | HelloRequest
- // | | +-----------+ | | Send
- // +--)--| WAITING |-------------------+ | ClientHello
- // | | +-----------+ Timer expires | |
- // | | | | |
- // | | +------------------------+ |
- // Receive | | Send Read retransmit |
- // last | | last |
- // flight | | flight |
- // | | |
- // \|/\|/ |
- // +-----------+ |
- // | FINISHED | -------------------------------+
- // +-----------+
- // | /|\
- // | |
- // +---+
- // Read retransmit
- // Retransmit last flight
- type handshakeState uint8
- const (
- handshakeErrored handshakeState = iota
- handshakePreparing
- handshakeSending
- handshakeWaiting
- handshakeFinished
- )
- func (s handshakeState) String() string {
- switch s {
- case handshakeErrored:
- return "Errored"
- case handshakePreparing:
- return "Preparing"
- case handshakeSending:
- return "Sending"
- case handshakeWaiting:
- return "Waiting"
- case handshakeFinished:
- return "Finished"
- default:
- return "Unknown"
- }
- }
- type handshakeFSM struct {
- currentFlight flightVal
- flights []*packet
- retransmit bool
- state *State
- cache *handshakeCache
- cfg *handshakeConfig
- closed chan struct{}
- }
- type handshakeConfig struct {
- localPSKCallback PSKCallback
- localPSKIdentityHint []byte
- localCipherSuites []CipherSuite // Available CipherSuites
- localSignatureSchemes []signaturehash.Algorithm // Available signature schemes
- extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension
- localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support
- serverName string
- supportedProtocols []string
- clientAuth ClientAuthType // If we are a client should we request a client certificate
- localCertificates []tls.Certificate
- nameToCertificate map[string]*tls.Certificate
- insecureSkipVerify bool
- verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
- verifyConnection func(*State) error
- sessionStore SessionStore
- rootCAs *x509.CertPool
- clientCAs *x509.CertPool
- retransmitInterval time.Duration
- customCipherSuites func() []CipherSuite
- ellipticCurves []elliptic.Curve
- insecureSkipHelloVerify bool
- // [Psiphon]
- // Conjure DTLS support, from: https://github.com/mingyech/dtls/commit/a56eccc1
- customClientHelloRandom func() [handshake.RandomBytesLength]byte
- onFlightState func(flightVal, handshakeState)
- log logging.LeveledLogger
- keyLogWriter io.Writer
- localGetCertificate func(*ClientHelloInfo) (*tls.Certificate, error)
- localGetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error)
- initialEpoch uint16
- mu sync.Mutex
- }
- type flightConn interface {
- notify(ctx context.Context, level alert.Level, desc alert.Description) error
- writePackets(context.Context, []*packet) error
- recvHandshake() <-chan chan struct{}
- setLocalEpoch(epoch uint16)
- handleQueuedPackets(context.Context) error
- sessionKey() []byte
- // [Psiphon]
- LocalAddr() net.Addr
- }
- func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) {
- if c.keyLogWriter == nil {
- return
- }
- c.mu.Lock()
- defer c.mu.Unlock()
- _, err := c.keyLogWriter.Write([]byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)))
- if err != nil {
- c.log.Debugf("failed to write key log file: %s", err)
- }
- }
- func srvCliStr(isClient bool) string {
- if isClient {
- return "client"
- }
- return "server"
- }
- func newHandshakeFSM(
- s *State, cache *handshakeCache, cfg *handshakeConfig,
- initialFlight flightVal,
- ) *handshakeFSM {
- return &handshakeFSM{
- currentFlight: initialFlight,
- state: s,
- cache: cache,
- cfg: cfg,
- closed: make(chan struct{}),
- }
- }
- func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState handshakeState) error {
- state := initialState
- defer func() {
- close(s.closed)
- }()
- for {
- s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String())
- if s.cfg.onFlightState != nil {
- s.cfg.onFlightState(s.currentFlight, state)
- }
- var err error
- switch state {
- case handshakePreparing:
- state, err = s.prepare(ctx, c)
- case handshakeSending:
- state, err = s.send(ctx, c)
- case handshakeWaiting:
- state, err = s.wait(ctx, c)
- case handshakeFinished:
- state, err = s.finish(ctx, c)
- default:
- return errInvalidFSMTransition
- }
- if err != nil {
- return err
- }
- }
- }
- func (s *handshakeFSM) Done() <-chan struct{} {
- return s.closed
- }
- func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeState, error) {
- s.flights = nil
- // Prepare flights
- var (
- a *alert.Alert
- err error
- pkts []*packet
- )
- gen, retransmit, errFlight := s.currentFlight.getFlightGenerator()
- if errFlight != nil {
- err = errFlight
- a = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}
- } else {
- // [Psiphon]
- // Pass in dial context for GetDTLSSeed.
- pkts, a, err = gen(ctx, c, s.state, s.cache, s.cfg)
- s.retransmit = retransmit
- }
- if a != nil {
- if alertErr := c.notify(ctx, a.Level, a.Description); alertErr != nil {
- if err != nil {
- err = alertErr
- }
- }
- }
- if err != nil {
- return handshakeErrored, err
- }
- s.flights = pkts
- epoch := s.cfg.initialEpoch
- nextEpoch := epoch
- for _, p := range s.flights {
- p.record.Header.Epoch += epoch
- if p.record.Header.Epoch > nextEpoch {
- nextEpoch = p.record.Header.Epoch
- }
- if h, ok := p.record.Content.(*handshake.Handshake); ok {
- h.Header.MessageSequence = uint16(s.state.handshakeSendSequence)
- s.state.handshakeSendSequence++
- }
- }
- if epoch != nextEpoch {
- s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch)
- c.setLocalEpoch(nextEpoch)
- }
- return handshakeSending, nil
- }
- func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) {
- // Send flights
- if err := c.writePackets(ctx, s.flights); err != nil {
- return handshakeErrored, err
- }
- if s.currentFlight.isLastSendFlight() {
- return handshakeFinished, nil
- }
- return handshakeWaiting, nil
- }
- func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit
- parse, errFlight := s.currentFlight.getFlightParser()
- if errFlight != nil {
- if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
- if errFlight != nil {
- return handshakeErrored, alertErr
- }
- }
- return handshakeErrored, errFlight
- }
- retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
- for {
- select {
- case done := <-c.recvHandshake():
- nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
- close(done)
- if alert != nil {
- if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
- if err != nil {
- err = alertErr
- }
- }
- }
- if err != nil {
- return handshakeErrored, err
- }
- if nextFlight == 0 {
- break
- }
- s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String())
- if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
- return handshakeFinished, nil
- }
- s.currentFlight = nextFlight
- return handshakePreparing, nil
- case <-retransmitTimer.C:
- if !s.retransmit {
- return handshakeWaiting, nil
- }
- return handshakeSending, nil
- case <-ctx.Done():
- return handshakeErrored, ctx.Err()
- }
- }
- }
- func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) {
- parse, errFlight := s.currentFlight.getFlightParser()
- if errFlight != nil {
- if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
- if errFlight != nil {
- return handshakeErrored, alertErr
- }
- }
- return handshakeErrored, errFlight
- }
- retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
- select {
- case done := <-c.recvHandshake():
- nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
- close(done)
- if alert != nil {
- if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
- if err != nil {
- err = alertErr
- }
- }
- }
- if err != nil {
- return handshakeErrored, err
- }
- if nextFlight == 0 {
- break
- }
- if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
- return handshakeFinished, nil
- }
- <-retransmitTimer.C
- // Retransmit last flight
- return handshakeSending, nil
- case <-ctx.Done():
- return handshakeErrored, ctx.Err()
- }
- return handshakeFinished, nil
- }
|