| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
- // SPDX-License-Identifier: MIT
- //go:build !js
- // +build !js
- package dtls
- import (
- "bytes"
- "context"
- "crypto/tls"
- "errors"
- "net"
- "testing"
- "time"
- "github.com/pion/dtls/v2/pkg/crypto/selfsign"
- "github.com/pion/transport/v2/dpipe"
- "github.com/pion/transport/v2/test"
- )
- func TestContextConfig(t *testing.T) {
- // Limit runtime in case of deadlocks
- lim := test.TimeOut(time.Second * 20)
- defer lim.Stop()
- report := test.CheckRoutines(t)
- defer report()
- addrListen, err := net.ResolveUDPAddr("udp", "localhost:0")
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
- // Dummy listener
- listen, err := net.ListenUDP("udp", addrListen)
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
- defer func() {
- _ = listen.Close()
- }()
- addr, ok := listen.LocalAddr().(*net.UDPAddr)
- if !ok {
- t.Fatal("Failed to cast net.UDPAddr")
- }
- cert, err := selfsign.GenerateSelfSigned()
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
- config := &Config{
- ConnectContextMaker: func() (context.Context, func()) {
- return context.WithTimeout(context.Background(), 40*time.Millisecond)
- },
- Certificates: []tls.Certificate{cert},
- }
- dials := map[string]struct {
- f func() (func() (net.Conn, error), func())
- order []byte
- }{
- "Dial": {
- f: func() (func() (net.Conn, error), func()) {
- return func() (net.Conn, error) {
- return Dial("udp", addr, config)
- }, func() {
- }
- },
- order: []byte{0, 1, 2},
- },
- "DialWithContext": {
- f: func() (func() (net.Conn, error), func()) {
- ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
- return func() (net.Conn, error) {
- return DialWithContext(ctx, "udp", addr, config)
- }, func() {
- cancel()
- }
- },
- order: []byte{0, 2, 1},
- },
- "Client": {
- f: func() (func() (net.Conn, error), func()) {
- ca, _ := dpipe.Pipe()
- return func() (net.Conn, error) {
- return Client(ca, config)
- }, func() {
- _ = ca.Close()
- }
- },
- order: []byte{0, 1, 2},
- },
- "ClientWithContext": {
- f: func() (func() (net.Conn, error), func()) {
- ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
- ca, _ := dpipe.Pipe()
- return func() (net.Conn, error) {
- return ClientWithContext(ctx, ca, config)
- }, func() {
- cancel()
- _ = ca.Close()
- }
- },
- order: []byte{0, 2, 1},
- },
- "Server": {
- f: func() (func() (net.Conn, error), func()) {
- ca, _ := dpipe.Pipe()
- return func() (net.Conn, error) {
- return Server(ca, config)
- }, func() {
- _ = ca.Close()
- }
- },
- order: []byte{0, 1, 2},
- },
- "ServerWithContext": {
- f: func() (func() (net.Conn, error), func()) {
- ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond)
- ca, _ := dpipe.Pipe()
- return func() (net.Conn, error) {
- return ServerWithContext(ctx, ca, config)
- }, func() {
- cancel()
- _ = ca.Close()
- }
- },
- order: []byte{0, 2, 1},
- },
- }
- for name, dial := range dials {
- dial := dial
- t.Run(name, func(t *testing.T) {
- done := make(chan struct{})
- go func() {
- d, cancel := dial.f()
- conn, err := d()
- defer cancel()
- var netError net.Error
- if !errors.As(err, &netError) || !netError.Temporary() { //nolint:staticcheck
- t.Errorf("Client error exp(Temporary network error) failed(%v)", err)
- close(done)
- return
- }
- done <- struct{}{}
- if err == nil {
- _ = conn.Close()
- }
- }()
- var order []byte
- early := time.After(20 * time.Millisecond)
- late := time.After(60 * time.Millisecond)
- func() {
- for len(order) < 3 {
- select {
- case <-early:
- order = append(order, 0)
- case _, ok := <-done:
- if !ok {
- return
- }
- order = append(order, 1)
- case <-late:
- order = append(order, 2)
- }
- }
- }()
- if !bytes.Equal(dial.order, order) {
- t.Errorf("Invalid cancel timing, expected: %v, got: %v", dial.order, order)
- }
- })
- }
- }
|