| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- package dtls
- import (
- "context"
- "net"
- "reflect"
- "sync"
- "sync/atomic"
- "testing"
- "time"
- "github.com/pion/transport/v2/dpipe"
- "github.com/pion/transport/v2/test"
- )
- func TestReplayProtection(t *testing.T) {
- // Limit runtime in case of deadlocks
- lim := test.TimeOut(5 * time.Second)
- defer lim.Stop()
- // Check for leaking routines
- report := test.CheckRoutines(t)
- defer report()
- c0, c1 := dpipe.Pipe()
- c2, c3 := dpipe.Pipe()
- conn := []net.Conn{c0, c1, c2, c3}
- var wgRoutines sync.WaitGroup
- var cntReplays int32 = 1
- ctxReplayDone, replayDone := context.WithCancel(context.Background())
- replaySendDone := func() {
- cnt := atomic.AddInt32(&cntReplays, -1)
- if cnt == 0 {
- replayDone()
- }
- }
- replayer := func(ca, cb net.Conn) {
- defer wgRoutines.Done()
- // Man in the middle
- for {
- b := make([]byte, 2048)
- n, rerr := ca.Read(b)
- if rerr != nil {
- return
- }
- if _, werr := cb.Write(b[:n]); werr != nil {
- t.Error(werr)
- return
- }
- atomic.AddInt32(&cntReplays, 1)
- go func() {
- defer replaySendDone()
- // Replay bit later
- time.Sleep(time.Millisecond)
- if _, werr := cb.Write(b[:n]); werr != nil {
- t.Error(werr)
- }
- }()
- }
- }
- wgRoutines.Add(2)
- go replayer(conn[1], conn[2])
- go replayer(conn[2], conn[1])
- ca, cb, err := pipeConn(conn[0], conn[3])
- if err != nil {
- t.Fatal(err)
- }
- const numMsgs = 10
- var received [2][][]byte
- for i, c := range []net.Conn{ca, cb} {
- i := i
- c := c
- wgRoutines.Add(1)
- atomic.AddInt32(&cntReplays, 1) // Keep locked until the final message
- var lastMsgDone sync.Once
- go func() {
- defer wgRoutines.Done()
- for {
- b := make([]byte, 2048)
- n, rerr := c.Read(b)
- if rerr != nil {
- return
- }
- received[i] = append(received[i], b[:n])
- if b[0] == numMsgs-1 {
- // Final message received
- lastMsgDone.Do(func() {
- defer replaySendDone()
- })
- }
- }
- }()
- }
- var sent [][]byte
- for i := 0; i < numMsgs; i++ {
- data := []byte{byte(i)}
- sent = append(sent, data)
- if _, werr := ca.Write(data); werr != nil {
- t.Error(werr)
- return
- }
- if _, werr := cb.Write(data); werr != nil {
- t.Error(werr)
- return
- }
- }
- replaySendDone()
- <-ctxReplayDone.Done()
- time.Sleep(10 * time.Millisecond) // Ensure all replayed packets are sent
- for i := 0; i < 4; i++ {
- if err := conn[i].Close(); err != nil {
- t.Error(err)
- }
- }
- if err := ca.Close(); err != nil {
- t.Error(err)
- }
- if err := cb.Close(); err != nil {
- t.Error(err)
- }
- wgRoutines.Wait()
- for _, r := range received {
- if !reflect.DeepEqual(sent, r) {
- t.Errorf("Received data differs, expected: %v, got: %v", sent, r)
- }
- }
- }
|