Browse Source

Fix: add missing parameter validation and test cases

Rod Hynes 3 years ago
parent
commit
e7cb9e0efc

+ 33 - 0
psiphon/common/parameters/parameters.go

@@ -833,6 +833,17 @@ func (p *Parameters) Set(
 	serverPacketManipulationSpecs, _ :=
 		serverPacketManipulationSpecsValue.(PacketManipulationSpecs)
 
+	// Special case: ProtocolTransformScopedSpecNames will reference
+	// ProtocolTransformSpecs.
+
+	dnsResolverProtocolTransformSpecsValue, err := getAppliedValue(
+		DNSResolverProtocolTransformSpecs, parameters, applyParameters)
+	if err != nil {
+		return nil, errors.Trace(err)
+	}
+	dnsResolverProtocolTransformSpecs, _ :=
+		dnsResolverProtocolTransformSpecsValue.(transforms.Specs)
+
 	for i := 0; i < len(applyParameters); i++ {
 
 		count := 0
@@ -995,6 +1006,28 @@ func (p *Parameters) Set(
 					}
 					return nil, errors.Trace(err)
 				}
+			case transforms.Specs:
+				err := v.Validate()
+				if err != nil {
+					if skipOnError {
+						continue
+					}
+					return nil, errors.Trace(err)
+				}
+			case transforms.ScopedSpecNames:
+
+				var specs transforms.Specs
+				if name == DNSResolverProtocolTransformScopedSpecNames {
+					specs = dnsResolverProtocolTransformSpecs
+				}
+
+				err := v.Validate(specs)
+				if err != nil {
+					if skipOnError {
+						continue
+					}
+					return nil, errors.Trace(err)
+				}
 			}
 
 			// Enforce any minimums. Assumes defaultParameters[name]

+ 19 - 1
psiphon/common/parameters/parameters_test.go

@@ -28,6 +28,7 @@ import (
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
 )
 
 func TestGetDefaultParameters(t *testing.T) {
@@ -159,8 +160,25 @@ func TestGetDefaultParameters(t *testing.T) {
 			if !reflect.DeepEqual(v, g) {
 				t.Fatalf("TunnelProtocolPortLists returned %+v expected %+v", g, v)
 			}
+		case LabeledCIDRs:
+			for label, CIDRs := range v {
+				g := p.Get().LabeledCIDRs(name, label)
+				if !reflect.DeepEqual(CIDRs, g) {
+					t.Fatalf("LabeledCIDRs returned %+v expected %+v", g, CIDRs)
+				}
+			}
+		case transforms.Specs:
+			g := p.Get().ProtocolTransformSpecs(name)
+			if !reflect.DeepEqual(v, g) {
+				t.Fatalf("ProtocolTransformSpecs returned %+v expected %+v", g, v)
+			}
+		case transforms.ScopedSpecNames:
+			g := p.Get().ProtocolTransformScopedSpecNames(name)
+			if !reflect.DeepEqual(v, g) {
+				t.Fatalf("ProtocolTransformScopedSpecNames returned %+v expected %+v", g, v)
+			}
 		default:
-			t.Fatalf("Unhandled default type: %s", name)
+			t.Fatalf("Unhandled default type: %s (%T)", name, defaults.value)
 		}
 	}
 }