flight3handler.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "bytes"
  6. "context"
  7. "github.com/pion/dtls/v2/internal/ciphersuite/types"
  8. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  9. "github.com/pion/dtls/v2/pkg/crypto/prf"
  10. "github.com/pion/dtls/v2/pkg/protocol"
  11. "github.com/pion/dtls/v2/pkg/protocol/alert"
  12. "github.com/pion/dtls/v2/pkg/protocol/extension"
  13. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  14. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  15. )
  16. func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit
  17. // Clients may receive multiple HelloVerifyRequest messages with different cookies.
  18. // Clients SHOULD handle this by sending a new ClientHello with a cookie in response
  19. // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1
  20. seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  21. handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
  22. )
  23. if ok {
  24. if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk {
  25. // DTLS 1.2 clients must not assume that the server will use the protocol version
  26. // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
  27. if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
  28. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
  29. }
  30. state.cookie = append([]byte{}, h.Cookie...)
  31. state.handshakeRecvSequence = seq
  32. return flight3, nil, nil
  33. }
  34. }
  35. _, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  36. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  37. )
  38. if !ok {
  39. // Don't have enough messages. Keep reading
  40. return 0, nil, nil
  41. }
  42. if h, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk {
  43. if !h.Version.Equal(protocol.Version1_2) {
  44. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
  45. }
  46. for _, v := range h.Extensions {
  47. switch e := v.(type) {
  48. case *extension.UseSRTP:
  49. profile, found := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
  50. if !found {
  51. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile
  52. }
  53. state.srtpProtectionProfile = profile
  54. case *extension.UseExtendedMasterSecret:
  55. if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
  56. state.extendedMasterSecret = true
  57. }
  58. case *extension.ALPN:
  59. if len(e.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling
  60. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, extension.ErrALPNInvalidFormat // Meh, internal error?
  61. }
  62. state.NegotiatedProtocol = e.ProtocolNameList[0]
  63. }
  64. }
  65. if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
  66. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS
  67. }
  68. if len(cfg.localSRTPProtectionProfiles) > 0 && state.srtpProtectionProfile == 0 {
  69. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension
  70. }
  71. remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*h.CipherSuiteID), cfg.customCipherSuites)
  72. if remoteCipherSuite == nil {
  73. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
  74. }
  75. selectedCipherSuite, found := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites)
  76. if !found {
  77. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
  78. }
  79. state.cipherSuite = selectedCipherSuite
  80. state.remoteRandom = h.Random
  81. cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String())
  82. if len(h.SessionID) > 0 && bytes.Equal(state.SessionID, h.SessionID) {
  83. return handleResumption(ctx, c, state, cache, cfg)
  84. }
  85. if len(state.SessionID) > 0 {
  86. cfg.log.Tracef("[handshake] clean old session : %s", state.SessionID)
  87. if err := cfg.sessionStore.Del(state.SessionID); err != nil {
  88. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  89. }
  90. }
  91. if cfg.sessionStore == nil {
  92. state.SessionID = []byte{}
  93. } else {
  94. state.SessionID = h.SessionID
  95. }
  96. state.masterSecret = []byte{}
  97. }
  98. if cfg.localPSKCallback != nil {
  99. seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
  100. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true},
  101. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  102. )
  103. } else {
  104. seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
  105. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true},
  106. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
  107. handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true},
  108. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  109. )
  110. }
  111. if !ok {
  112. // Don't have enough messages. Keep reading
  113. return 0, nil, nil
  114. }
  115. state.handshakeRecvSequence = seq
  116. if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok {
  117. state.PeerCertificates = h.Certificate
  118. } else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
  119. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate
  120. }
  121. if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok {
  122. alertPtr, err := handleServerKeyExchange(c, state, cfg, h)
  123. if err != nil {
  124. return 0, alertPtr, err
  125. }
  126. }
  127. if _, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok {
  128. state.remoteRequestedCertificate = true
  129. }
  130. return flight5, nil, nil
  131. }
  132. func handleResumption(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
  133. if err := state.initCipherSuite(); err != nil {
  134. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  135. }
  136. // Now, encrypted packets can be handled
  137. if err := c.handleQueuedPackets(ctx); err != nil {
  138. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  139. }
  140. _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
  141. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
  142. )
  143. if !ok {
  144. // No valid message received. Keep reading
  145. return 0, nil, nil
  146. }
  147. var finished *handshake.MessageFinished
  148. if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
  149. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  150. }
  151. plainText := cache.pullAndMerge(
  152. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  153. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  154. )
  155. expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
  156. if err != nil {
  157. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  158. }
  159. if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
  160. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
  161. }
  162. clientRandom := state.localRandom.MarshalFixed()
  163. cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
  164. return flight5b, nil, nil
  165. }
  166. func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) {
  167. var err error
  168. if state.cipherSuite == nil {
  169. return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
  170. }
  171. if cfg.localPSKCallback != nil {
  172. var psk []byte
  173. if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil {
  174. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  175. }
  176. state.IdentityHint = h.IdentityHint
  177. switch state.cipherSuite.KeyExchangeAlgorithm() {
  178. case types.KeyExchangeAlgorithmPsk:
  179. state.preMasterSecret = prf.PSKPreMasterSecret(psk)
  180. case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk):
  181. if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil {
  182. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  183. }
  184. state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve)
  185. if err != nil {
  186. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  187. }
  188. default:
  189. return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
  190. }
  191. } else {
  192. if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil {
  193. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  194. }
  195. if state.preMasterSecret, err = prf.PreMasterSecret(h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil {
  196. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  197. }
  198. }
  199. return nil, nil //nolint:nilnil
  200. }
  201. func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
  202. extensions := []extension.Extension{
  203. &extension.SupportedSignatureAlgorithms{
  204. SignatureHashAlgorithms: cfg.localSignatureSchemes,
  205. },
  206. &extension.RenegotiationInfo{
  207. RenegotiatedConnection: 0,
  208. },
  209. }
  210. if state.namedCurve != 0 {
  211. extensions = append(extensions, []extension.Extension{
  212. &extension.SupportedEllipticCurves{
  213. EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
  214. },
  215. &extension.SupportedPointFormats{
  216. PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
  217. },
  218. }...)
  219. }
  220. if len(cfg.localSRTPProtectionProfiles) > 0 {
  221. extensions = append(extensions, &extension.UseSRTP{
  222. ProtectionProfiles: cfg.localSRTPProtectionProfiles,
  223. })
  224. }
  225. if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
  226. cfg.extendedMasterSecret == RequireExtendedMasterSecret {
  227. extensions = append(extensions, &extension.UseExtendedMasterSecret{
  228. Supported: true,
  229. })
  230. }
  231. if len(cfg.serverName) > 0 {
  232. extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
  233. }
  234. if len(cfg.supportedProtocols) > 0 {
  235. extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols})
  236. }
  237. return []*packet{
  238. {
  239. record: &recordlayer.RecordLayer{
  240. Header: recordlayer.Header{
  241. Version: protocol.Version1_2,
  242. },
  243. Content: &handshake.Handshake{
  244. Message: &handshake.MessageClientHello{
  245. Version: protocol.Version1_2,
  246. SessionID: state.SessionID,
  247. Cookie: state.cookie,
  248. Random: state.localRandom,
  249. CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
  250. CompressionMethods: defaultCompressionMethods(),
  251. Extensions: extensions,
  252. },
  253. },
  254. },
  255. },
  256. }, nil, nil
  257. }