e2e_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. //go:build !js
  4. // +build !js
  5. package e2e
  6. import (
  7. "context"
  8. "crypto/ed25519"
  9. "crypto/rand"
  10. "crypto/rsa"
  11. "crypto/tls"
  12. "crypto/x509"
  13. "errors"
  14. "fmt"
  15. "io"
  16. "net"
  17. "sync"
  18. "sync/atomic"
  19. "testing"
  20. "time"
  21. "github.com/pion/dtls/v2"
  22. "github.com/pion/dtls/v2/pkg/crypto/selfsign"
  23. "github.com/pion/transport/v2/test"
  24. )
  25. const (
  26. testMessage = "Hello World"
  27. testTimeLimit = 5 * time.Second
  28. messageRetry = 200 * time.Millisecond
  29. )
  30. var errServerTimeout = errors.New("waiting on serverReady err: timeout")
  31. func randomPort(t testing.TB) int {
  32. t.Helper()
  33. conn, err := net.ListenPacket("udp4", "127.0.0.1:0")
  34. if err != nil {
  35. t.Fatalf("failed to pickPort: %v", err)
  36. }
  37. defer func() {
  38. _ = conn.Close()
  39. }()
  40. switch addr := conn.LocalAddr().(type) {
  41. case *net.UDPAddr:
  42. return addr.Port
  43. default:
  44. t.Fatalf("unknown addr type %T", addr)
  45. return 0
  46. }
  47. }
  48. func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter, messageRecvCount *uint64) {
  49. go func() {
  50. buffer := make([]byte, 8192)
  51. n, err := conn.Read(buffer)
  52. if err != nil {
  53. errChan <- err
  54. return
  55. }
  56. outChan <- string(buffer[:n])
  57. atomic.AddUint64(messageRecvCount, 1)
  58. }()
  59. for {
  60. if atomic.LoadUint64(messageRecvCount) == 2 {
  61. break
  62. } else if _, err := conn.Write([]byte(testMessage)); err != nil {
  63. errChan <- err
  64. break
  65. }
  66. time.Sleep(messageRetry)
  67. }
  68. }
  69. type comm struct {
  70. ctx context.Context
  71. clientConfig, serverConfig *dtls.Config
  72. serverPort int
  73. messageRecvCount *uint64 // Counter to make sure both sides got a message
  74. clientMutex *sync.Mutex
  75. clientConn net.Conn
  76. serverMutex *sync.Mutex
  77. serverConn net.Conn
  78. serverListener net.Listener
  79. serverReady chan struct{}
  80. errChan chan error
  81. clientChan chan string
  82. serverChan chan string
  83. client func(*comm)
  84. server func(*comm)
  85. }
  86. func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serverPort int, server, client func(*comm)) *comm {
  87. messageRecvCount := uint64(0)
  88. c := &comm{
  89. ctx: ctx,
  90. clientConfig: clientConfig,
  91. serverConfig: serverConfig,
  92. serverPort: serverPort,
  93. messageRecvCount: &messageRecvCount,
  94. clientMutex: &sync.Mutex{},
  95. serverMutex: &sync.Mutex{},
  96. serverReady: make(chan struct{}),
  97. errChan: make(chan error),
  98. clientChan: make(chan string),
  99. serverChan: make(chan string),
  100. server: server,
  101. client: client,
  102. }
  103. return c
  104. }
  105. func (c *comm) assert(t *testing.T) {
  106. // DTLS Client
  107. go c.client(c)
  108. // DTLS Server
  109. go c.server(c)
  110. defer func() {
  111. if c.clientConn != nil {
  112. if err := c.clientConn.Close(); err != nil {
  113. t.Fatal(err)
  114. }
  115. }
  116. if c.serverConn != nil {
  117. if err := c.serverConn.Close(); err != nil {
  118. t.Fatal(err)
  119. }
  120. }
  121. if c.serverListener != nil {
  122. if err := c.serverListener.Close(); err != nil {
  123. t.Fatal(err)
  124. }
  125. }
  126. }()
  127. func() {
  128. seenClient, seenServer := false, false
  129. for {
  130. select {
  131. case err := <-c.errChan:
  132. t.Fatal(err)
  133. case <-time.After(testTimeLimit):
  134. t.Fatalf("Test timeout, seenClient %t seenServer %t", seenClient, seenServer)
  135. case clientMsg := <-c.clientChan:
  136. if clientMsg != testMessage {
  137. t.Fatalf("clientMsg does not equal test message: %s %s", clientMsg, testMessage)
  138. }
  139. seenClient = true
  140. if seenClient && seenServer {
  141. return
  142. }
  143. case serverMsg := <-c.serverChan:
  144. if serverMsg != testMessage {
  145. t.Fatalf("serverMsg does not equal test message: %s %s", serverMsg, testMessage)
  146. }
  147. seenServer = true
  148. if seenClient && seenServer {
  149. return
  150. }
  151. }
  152. }
  153. }()
  154. }
  155. func clientPion(c *comm) {
  156. select {
  157. case <-c.serverReady:
  158. // OK
  159. case <-time.After(time.Second):
  160. c.errChan <- errServerTimeout
  161. }
  162. c.clientMutex.Lock()
  163. defer c.clientMutex.Unlock()
  164. var err error
  165. c.clientConn, err = dtls.DialWithContext(c.ctx, "udp",
  166. &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
  167. c.clientConfig,
  168. )
  169. if err != nil {
  170. c.errChan <- err
  171. return
  172. }
  173. simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount)
  174. }
  175. func serverPion(c *comm) {
  176. c.serverMutex.Lock()
  177. defer c.serverMutex.Unlock()
  178. var err error
  179. c.serverListener, err = dtls.Listen("udp",
  180. &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort},
  181. c.serverConfig,
  182. )
  183. if err != nil {
  184. c.errChan <- err
  185. return
  186. }
  187. c.serverReady <- struct{}{}
  188. c.serverConn, err = c.serverListener.Accept()
  189. if err != nil {
  190. c.errChan <- err
  191. return
  192. }
  193. simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount)
  194. }
  195. /*
  196. Simple DTLS Client/Server can communicate
  197. - Assert that you can send messages both ways
  198. - Assert that Close() on both ends work
  199. - Assert that no Goroutines are leaked
  200. */
  201. func testPionE2ESimple(t *testing.T, server, client func(*comm)) {
  202. lim := test.TimeOut(time.Second * 30)
  203. defer lim.Stop()
  204. report := test.CheckRoutines(t)
  205. defer report()
  206. for _, cipherSuite := range []dtls.CipherSuiteID{
  207. dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
  208. dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
  209. dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
  210. } {
  211. cipherSuite := cipherSuite
  212. t.Run(cipherSuite.String(), func(t *testing.T) {
  213. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
  214. defer cancel()
  215. cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
  216. if err != nil {
  217. t.Fatal(err)
  218. }
  219. cfg := &dtls.Config{
  220. Certificates: []tls.Certificate{cert},
  221. CipherSuites: []dtls.CipherSuiteID{cipherSuite},
  222. InsecureSkipVerify: true,
  223. }
  224. serverPort := randomPort(t)
  225. comm := newComm(ctx, cfg, cfg, serverPort, server, client)
  226. comm.assert(t)
  227. })
  228. }
  229. }
  230. func testPionE2ESimplePSK(t *testing.T, server, client func(*comm)) {
  231. lim := test.TimeOut(time.Second * 30)
  232. defer lim.Stop()
  233. report := test.CheckRoutines(t)
  234. defer report()
  235. for _, cipherSuite := range []dtls.CipherSuiteID{
  236. dtls.TLS_PSK_WITH_AES_128_CCM,
  237. dtls.TLS_PSK_WITH_AES_128_CCM_8,
  238. dtls.TLS_PSK_WITH_AES_256_CCM_8,
  239. dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
  240. dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
  241. } {
  242. cipherSuite := cipherSuite
  243. t.Run(cipherSuite.String(), func(t *testing.T) {
  244. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
  245. defer cancel()
  246. cfg := &dtls.Config{
  247. PSK: func(hint []byte) ([]byte, error) {
  248. return []byte{0xAB, 0xC1, 0x23}, nil
  249. },
  250. PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
  251. CipherSuites: []dtls.CipherSuiteID{cipherSuite},
  252. }
  253. serverPort := randomPort(t)
  254. comm := newComm(ctx, cfg, cfg, serverPort, server, client)
  255. comm.assert(t)
  256. })
  257. }
  258. }
  259. func testPionE2EMTUs(t *testing.T, server, client func(*comm)) {
  260. lim := test.TimeOut(time.Second * 30)
  261. defer lim.Stop()
  262. report := test.CheckRoutines(t)
  263. defer report()
  264. for _, mtu := range []int{
  265. 10000,
  266. 1000,
  267. 100,
  268. } {
  269. mtu := mtu
  270. t.Run(fmt.Sprintf("MTU%d", mtu), func(t *testing.T) {
  271. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  272. defer cancel()
  273. cert, err := selfsign.GenerateSelfSignedWithDNS("localhost")
  274. if err != nil {
  275. t.Fatal(err)
  276. }
  277. cfg := &dtls.Config{
  278. Certificates: []tls.Certificate{cert},
  279. CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  280. InsecureSkipVerify: true,
  281. MTU: mtu,
  282. }
  283. serverPort := randomPort(t)
  284. comm := newComm(ctx, cfg, cfg, serverPort, server, client)
  285. comm.assert(t)
  286. })
  287. }
  288. }
  289. func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm)) {
  290. lim := test.TimeOut(time.Second * 30)
  291. defer lim.Stop()
  292. report := test.CheckRoutines(t)
  293. defer report()
  294. for _, cipherSuite := range []dtls.CipherSuiteID{
  295. dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM,
  296. dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8,
  297. dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
  298. dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
  299. dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
  300. } {
  301. cipherSuite := cipherSuite
  302. t.Run(cipherSuite.String(), func(t *testing.T) {
  303. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  304. defer cancel()
  305. _, key, err := ed25519.GenerateKey(rand.Reader)
  306. if err != nil {
  307. t.Fatal(err)
  308. }
  309. cert, err := selfsign.SelfSign(key)
  310. if err != nil {
  311. t.Fatal(err)
  312. }
  313. cfg := &dtls.Config{
  314. Certificates: []tls.Certificate{cert},
  315. CipherSuites: []dtls.CipherSuiteID{cipherSuite},
  316. InsecureSkipVerify: true,
  317. }
  318. serverPort := randomPort(t)
  319. comm := newComm(ctx, cfg, cfg, serverPort, server, client)
  320. comm.assert(t)
  321. })
  322. }
  323. }
  324. func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm)) {
  325. lim := test.TimeOut(time.Second * 30)
  326. defer lim.Stop()
  327. report := test.CheckRoutines(t)
  328. defer report()
  329. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  330. defer cancel()
  331. _, skey, err := ed25519.GenerateKey(rand.Reader)
  332. if err != nil {
  333. t.Fatal(err)
  334. }
  335. scert, err := selfsign.SelfSign(skey)
  336. if err != nil {
  337. t.Fatal(err)
  338. }
  339. _, ckey, err := ed25519.GenerateKey(rand.Reader)
  340. if err != nil {
  341. t.Fatal(err)
  342. }
  343. ccert, err := selfsign.SelfSign(ckey)
  344. if err != nil {
  345. t.Fatal(err)
  346. }
  347. scfg := &dtls.Config{
  348. Certificates: []tls.Certificate{scert},
  349. CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  350. ClientAuth: dtls.RequireAnyClientCert,
  351. }
  352. ccfg := &dtls.Config{
  353. Certificates: []tls.Certificate{ccert},
  354. CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  355. InsecureSkipVerify: true,
  356. }
  357. serverPort := randomPort(t)
  358. comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
  359. comm.assert(t)
  360. }
  361. func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm)) {
  362. lim := test.TimeOut(time.Second * 30)
  363. defer lim.Stop()
  364. report := test.CheckRoutines(t)
  365. defer report()
  366. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  367. defer cancel()
  368. scert, err := selfsign.GenerateSelfSigned()
  369. if err != nil {
  370. t.Fatal(err)
  371. }
  372. ccert, err := selfsign.GenerateSelfSigned()
  373. if err != nil {
  374. t.Fatal(err)
  375. }
  376. clientCAs := x509.NewCertPool()
  377. caCert, err := x509.ParseCertificate(ccert.Certificate[0])
  378. if err != nil {
  379. t.Fatal(err)
  380. }
  381. clientCAs.AddCert(caCert)
  382. scfg := &dtls.Config{
  383. ClientCAs: clientCAs,
  384. Certificates: []tls.Certificate{scert},
  385. CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  386. ClientAuth: dtls.RequireAnyClientCert,
  387. }
  388. ccfg := &dtls.Config{
  389. Certificates: []tls.Certificate{ccert},
  390. CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
  391. InsecureSkipVerify: true,
  392. }
  393. serverPort := randomPort(t)
  394. comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
  395. comm.assert(t)
  396. }
  397. func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm)) {
  398. lim := test.TimeOut(time.Second * 30)
  399. defer lim.Stop()
  400. report := test.CheckRoutines(t)
  401. defer report()
  402. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  403. defer cancel()
  404. spriv, err := rsa.GenerateKey(rand.Reader, 2048)
  405. if err != nil {
  406. t.Fatal(err)
  407. }
  408. scert, err := selfsign.SelfSign(spriv)
  409. if err != nil {
  410. t.Fatal(err)
  411. }
  412. cpriv, err := rsa.GenerateKey(rand.Reader, 2048)
  413. if err != nil {
  414. t.Fatal(err)
  415. }
  416. ccert, err := selfsign.SelfSign(cpriv)
  417. if err != nil {
  418. t.Fatal(err)
  419. }
  420. scfg := &dtls.Config{
  421. Certificates: []tls.Certificate{scert},
  422. CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
  423. ClientAuth: dtls.RequireAnyClientCert,
  424. }
  425. ccfg := &dtls.Config{
  426. Certificates: []tls.Certificate{ccert},
  427. CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
  428. InsecureSkipVerify: true,
  429. }
  430. serverPort := randomPort(t)
  431. comm := newComm(ctx, ccfg, scfg, serverPort, server, client)
  432. comm.assert(t)
  433. }
  434. func TestPionE2ESimple(t *testing.T) {
  435. testPionE2ESimple(t, serverPion, clientPion)
  436. }
  437. func TestPionE2ESimplePSK(t *testing.T) {
  438. testPionE2ESimplePSK(t, serverPion, clientPion)
  439. }
  440. func TestPionE2EMTUs(t *testing.T) {
  441. testPionE2EMTUs(t, serverPion, clientPion)
  442. }
  443. func TestPionE2ESimpleED25519(t *testing.T) {
  444. testPionE2ESimpleED25519(t, serverPion, clientPion)
  445. }
  446. func TestPionE2ESimpleED25519ClientCert(t *testing.T) {
  447. testPionE2ESimpleED25519ClientCert(t, serverPion, clientPion)
  448. }
  449. func TestPionE2ESimpleECDSAClientCert(t *testing.T) {
  450. testPionE2ESimpleECDSAClientCert(t, serverPion, clientPion)
  451. }
  452. func TestPionE2ESimpleRSAClientCert(t *testing.T) {
  453. testPionE2ESimpleRSAClientCert(t, serverPion, clientPion)
  454. }