flight3handler.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. package dtls
  2. import (
  3. "bytes"
  4. "context"
  5. "github.com/pion/dtls/v2/internal/ciphersuite/types"
  6. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  7. "github.com/pion/dtls/v2/pkg/crypto/prf"
  8. "github.com/pion/dtls/v2/pkg/protocol"
  9. "github.com/pion/dtls/v2/pkg/protocol/alert"
  10. "github.com/pion/dtls/v2/pkg/protocol/extension"
  11. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  12. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  13. )
  14. func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit
  15. // Clients may receive multiple HelloVerifyRequest messages with different cookies.
  16. // Clients SHOULD handle this by sending a new ClientHello with a cookie in response
  17. // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1
  18. seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  19. handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true},
  20. )
  21. if ok {
  22. if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk {
  23. // DTLS 1.2 clients must not assume that the server will use the protocol version
  24. // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1
  25. if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) {
  26. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
  27. }
  28. state.cookie = append([]byte{}, h.Cookie...)
  29. state.handshakeRecvSequence = seq
  30. return flight3, nil, nil
  31. }
  32. }
  33. _, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite,
  34. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  35. )
  36. if !ok {
  37. // Don't have enough messages. Keep reading
  38. return 0, nil, nil
  39. }
  40. if h, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk {
  41. if !h.Version.Equal(protocol.Version1_2) {
  42. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion
  43. }
  44. for _, v := range h.Extensions {
  45. switch e := v.(type) {
  46. case *extension.UseSRTP:
  47. profile, found := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles)
  48. if !found {
  49. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile
  50. }
  51. state.srtpProtectionProfile = profile
  52. case *extension.UseExtendedMasterSecret:
  53. if cfg.extendedMasterSecret != DisableExtendedMasterSecret {
  54. state.extendedMasterSecret = true
  55. }
  56. case *extension.ALPN:
  57. if len(e.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling
  58. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, extension.ErrALPNInvalidFormat // Meh, internal error?
  59. }
  60. state.NegotiatedProtocol = e.ProtocolNameList[0]
  61. }
  62. }
  63. if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret {
  64. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS
  65. }
  66. if len(cfg.localSRTPProtectionProfiles) > 0 && state.srtpProtectionProfile == 0 {
  67. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension
  68. }
  69. remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*h.CipherSuiteID), cfg.customCipherSuites)
  70. if remoteCipherSuite == nil {
  71. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection
  72. }
  73. selectedCipherSuite, found := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites)
  74. if !found {
  75. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
  76. }
  77. state.cipherSuite = selectedCipherSuite
  78. state.remoteRandom = h.Random
  79. cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String())
  80. if len(h.SessionID) > 0 && bytes.Equal(state.SessionID, h.SessionID) {
  81. return handleResumption(ctx, c, state, cache, cfg)
  82. }
  83. if len(state.SessionID) > 0 {
  84. cfg.log.Tracef("[handshake] clean old session : %s", state.SessionID)
  85. if err := cfg.sessionStore.Del(state.SessionID); err != nil {
  86. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  87. }
  88. }
  89. if cfg.sessionStore == nil {
  90. state.SessionID = []byte{}
  91. } else {
  92. state.SessionID = h.SessionID
  93. }
  94. state.masterSecret = []byte{}
  95. }
  96. if cfg.localPSKCallback != nil {
  97. seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
  98. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true},
  99. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  100. )
  101. } else {
  102. seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
  103. handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true},
  104. handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false},
  105. handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true},
  106. handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false},
  107. )
  108. }
  109. if !ok {
  110. // Don't have enough messages. Keep reading
  111. return 0, nil, nil
  112. }
  113. state.handshakeRecvSequence = seq
  114. if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok {
  115. state.PeerCertificates = h.Certificate
  116. } else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate {
  117. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate
  118. }
  119. if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok {
  120. alertPtr, err := handleServerKeyExchange(c, state, cfg, h)
  121. if err != nil {
  122. return 0, alertPtr, err
  123. }
  124. }
  125. if _, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok {
  126. state.remoteRequestedCertificate = true
  127. }
  128. return flight5, nil, nil
  129. }
  130. func handleResumption(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) {
  131. if err := state.initCipherSuite(); err != nil {
  132. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  133. }
  134. // Now, encrypted packets can be handled
  135. if err := c.handleQueuedPackets(ctx); err != nil {
  136. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  137. }
  138. _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite,
  139. handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false},
  140. )
  141. if !ok {
  142. // No valid message received. Keep reading
  143. return 0, nil, nil
  144. }
  145. var finished *handshake.MessageFinished
  146. if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok {
  147. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil
  148. }
  149. plainText := cache.pullAndMerge(
  150. handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false},
  151. handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false},
  152. )
  153. expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc())
  154. if err != nil {
  155. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  156. }
  157. if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
  158. return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
  159. }
  160. clientRandom := state.localRandom.MarshalFixed()
  161. cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)
  162. return flight5b, nil, nil
  163. }
  164. func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) {
  165. var err error
  166. if state.cipherSuite == nil {
  167. return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
  168. }
  169. if cfg.localPSKCallback != nil {
  170. var psk []byte
  171. if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil {
  172. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  173. }
  174. state.IdentityHint = h.IdentityHint
  175. switch state.cipherSuite.KeyExchangeAlgorithm() {
  176. case types.KeyExchangeAlgorithmPsk:
  177. state.preMasterSecret = prf.PSKPreMasterSecret(psk)
  178. case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk):
  179. if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil {
  180. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  181. }
  182. state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve)
  183. if err != nil {
  184. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  185. }
  186. default:
  187. return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
  188. }
  189. } else {
  190. if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil {
  191. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  192. }
  193. if state.preMasterSecret, err = prf.PreMasterSecret(h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil {
  194. return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
  195. }
  196. }
  197. return nil, nil //nolint:nilnil
  198. }
  199. func flight3Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
  200. extensions := []extension.Extension{
  201. &extension.SupportedSignatureAlgorithms{
  202. SignatureHashAlgorithms: cfg.localSignatureSchemes,
  203. },
  204. &extension.RenegotiationInfo{
  205. RenegotiatedConnection: 0,
  206. },
  207. }
  208. if state.namedCurve != 0 {
  209. extensions = append(extensions, []extension.Extension{
  210. &extension.SupportedEllipticCurves{
  211. EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384},
  212. },
  213. &extension.SupportedPointFormats{
  214. PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed},
  215. },
  216. }...)
  217. }
  218. if len(cfg.localSRTPProtectionProfiles) > 0 {
  219. extensions = append(extensions, &extension.UseSRTP{
  220. ProtectionProfiles: cfg.localSRTPProtectionProfiles,
  221. })
  222. }
  223. if cfg.extendedMasterSecret == RequestExtendedMasterSecret ||
  224. cfg.extendedMasterSecret == RequireExtendedMasterSecret {
  225. extensions = append(extensions, &extension.UseExtendedMasterSecret{
  226. Supported: true,
  227. })
  228. }
  229. if len(cfg.serverName) > 0 {
  230. extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName})
  231. }
  232. if len(cfg.supportedProtocols) > 0 {
  233. extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols})
  234. }
  235. return []*packet{
  236. {
  237. record: &recordlayer.RecordLayer{
  238. Header: recordlayer.Header{
  239. Version: protocol.Version1_2,
  240. },
  241. Content: &handshake.Handshake{
  242. Message: &handshake.MessageClientHello{
  243. Version: protocol.Version1_2,
  244. SessionID: state.SessionID,
  245. Cookie: state.cookie,
  246. Random: state.localRandom,
  247. CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites),
  248. CompressionMethods: defaultCompressionMethods(),
  249. Extensions: extensions,
  250. },
  251. },
  252. },
  253. },
  254. }, nil, nil
  255. }