|
|
@@ -23,17 +23,43 @@ import (
|
|
|
"fmt"
|
|
|
"reflect"
|
|
|
"testing"
|
|
|
+
|
|
|
+ "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
|
|
|
)
|
|
|
|
|
|
func TestTunnelProtocolValidation(t *testing.T) {
|
|
|
|
|
|
- err := SupportedTunnelProtocols.Validate()
|
|
|
- if err != nil {
|
|
|
- t.Errorf("unexpected Validate error: %s", err)
|
|
|
+ validSupportedProtocols := make(TunnelProtocols, 0)
|
|
|
+ for _, p := range SupportedTunnelProtocols {
|
|
|
+ if common.Contains(DisabledTunnelProtocols, p) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ validSupportedProtocols = append(validSupportedProtocols, p)
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(validSupportedProtocols) == len(SupportedTunnelProtocols) {
|
|
|
+
|
|
|
+ err := SupportedTunnelProtocols.Validate()
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unexpected Validate error: %s", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ } else {
|
|
|
+
|
|
|
+ err := SupportedTunnelProtocols.Validate()
|
|
|
+ if err == nil {
|
|
|
+ t.Errorf("unexpected Validate success")
|
|
|
+ }
|
|
|
+
|
|
|
+ err = validSupportedProtocols.Validate()
|
|
|
+ if err != nil {
|
|
|
+ t.Errorf("unexpected Validate error: %s", err)
|
|
|
+ }
|
|
|
+
|
|
|
}
|
|
|
|
|
|
invalidProtocols := TunnelProtocols{"OSSH", "INVALID-PROTOCOL"}
|
|
|
- err = invalidProtocols.Validate()
|
|
|
+ err := invalidProtocols.Validate()
|
|
|
if err == nil {
|
|
|
t.Errorf("unexpected Validate success")
|
|
|
}
|
|
|
@@ -47,8 +73,8 @@ func TestTunnelProtocolValidation(t *testing.T) {
|
|
|
|
|
|
prunedProtocols := pruneProtocols.PruneInvalid()
|
|
|
|
|
|
- if !reflect.DeepEqual(prunedProtocols, SupportedTunnelProtocols) {
|
|
|
- t.Errorf("unexpected %+v != %+v", prunedProtocols, SupportedTunnelProtocols)
|
|
|
+ if !reflect.DeepEqual(prunedProtocols, validSupportedProtocols) {
|
|
|
+ t.Errorf("unexpected %+v != %+v", prunedProtocols, validSupportedProtocols)
|
|
|
}
|
|
|
}
|
|
|
|