replayprotection_test.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "context"
  6. "net"
  7. "reflect"
  8. "sync"
  9. "sync/atomic"
  10. "testing"
  11. "time"
  12. "github.com/pion/transport/v2/dpipe"
  13. "github.com/pion/transport/v2/test"
  14. )
  15. func TestReplayProtection(t *testing.T) {
  16. // Limit runtime in case of deadlocks
  17. lim := test.TimeOut(5 * time.Second)
  18. defer lim.Stop()
  19. // Check for leaking routines
  20. report := test.CheckRoutines(t)
  21. defer report()
  22. c0, c1 := dpipe.Pipe()
  23. c2, c3 := dpipe.Pipe()
  24. conn := []net.Conn{c0, c1, c2, c3}
  25. var wgRoutines sync.WaitGroup
  26. var cntReplays int32 = 1
  27. ctxReplayDone, replayDone := context.WithCancel(context.Background())
  28. replaySendDone := func() {
  29. cnt := atomic.AddInt32(&cntReplays, -1)
  30. if cnt == 0 {
  31. replayDone()
  32. }
  33. }
  34. replayer := func(ca, cb net.Conn) {
  35. defer wgRoutines.Done()
  36. // Man in the middle
  37. for {
  38. b := make([]byte, 2048)
  39. n, rerr := ca.Read(b)
  40. if rerr != nil {
  41. return
  42. }
  43. if _, werr := cb.Write(b[:n]); werr != nil {
  44. t.Error(werr)
  45. return
  46. }
  47. atomic.AddInt32(&cntReplays, 1)
  48. go func() {
  49. defer replaySendDone()
  50. // Replay bit later
  51. time.Sleep(time.Millisecond)
  52. if _, werr := cb.Write(b[:n]); werr != nil {
  53. t.Error(werr)
  54. }
  55. }()
  56. }
  57. }
  58. wgRoutines.Add(2)
  59. go replayer(conn[1], conn[2])
  60. go replayer(conn[2], conn[1])
  61. ca, cb, err := pipeConn(conn[0], conn[3])
  62. if err != nil {
  63. t.Fatal(err)
  64. }
  65. const numMsgs = 10
  66. var received [2][][]byte
  67. for i, c := range []net.Conn{ca, cb} {
  68. i := i
  69. c := c
  70. wgRoutines.Add(1)
  71. atomic.AddInt32(&cntReplays, 1) // Keep locked until the final message
  72. var lastMsgDone sync.Once
  73. go func() {
  74. defer wgRoutines.Done()
  75. for {
  76. b := make([]byte, 2048)
  77. n, rerr := c.Read(b)
  78. if rerr != nil {
  79. return
  80. }
  81. received[i] = append(received[i], b[:n])
  82. if b[0] == numMsgs-1 {
  83. // Final message received
  84. lastMsgDone.Do(func() {
  85. defer replaySendDone()
  86. })
  87. }
  88. }
  89. }()
  90. }
  91. var sent [][]byte
  92. for i := 0; i < numMsgs; i++ {
  93. data := []byte{byte(i)}
  94. sent = append(sent, data)
  95. if _, werr := ca.Write(data); werr != nil {
  96. t.Error(werr)
  97. return
  98. }
  99. if _, werr := cb.Write(data); werr != nil {
  100. t.Error(werr)
  101. return
  102. }
  103. }
  104. replaySendDone()
  105. <-ctxReplayDone.Done()
  106. time.Sleep(10 * time.Millisecond) // Ensure all replayed packets are sent
  107. for i := 0; i < 4; i++ {
  108. if err := conn[i].Close(); err != nil {
  109. t.Error(err)
  110. }
  111. }
  112. if err := ca.Close(); err != nil {
  113. t.Error(err)
  114. }
  115. if err := cb.Close(); err != nil {
  116. t.Error(err)
  117. }
  118. wgRoutines.Wait()
  119. for _, r := range received {
  120. if !reflect.DeepEqual(sent, r) {
  121. t.Errorf("Received data differs, expected: %v, got: %v", sent, r)
  122. }
  123. }
  124. }