| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- //go:build !js
- // +build !js
- package e2e
- import (
- "context"
- "crypto/ed25519"
- "crypto/rand"
- "crypto/rsa"
- "crypto/tls"
- "crypto/x509"
- "errors"
- "fmt"
- "io"
- "net"
- "sync"
- "sync/atomic"
- "testing"
- "time"
- "github.com/pion/dtls/v2"
- "github.com/pion/dtls/v2/pkg/crypto/selfsign"
- "github.com/pion/transport/v2/test"
- )
- const (
- testMessage = "Hello World"
- testTimeLimit = 5 * time.Second
- messageRetry = 200 * time.Millisecond
- )
- var errServerTimeout = errors.New("waiting on serverReady err: timeout")
- func randomPort(t testing.TB) int {
- t.Helper()
- conn, err := net.ListenPacket("udp4", "127.0.0.1:0")
- if err != nil {
- t.Fatalf("failed to pickPort: %v", err)
- }
- defer func() {
- _ = conn.Close()
- }()
- switch addr := conn.LocalAddr().(type) {
- case *net.UDPAddr:
- return addr.Port
- default:
- t.Fatalf("unknown addr type %T", addr)
- return 0
- }
- }
- func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter, messageRecvCount *uint64) {
- go func() {
- buffer := make([]byte, 8192)
- n, err := conn.Read(buffer)
- if err != nil {
- errChan <- err
- return
- }
- outChan <- string(buffer[:n])
- atomic.AddUint64(messageRecvCount, 1)
- }()
- for {
- if atomic.LoadUint64(messageRecvCount) == 2 {
- break
- } else if _, err := conn.Write([]byte(testMessage)); err != nil {
- errChan <- err
- break
- }
- time.Sleep(messageRetry)
- }
- }
- type comm struct {
- ctx context.Context
- clientConfig, serverConfig *dtls.Config
- serverPort int
- messageRecvCount *uint64 // Counter to make sure both sides got a message
- clientMutex *sync.Mutex
- clientConn net.Conn
- serverMutex *sync.Mutex
- serverConn net.Conn
- serverListener net.Listener
- serverReady chan struct{}
- errChan chan error
- clientChan chan string
- serverChan chan string
- client func(*comm)
- server func(*comm)
- }
- func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serverPort int, server, client func(*comm)) *comm {
- messageRecvCount := uint64(0)
- c := &comm{
- ctx: ctx,
- clientConfig: clientConfig,
- serverConfig: serverConfig,
- serverPort: serverPort,
- messageRecvCount: &messageRecvCount,
- clientMutex: &sync.Mutex{},
- serverMutex: &sync.Mutex{},
- serverReady: make(chan struct{}),
- errChan: make(chan error),
- clientChan: make(chan string),
- serverChan: make(chan string),
- server: server,
- client: client,
- }
- return c
- }
- func (c *comm) assert(t *testing.T) {
- // DTLS Client
- go c.client(c)
- // DTLS Server
- go c.server(c)
- defer func() {
- if c.clientConn != nil {
- if err := c.clientConn.Close(); err != nil {
- t.Fatal(err)
- }
- }
- if c.serverConn != nil {
- if err := c.serverConn.Close(); err != nil {
- t.Fatal(err)
- }
- }
- if c.serverListener != nil {
- if err := c.serverListener.Close(); err != nil {
- t.Fatal(err)
- }
- }
- }()
- func() {
- seenClient, seenServer := false, false
- for {
- select {
- case err := <-c.errChan:
- t.Fatal(err)
- case <-time.After(testTimeLimit):
- t.Fatalf("Test timeout, seenClient %t seenServer %t", seenClient, seenServer)
- case clientMsg := <-c.clientChan:
- if clientMsg != testMessage {
- t.Fatalf("clientMsg does not equal test message: %s %s", clientMsg, testMessage)
- }
- seenClient = true
- if seenClient && seenServer {
- return
- }
- case serverMsg := <-c.serverChan:
- if serverMsg != testMessage {
- t.Fatalf("serverMsg does not equal test message: %s %s", serverMsg, testMessage)
- }
- seenServer = true
- if seenClient && seenServer {
- return
- }
- }
- }
- }()
- }
- func clientPion(c *comm) {
- select {
- case <-c.serverReady:
- // OK
- case <-time.After(time.Second):
- c.errChan <- errServerTimeout
- }
- c.clientMutex.Lock()
- defer c.clientMutex.Unlock()
- var err error
- c.clientConn, err = dtls.DialWithContext(c.ctx, "udp",
- &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
- c.clientConfig,
- )
- if err != nil {
- c.errChan <- err
- return
- }
- simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount)
- }
- func serverPion(c *comm) {
- c.serverMutex.Lock()
- defer c.serverMutex.Unlock()
- var err error
- c.serverListener, err = dtls.Listen("udp",
- &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
- c.serverConfig,
- )
- if err != nil {
- c.errChan <- err
- return
- }
- c.serverReady <- struct{}{}
- c.serverConn, err = c.serverListener.Accept()
- if err != nil {
- c.errChan <- err
- return
- }
- simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount)
- }
- /*
- Simple DTLS Client/Server can communicate
- - Assert that you can send messages both ways
- - Assert that Close() on both ends work
- - Assert that no Goroutines are leaked
- */
- func testPionE2ESimple(t *testing.T, server, client func(*comm)) {
- lim := test.TimeOut(time.Second * 30)
- defer lim.Stop()
- report := test.CheckRoutines(t)
- defer report()
- for _, cipherSuite := range []dtls.CipherSuiteID{
- dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
- dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
- dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
- } {
- cipherSuite := cipherSuite
- t.Run(cipherSuite.String(), func(t *testing.T) {
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
- if err != nil {
- t.Fatal(err)
- }
- cfg := &dtls.Config{
- Certificates: []tls.Certificate{cert},
- CipherSuites: []dtls.CipherSuiteID{cipherSuite},
- InsecureSkipVerify: true,
- }
- serverPort := randomPort(t)
- comm := newComm(ctx, cfg, cfg, serverPort, server, client)
- comm.assert(t)
- })
- }
- }
- func testPionE2ESimplePSK(t *testing.T, server, client func(*comm)) {
- lim := test.TimeOut(time.Second * 30)
- defer lim.Stop()
- report := test.CheckRoutines(t)
- defer report()
- for _, cipherSuite := range []dtls.CipherSuiteID{
- dtls.TLS_PSK_WITH_AES_128_CCM,
- dtls.TLS_PSK_WITH_AES_128_CCM_8,
- dtls.TLS_PSK_WITH_AES_256_CCM_8,
- dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
- dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
- } {
- cipherSuite := cipherSuite
- t.Run(cipherSuite.String(), func(t *testing.T) {
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- cfg := &dtls.Config{
- PSK: func(hint []byte) ([]byte, error) {
- return []byte{0xAB, 0xC1, 0x23}, nil
- },
- PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
- CipherSuites: []dtls.CipherSuiteID{cipherSuite},
- }
- serverPort := randomPort(t)
- comm := newComm(ctx, cfg, cfg, serverPort, server, client)
- comm.assert(t)
- })
- }
- }
- func testPionE2EMTUs(t *testing.T, server, client func(*comm)) {
- lim := test.TimeOut(time.Second * 30)
- defer lim.Stop()
- report := test.CheckRoutines(t)
- defer report()
- for _, mtu := range []int{
- 10000,
- 1000,
- 100,
- } {
- mtu := mtu
- t.Run(fmt.Sprintf("MTU%d", mtu), func(t *testing.T) {
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
- if err != nil {
- t.Fatal(err)
- }
- cfg := &dtls.Config{
- Certificates: []tls.Certificate{cert},
- CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
- InsecureSkipVerify: true,
- MTU: mtu,
- }
- serverPort := randomPort(t)
- comm := newComm(ctx, cfg, cfg, serverPort, server, client)
- comm.assert(t)
- })
- }
- }
- func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm)) {
- lim := test.TimeOut(time.Second * 30)
- defer lim.Stop()
- report := test.CheckRoutines(t)
- defer report()
- for _, cipherSuite := range []dtls.CipherSuiteID{
- dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM,
- dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8,
- dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
- dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
- dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
- } {
- cipherSuite := cipherSuite
- t.Run(cipherSuite.String(), func(t *testing.T) {
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- _, key, err := ed25519.GenerateKey(rand.Reader)
- if err != nil {
- t.Fatal(err)
- }
- cert, err := selfsign.SelfSign(key)
- if err != nil {
- t.Fatal(err)
- }
- cfg := &dtls.Config{
- Certificates: []tls.Certificate{cert},
- CipherSuites: []dtls.CipherSuiteID{cipherSuite},
- InsecureSkipVerify: true,
- }
- serverPort := randomPort(t)
- comm := newComm(ctx, cfg, cfg, serverPort, server, client)
- comm.assert(t)
- })
- }
- }
- func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm)) {
- lim := test.TimeOut(time.Second * 30)
- defer lim.Stop()
- report := test.CheckRoutines(t)
- defer report()
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- _, skey, err := ed25519.GenerateKey(rand.Reader)
- if err != nil {
- t.Fatal(err)
- }
- scert, err := selfsign.SelfSign(skey)
- if err != nil {
- t.Fatal(err)
- }
- _, ckey, err := ed25519.GenerateKey(rand.Reader)
- if err != nil {
- t.Fatal(err)
- }
- ccert, err := selfsign.SelfSign(ckey)
- if err != nil {
- t.Fatal(err)
- }
- scfg := &dtls.Config{
- Certificates: []tls.Certificate{scert},
- CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
- ClientAuth: dtls.RequireAnyClientCert,
- }
- ccfg := &dtls.Config{
- Certificates: []tls.Certificate{ccert},
- CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
- InsecureSkipVerify: true,
- }
- serverPort := randomPort(t)
- comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
- comm.assert(t)
- }
- func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm)) {
- lim := test.TimeOut(time.Second * 30)
- defer lim.Stop()
- report := test.CheckRoutines(t)
- defer report()
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- scert, err := selfsign.GenerateSelfSigned()
- if err != nil {
- t.Fatal(err)
- }
- ccert, err := selfsign.GenerateSelfSigned()
- if err != nil {
- t.Fatal(err)
- }
- clientCAs := x509.NewCertPool()
- caCert, err := x509.ParseCertificate(ccert.Certificate[0])
- if err != nil {
- t.Fatal(err)
- }
- clientCAs.AddCert(caCert)
- scfg := &dtls.Config{
- ClientCAs: clientCAs,
- Certificates: []tls.Certificate{scert},
- CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
- ClientAuth: dtls.RequireAnyClientCert,
- }
- ccfg := &dtls.Config{
- Certificates: []tls.Certificate{ccert},
- CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
- InsecureSkipVerify: true,
- }
- serverPort := randomPort(t)
- comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
- comm.assert(t)
- }
- func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm)) {
- lim := test.TimeOut(time.Second * 30)
- defer lim.Stop()
- report := test.CheckRoutines(t)
- defer report()
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- spriv, err := rsa.GenerateKey(rand.Reader, 2048)
- if err != nil {
- t.Fatal(err)
- }
- scert, err := selfsign.SelfSign(spriv)
- if err != nil {
- t.Fatal(err)
- }
- cpriv, err := rsa.GenerateKey(rand.Reader, 2048)
- if err != nil {
- t.Fatal(err)
- }
- ccert, err := selfsign.SelfSign(cpriv)
- if err != nil {
- t.Fatal(err)
- }
- scfg := &dtls.Config{
- Certificates: []tls.Certificate{scert},
- CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
- ClientAuth: dtls.RequireAnyClientCert,
- }
- ccfg := &dtls.Config{
- Certificates: []tls.Certificate{ccert},
- CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
- InsecureSkipVerify: true,
- }
- serverPort := randomPort(t)
- comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
- comm.assert(t)
- }
- func TestPionE2ESimple(t *testing.T) {
- testPionE2ESimple(t, serverPion, clientPion)
- }
- func TestPionE2ESimplePSK(t *testing.T) {
- testPionE2ESimplePSK(t, serverPion, clientPion)
- }
- func TestPionE2EMTUs(t *testing.T) {
- testPionE2EMTUs(t, serverPion, clientPion)
- }
- func TestPionE2ESimpleED25519(t *testing.T) {
- testPionE2ESimpleED25519(t, serverPion, clientPion)
- }
- func TestPionE2ESimpleED25519ClientCert(t *testing.T) {
- testPionE2ESimpleED25519ClientCert(t, serverPion, clientPion)
- }
- func TestPionE2ESimpleECDSAClientCert(t *testing.T) {
- testPionE2ESimpleECDSAClientCert(t, serverPion, clientPion)
- }
- func TestPionE2ESimpleRSAClientCert(t *testing.T) {
- testPionE2ESimpleRSAClientCert(t, serverPion, clientPion)
- }
|