handshaker_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "bytes"
  6. "context"
  7. "crypto/tls"
  8. "errors"
  9. "sync"
  10. "testing"
  11. "time"
  12. "github.com/pion/dtls/v2/pkg/crypto/selfsign"
  13. "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
  14. "github.com/pion/dtls/v2/pkg/protocol/alert"
  15. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  16. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  17. "github.com/pion/logging"
  18. "github.com/pion/transport/v2/test"
  19. )
  20. const nonZeroRetransmitInterval = 100 * time.Millisecond
  21. // Test that writes to the key log are in the correct format and only applies
  22. // when a key log writer is given.
  23. func TestWriteKeyLog(t *testing.T) {
  24. var buf bytes.Buffer
  25. cfg := handshakeConfig{
  26. keyLogWriter: &buf,
  27. }
  28. cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF})
  29. // Secrets follow the format <Label> <space> <ClientRandom> <space> <Secret>
  30. // https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format
  31. want := "LABEL aabbcc ddeeff\n"
  32. if buf.String() != want {
  33. t.Fatalf("Got %s want %s", buf.String(), want)
  34. }
  35. // no key log writer = no writes
  36. cfg = handshakeConfig{}
  37. cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF})
  38. }
  39. func TestHandshaker(t *testing.T) {
  40. // Check for leaking routines
  41. report := test.CheckRoutines(t)
  42. defer report()
  43. loggerFactory := logging.NewDefaultLoggerFactory()
  44. logger := loggerFactory.NewLogger("dtls")
  45. cipherSuites, err := parseCipherSuites(nil, nil, true, false)
  46. if err != nil {
  47. t.Fatal(err)
  48. }
  49. clientCert, err := selfsign.GenerateSelfSigned()
  50. if err != nil {
  51. t.Fatal(err)
  52. }
  53. genFilters := map[string]func() (TestEndpoint, TestEndpoint, func(t *testing.T)){
  54. "PassThrough": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) {
  55. return TestEndpoint{}, TestEndpoint{}, nil
  56. },
  57. "HelloVerifyRequestLost": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) {
  58. var (
  59. cntHelloVerifyRequest = 0
  60. cntClientHelloNoCookie = 0
  61. )
  62. const helloVerifyDrop = 5
  63. clientEndpoint := TestEndpoint{
  64. Filter: func(p *packet) bool {
  65. h, ok := p.record.Content.(*handshake.Handshake)
  66. if !ok {
  67. return true
  68. }
  69. if hmch, ok := h.Message.(*handshake.MessageClientHello); ok {
  70. if len(hmch.Cookie) == 0 {
  71. cntClientHelloNoCookie++
  72. }
  73. }
  74. return true
  75. },
  76. }
  77. serverEndpoint := TestEndpoint{
  78. Filter: func(p *packet) bool {
  79. h, ok := p.record.Content.(*handshake.Handshake)
  80. if !ok {
  81. return true
  82. }
  83. if _, ok := h.Message.(*handshake.MessageHelloVerifyRequest); ok {
  84. cntHelloVerifyRequest++
  85. return cntHelloVerifyRequest > helloVerifyDrop
  86. }
  87. return true
  88. },
  89. }
  90. report := func(t *testing.T) {
  91. if cntHelloVerifyRequest != helloVerifyDrop+1 {
  92. t.Errorf("Number of HelloVerifyRequest retransmit is wrong, expected: %d times, got: %d times", helloVerifyDrop+1, cntHelloVerifyRequest)
  93. }
  94. if cntClientHelloNoCookie != cntHelloVerifyRequest {
  95. t.Errorf(
  96. "HelloVerifyRequest must be triggered only by ClientHello, but HelloVerifyRequest was sent %d times and ClientHello was sent %d times",
  97. cntHelloVerifyRequest, cntClientHelloNoCookie,
  98. )
  99. }
  100. }
  101. return clientEndpoint, serverEndpoint, report
  102. },
  103. "NoLatencyTest": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) {
  104. var (
  105. cntClientFinished = 0
  106. cntServerFinished = 0
  107. )
  108. clientEndpoint := TestEndpoint{
  109. Filter: func(p *packet) bool {
  110. h, ok := p.record.Content.(*handshake.Handshake)
  111. if !ok {
  112. return true
  113. }
  114. if _, ok := h.Message.(*handshake.MessageFinished); ok {
  115. cntClientFinished++
  116. }
  117. return true
  118. },
  119. }
  120. serverEndpoint := TestEndpoint{
  121. Filter: func(p *packet) bool {
  122. h, ok := p.record.Content.(*handshake.Handshake)
  123. if !ok {
  124. return true
  125. }
  126. if _, ok := h.Message.(*handshake.MessageFinished); ok {
  127. cntServerFinished++
  128. }
  129. return true
  130. },
  131. }
  132. report := func(t *testing.T) {
  133. if cntClientFinished != 1 {
  134. t.Errorf("Number of client finished is wrong, expected: %d times, got: %d times", 1, cntClientFinished)
  135. }
  136. if cntServerFinished != 1 {
  137. t.Errorf("Number of server finished is wrong, expected: %d times, got: %d times", 1, cntServerFinished)
  138. }
  139. }
  140. return clientEndpoint, serverEndpoint, report
  141. },
  142. "SlowServerTest": func() (TestEndpoint, TestEndpoint, func(t *testing.T)) {
  143. var (
  144. cntClientFinished = 0
  145. isClientFinished = false
  146. cntClientFinishedLastRetransmit = 0
  147. cntServerFinished = 0
  148. isServerFinished = false
  149. cntServerFinishedLastRetransmit = 0
  150. )
  151. clientEndpoint := TestEndpoint{
  152. Filter: func(p *packet) bool {
  153. h, ok := p.record.Content.(*handshake.Handshake)
  154. if !ok {
  155. return true
  156. }
  157. if _, ok := h.Message.(*handshake.MessageFinished); ok {
  158. if isClientFinished {
  159. cntClientFinishedLastRetransmit++
  160. } else {
  161. cntClientFinished++
  162. }
  163. }
  164. return true
  165. },
  166. Delay: 0,
  167. OnFinished: func() {
  168. isClientFinished = true
  169. },
  170. FinishWait: 2000 * time.Millisecond,
  171. }
  172. serverEndpoint := TestEndpoint{
  173. Filter: func(p *packet) bool {
  174. h, ok := p.record.Content.(*handshake.Handshake)
  175. if !ok {
  176. return true
  177. }
  178. if _, ok := h.Message.(*handshake.MessageFinished); ok {
  179. if isServerFinished {
  180. cntServerFinishedLastRetransmit++
  181. } else {
  182. cntServerFinished++
  183. }
  184. }
  185. return true
  186. },
  187. Delay: 1000 * time.Millisecond,
  188. OnFinished: func() {
  189. isServerFinished = true
  190. },
  191. FinishWait: 2000 * time.Millisecond,
  192. }
  193. report := func(t *testing.T) {
  194. // with one second server delay and 100 ms retransmit, there should be close to 10 `Finished` from client
  195. // using a range of 9 - 11 for checking
  196. if cntClientFinished < 8 || cntClientFinished > 11 {
  197. t.Errorf("Number of client finished is wrong, expected: %d - %d times, got: %d times", 9, 11, cntClientFinished)
  198. }
  199. if !isClientFinished {
  200. t.Errorf("Client is not finished")
  201. }
  202. // there should be no `Finished` last retransmit from client
  203. if cntClientFinishedLastRetransmit != 0 {
  204. t.Errorf("Number of client finished last retransmit is wrong, expected: %d times, got: %d times", 0, cntClientFinishedLastRetransmit)
  205. }
  206. if cntServerFinished < 1 {
  207. t.Errorf("Number of server finished is wrong, expected: at least %d times, got: %d times", 1, cntServerFinished)
  208. }
  209. if !isServerFinished {
  210. t.Errorf("Server is not finished")
  211. }
  212. // there should be `Finished` last retransmit from server. Because of slow server, client would have sent several `Finished`.
  213. if cntServerFinishedLastRetransmit < 1 {
  214. t.Errorf("Number of server finished last retransmit is wrong, expected: at least %d times, got: %d times", 1, cntServerFinishedLastRetransmit)
  215. }
  216. }
  217. return clientEndpoint, serverEndpoint, report
  218. },
  219. }
  220. for name, filters := range genFilters {
  221. clientEndpoint, serverEndpoint, report := filters()
  222. t.Run(name, func(t *testing.T) {
  223. ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
  224. defer cancel()
  225. if report != nil {
  226. defer report(t)
  227. }
  228. ca, cb := flightTestPipe(ctx, clientEndpoint, serverEndpoint)
  229. ca.state.isClient = true
  230. var wg sync.WaitGroup
  231. wg.Add(2)
  232. ctxCliFinished, cancelCli := context.WithCancel(ctx)
  233. ctxSrvFinished, cancelSrv := context.WithCancel(ctx)
  234. go func() {
  235. defer wg.Done()
  236. cfg := &handshakeConfig{
  237. localCipherSuites: cipherSuites,
  238. localCertificates: []tls.Certificate{clientCert},
  239. ellipticCurves: defaultCurves,
  240. localSignatureSchemes: signaturehash.Algorithms(),
  241. insecureSkipVerify: true,
  242. log: logger,
  243. onFlightState: func(f flightVal, s handshakeState) {
  244. if s == handshakeFinished {
  245. if clientEndpoint.OnFinished != nil {
  246. clientEndpoint.OnFinished()
  247. }
  248. time.AfterFunc(clientEndpoint.FinishWait, func() {
  249. cancelCli()
  250. })
  251. }
  252. },
  253. retransmitInterval: nonZeroRetransmitInterval,
  254. }
  255. fsm := newHandshakeFSM(&ca.state, ca.handshakeCache, cfg, flight1)
  256. err := fsm.Run(ctx, ca, handshakePreparing)
  257. switch {
  258. case errors.Is(err, context.Canceled):
  259. case errors.Is(err, context.DeadlineExceeded):
  260. t.Error("Timeout")
  261. default:
  262. t.Error(err)
  263. }
  264. }()
  265. go func() {
  266. defer wg.Done()
  267. cfg := &handshakeConfig{
  268. localCipherSuites: cipherSuites,
  269. localCertificates: []tls.Certificate{clientCert},
  270. ellipticCurves: defaultCurves,
  271. localSignatureSchemes: signaturehash.Algorithms(),
  272. insecureSkipVerify: true,
  273. log: logger,
  274. onFlightState: func(f flightVal, s handshakeState) {
  275. if s == handshakeFinished {
  276. if serverEndpoint.OnFinished != nil {
  277. serverEndpoint.OnFinished()
  278. }
  279. time.AfterFunc(serverEndpoint.FinishWait, func() {
  280. cancelSrv()
  281. })
  282. }
  283. },
  284. retransmitInterval: nonZeroRetransmitInterval,
  285. }
  286. fsm := newHandshakeFSM(&cb.state, cb.handshakeCache, cfg, flight0)
  287. err := fsm.Run(ctx, cb, handshakePreparing)
  288. switch {
  289. case errors.Is(err, context.Canceled):
  290. case errors.Is(err, context.DeadlineExceeded):
  291. t.Error("Timeout")
  292. default:
  293. t.Error(err)
  294. }
  295. }()
  296. <-ctxCliFinished.Done()
  297. <-ctxSrvFinished.Done()
  298. cancel()
  299. wg.Wait()
  300. })
  301. }
  302. }
  303. type packetFilter func(p *packet) bool
  304. type TestEndpoint struct {
  305. Filter packetFilter
  306. Delay time.Duration
  307. OnFinished func()
  308. FinishWait time.Duration
  309. }
  310. func flightTestPipe(ctx context.Context, clientEndpoint TestEndpoint, serverEndpoint TestEndpoint) (*flightTestConn, *flightTestConn) {
  311. ca := newHandshakeCache()
  312. cb := newHandshakeCache()
  313. chA := make(chan chan struct{})
  314. chB := make(chan chan struct{})
  315. return &flightTestConn{
  316. handshakeCache: ca,
  317. otherEndCache: cb,
  318. recv: chA,
  319. otherEndRecv: chB,
  320. done: ctx.Done(),
  321. filter: clientEndpoint.Filter,
  322. delay: clientEndpoint.Delay,
  323. }, &flightTestConn{
  324. handshakeCache: cb,
  325. otherEndCache: ca,
  326. recv: chB,
  327. otherEndRecv: chA,
  328. done: ctx.Done(),
  329. filter: serverEndpoint.Filter,
  330. delay: serverEndpoint.Delay,
  331. }
  332. }
  333. type flightTestConn struct {
  334. state State
  335. handshakeCache *handshakeCache
  336. recv chan chan struct{}
  337. done <-chan struct{}
  338. epoch uint16
  339. filter packetFilter
  340. delay time.Duration
  341. otherEndCache *handshakeCache
  342. otherEndRecv chan chan struct{}
  343. }
  344. func (c *flightTestConn) recvHandshake() <-chan chan struct{} {
  345. return c.recv
  346. }
  347. func (c *flightTestConn) setLocalEpoch(epoch uint16) {
  348. c.epoch = epoch
  349. }
  350. func (c *flightTestConn) notify(context.Context, alert.Level, alert.Description) error {
  351. return nil
  352. }
  353. func (c *flightTestConn) writePackets(_ context.Context, pkts []*packet) error {
  354. time.Sleep(c.delay)
  355. for _, p := range pkts {
  356. if c.filter != nil && !c.filter(p) {
  357. continue
  358. }
  359. if h, ok := p.record.Content.(*handshake.Handshake); ok {
  360. handshakeRaw, err := p.record.Marshal()
  361. if err != nil {
  362. return err
  363. }
  364. c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
  365. content, err := h.Message.Marshal()
  366. if err != nil {
  367. return err
  368. }
  369. h.Header.Length = uint32(len(content))
  370. h.Header.FragmentLength = uint32(len(content))
  371. hdr, err := h.Header.Marshal()
  372. if err != nil {
  373. return err
  374. }
  375. c.otherEndCache.push(
  376. append(hdr, content...), p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
  377. }
  378. }
  379. go func() {
  380. select {
  381. case c.otherEndRecv <- make(chan struct{}):
  382. case <-c.done:
  383. }
  384. }()
  385. // Avoid deadlock on JS/WASM environment due to context switch problem.
  386. time.Sleep(10 * time.Millisecond)
  387. return nil
  388. }
  389. func (c *flightTestConn) handleQueuedPackets(context.Context) error {
  390. return nil
  391. }
  392. func (c *flightTestConn) sessionKey() []byte {
  393. return nil
  394. }