errors_test.go 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
  2. // SPDX-License-Identifier: MIT
  3. package dtls
  4. import (
  5. "errors"
  6. "fmt"
  7. "net"
  8. "testing"
  9. )
  10. var errExample = errors.New("an example error")
  11. func TestErrorUnwrap(t *testing.T) {
  12. cases := []struct {
  13. err error
  14. errUnwrapped []error
  15. }{
  16. {
  17. &FatalError{Err: errExample},
  18. []error{errExample},
  19. },
  20. {
  21. &TemporaryError{Err: errExample},
  22. []error{errExample},
  23. },
  24. {
  25. &InternalError{Err: errExample},
  26. []error{errExample},
  27. },
  28. {
  29. &TimeoutError{Err: errExample},
  30. []error{errExample},
  31. },
  32. {
  33. &HandshakeError{Err: errExample},
  34. []error{errExample},
  35. },
  36. }
  37. for _, c := range cases {
  38. c := c
  39. t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) {
  40. err := c.err
  41. for _, unwrapped := range c.errUnwrapped {
  42. e := errors.Unwrap(err)
  43. if !errors.Is(e, unwrapped) {
  44. t.Errorf("Unwrapped error is expected to be '%v', got '%v'", unwrapped, e)
  45. }
  46. }
  47. })
  48. }
  49. }
  50. func TestErrorNetError(t *testing.T) {
  51. cases := []struct {
  52. err error
  53. str string
  54. timeout, temporary bool
  55. }{
  56. {&FatalError{Err: errExample}, "dtls fatal: an example error", false, false},
  57. {&TemporaryError{Err: errExample}, "dtls temporary: an example error", false, true},
  58. {&InternalError{Err: errExample}, "dtls internal: an example error", false, false},
  59. {&TimeoutError{Err: errExample}, "dtls timeout: an example error", true, true},
  60. {&HandshakeError{Err: errExample}, "handshake error: an example error", false, false},
  61. {&HandshakeError{Err: &TimeoutError{Err: errExample}}, "handshake error: dtls timeout: an example error", true, true},
  62. }
  63. for _, c := range cases {
  64. c := c
  65. t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) {
  66. var ne net.Error
  67. if !errors.As(c.err, &ne) {
  68. t.Fatalf("%T doesn't implement net.Error", c.err)
  69. }
  70. if ne.Timeout() != c.timeout {
  71. t.Errorf("%T.Timeout() should be %v", c.err, c.timeout)
  72. }
  73. if ne.Temporary() != c.temporary { //nolint:staticcheck
  74. t.Errorf("%T.Temporary() should be %v", c.err, c.temporary)
  75. }
  76. if ne.Error() != c.str {
  77. t.Errorf("%T.Error() should be %v", c.err, c.str)
  78. }
  79. })
  80. }
  81. }