flight3handler.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "bytes"
  6. "context"
  7. "errors"
  8. "github.com/pion/dtls/v2/internal/ciphersuite/types"
  9. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  10. "github.com/pion/dtls/v2/pkg/crypto/prf"
  11. "github.com/pion/dtls/v2/pkg/protocol"
  12. "github.com/pion/dtls/v2/pkg/protocol/alert"
  13. "github.com/pion/dtls/v2/pkg/protocol/extension"
  14. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  15. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  16. )
  17. func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit
  18. // Clients may receive multiple HelloVerifyRequest messages with different cookies.
  19. // Clients SHOULD handle this by sending a new ClientHello with a cookie in response
  20. // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1
  21. seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  22. handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
  23. )
  24. if ok {
  25. if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk {
  26. // DTLS 1.2 clients must not assume that the server will use the protocol version
  27. // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
  28. if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
  29. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
  30. }
  31. state.cookie = append([]byte{}, h.Cookie...)
  32. state.handshakeRecvSequence = seq
  33. return flight3, nil, nil
  34. }
  35. }
  36. _, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  37. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  38. )
  39. if !ok {
  40. // Don't have enough messages. Keep reading
  41. return 0, nil, nil
  42. }
  43. if h, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk {
  44. if !h.Version.Equal(protocol.Version1_2) {
  45. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
  46. }
  47. for _, v := range h.Extensions {
  48. switch e := v.(type) {
  49. case *extension.UseSRTP:
  50. profile, found := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
  51. if !found {
  52. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile
  53. }
  54. state.setSRTPProtectionProfile(profile)
  55. case *extension.UseExtendedMasterSecret:
  56. if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
  57. state.extendedMasterSecret = true
  58. }
  59. case *extension.ALPN:
  60. if len(e.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling
  61. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, extension.ErrALPNInvalidFormat // Meh, internal error?
  62. }
  63. state.NegotiatedProtocol = e.ProtocolNameList[0]
  64. }
  65. }
  66. if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
  67. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS
  68. }
  69. if len(cfg.localSRTPProtectionProfiles) > 0 && state.getSRTPProtectionProfile() == 0 {
  70. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension
  71. }
  72. remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*h.CipherSuiteID), cfg.customCipherSuites)
  73. if remoteCipherSuite == nil {
  74. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
  75. }
  76. selectedCipherSuite, found := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites)
  77. if !found {
  78. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
  79. }
  80. state.cipherSuite = selectedCipherSuite
  81. state.remoteRandom = h.Random
  82. cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String())
  83. if len(h.SessionID) > 0 && bytes.Equal(state.SessionID, h.SessionID) {
  84. return handleResumption(ctx, c, state, cache, cfg)
  85. }
  86. if len(state.SessionID) > 0 {
  87. cfg.log.Tracef("[handshake] clean old session : %s", state.SessionID)
  88. if err := cfg.sessionStore.Del(state.SessionID); err != nil {
  89. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  90. }
  91. }
  92. if cfg.sessionStore == nil {
  93. state.SessionID = []byte{}
  94. } else {
  95. state.SessionID = h.SessionID
  96. }
  97. state.masterSecret = []byte{}
  98. }
  99. if cfg.localPSKCallback != nil {
  100. seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
  101. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true},
  102. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  103. )
  104. } else {
  105. seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
  106. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true},
  107. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
  108. handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true},
  109. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  110. )
  111. }
  112. if !ok {
  113. // Don't have enough messages. Keep reading
  114. return 0, nil, nil
  115. }
  116. state.handshakeRecvSequence = seq
  117. if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok {
  118. state.PeerCertificates = h.Certificate
  119. } else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
  120. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate
  121. }
  122. if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok {
  123. alertPtr, err := handleServerKeyExchange(c, state, cfg, h)
  124. if err != nil {
  125. return 0, alertPtr, err
  126. }
  127. }
  128. if _, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok {
  129. state.remoteRequestedCertificate = true
  130. }
  131. return flight5, nil, nil
  132. }
  133. func handleResumption(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
  134. if err := state.initCipherSuite(); err != nil {
  135. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  136. }
  137. // Now, encrypted packets can be handled
  138. if err := c.handleQueuedPackets(ctx); err != nil {
  139. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  140. }
  141. _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
  142. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
  143. )
  144. if !ok {
  145. // No valid message received. Keep reading
  146. return 0, nil, nil
  147. }
  148. var finished *handshake.MessageFinished
  149. if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
  150. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  151. }
  152. plainText := cache.pullAndMerge(
  153. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  154. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  155. )
  156. expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
  157. if err != nil {
  158. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  159. }
  160. if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
  161. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
  162. }
  163. clientRandom := state.localRandom.MarshalFixed()
  164. cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
  165. return flight5b, nil, nil
  166. }
  167. func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) {
  168. var err error
  169. if state.cipherSuite == nil {
  170. return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
  171. }
  172. if cfg.localPSKCallback != nil {
  173. var psk []byte
  174. if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil {
  175. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  176. }
  177. state.IdentityHint = h.IdentityHint
  178. switch state.cipherSuite.KeyExchangeAlgorithm() {
  179. case types.KeyExchangeAlgorithmPsk:
  180. state.preMasterSecret = prf.PSKPreMasterSecret(psk)
  181. case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk):
  182. if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil {
  183. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  184. }
  185. state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve)
  186. if err != nil {
  187. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  188. }
  189. default:
  190. return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
  191. }
  192. } else {
  193. if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil {
  194. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  195. }
  196. if state.preMasterSecret, err = prf.PreMasterSecret(h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil {
  197. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  198. }
  199. }
  200. return nil, nil //nolint:nilnil
  201. }
  202. func flight3Generate(_ context.Context, _ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
  203. // [Psiphon]
  204. // With SetDTLSInsecureSkipHelloVerify set, this should never be called,
  205. // so handshake randomization is not implemented here.
  206. return nil, nil, errors.New("unexpected flight3Generate call")
  207. extensions := []extension.Extension{
  208. &extension.SupportedSignatureAlgorithms{
  209. SignatureHashAlgorithms: cfg.localSignatureSchemes,
  210. },
  211. &extension.RenegotiationInfo{
  212. RenegotiatedConnection: 0,
  213. },
  214. }
  215. if state.namedCurve != 0 {
  216. extensions = append(extensions, []extension.Extension{
  217. &extension.SupportedEllipticCurves{
  218. EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
  219. },
  220. &extension.SupportedPointFormats{
  221. PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
  222. },
  223. }...)
  224. }
  225. if len(cfg.localSRTPProtectionProfiles) > 0 {
  226. extensions = append(extensions, &extension.UseSRTP{
  227. ProtectionProfiles: cfg.localSRTPProtectionProfiles,
  228. })
  229. }
  230. if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
  231. cfg.extendedMasterSecret == RequireExtendedMasterSecret {
  232. extensions = append(extensions, &extension.UseExtendedMasterSecret{
  233. Supported: true,
  234. })
  235. }
  236. if len(cfg.serverName) > 0 {
  237. extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
  238. }
  239. if len(cfg.supportedProtocols) > 0 {
  240. extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols})
  241. }
  242. return []*packet{
  243. {
  244. record: &recordlayer.RecordLayer{
  245. Header: recordlayer.Header{
  246. Version: protocol.Version1_2,
  247. },
  248. Content: &handshake.Handshake{
  249. Message: &handshake.MessageClientHello{
  250. Version: protocol.Version1_2,
  251. SessionID: state.SessionID,
  252. Cookie: state.cookie,
  253. Random: state.localRandom,
  254. CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
  255. CompressionMethods: defaultCompressionMethods(),
  256. Extensions: extensions,
  257. },
  258. },
  259. },
  260. },
  261. }, nil, nil
  262. }