Просмотр исходного кода

Recover from all panics in regen package

Amir Khan 2 лет назад
Родитель
Сommit
0f5e4e4142

+ 6 - 1
psiphon/common/regen/internal_generator.go

@@ -78,7 +78,12 @@ type internalGenerator struct {
 	GenerateFunc func() ([]byte, error)
 }
 
-func (gen *internalGenerator) Generate() ([]byte, error) {
+func (gen *internalGenerator) Generate() (b []byte, err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			err = fmt.Errorf("panicked on bad input: Generate: %v", r)
+		}
+	}()
 	return gen.GenerateFunc()
 }
 

+ 17 - 6
psiphon/common/regen/regen.go

@@ -186,8 +186,8 @@ func (a *GeneratorArgs) initialize() error {
 	}
 
 	if a.MinUnboundedRepeatCount > a.MaxUnboundedRepeatCount {
-		panic(fmt.Sprintf("MinUnboundedRepeatCount(%d) > MaxUnboundedRepeatCount(%d)",
-			a.MinUnboundedRepeatCount, a.MaxUnboundedRepeatCount))
+		return fmt.Errorf("MinUnboundedRepeatCount(%d) > MaxUnboundedRepeatCount(%d)",
+			a.MinUnboundedRepeatCount, a.MaxUnboundedRepeatCount)
 	}
 
 	if a.CaptureGroupHandler == nil {
@@ -199,11 +199,11 @@ func (a *GeneratorArgs) initialize() error {
 
 // Rng returns the random number generator used by generators.
 // Panics if called before the GeneratorArgs has been initialized by NewGenerator.
-func (a *GeneratorArgs) Rng() *rand.Rand {
+func (a *GeneratorArgs) Rng() (*rand.Rand, error) {
 	if a.rng == nil {
-		panic("GeneratorArgs has not been initialized by NewGenerator yet")
+		return nil, fmt.Errorf("GeneratorArgs has not been initialized by NewGenerator yet")
 	}
-	return a.rng
+	return a.rng, nil
 }
 
 // Generator generates random bytes or strings.
@@ -219,7 +219,12 @@ If args is nil, default values are used.
 This function does not seed the default RNG, so you must call rand.Seed() if you want
 non-deterministic strings.
 */
-func GenerateString(pattern string) (string, error) {
+func GenerateString(pattern string) (str string, err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			err = fmt.Errorf("panicked on bad input: GenerateString: %v", r)
+		}
+	}()
 	generator, err := NewGenerator(pattern, nil)
 	if err != nil {
 		return "", err
@@ -237,6 +242,12 @@ func GenerateString(pattern string) (string, error) {
 // character range. This makes it impossible to infer the original negated
 // character class.
 func NewGenerator(pattern string, inputArgs *GeneratorArgs) (generator Generator, err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			err = fmt.Errorf("panicked on bad input: NewGenerator: %v", r)
+		}
+	}()
+
 	args := GeneratorArgs{}
 
 	// Copy inputArgs so the caller can't change them.

+ 14 - 27
psiphon/common/regen/regen_test.go

@@ -188,15 +188,16 @@ func TestGeneratorArgs(t *testing.T) {
 		}
 	})
 
-	t.Run("Panics if repeat bounds are invalid", func(t *testing.T) {
+	t.Run("Error if repeat bounds are invalid", func(t *testing.T) {
 		args := &GeneratorArgs{
 			MinUnboundedRepeatCount: 2,
 			MaxUnboundedRepeatCount: 1,
 		}
 
-		shouldPanicWith(t, func() {
-			_ = args.initialize()
-		}, "MinUnboundedRepeatCount(2) > MaxUnboundedRepeatCount(1)")
+		err := args.initialize()
+		if err.Error() != "MinUnboundedRepeatCount(2) > MaxUnboundedRepeatCount(1)" {
+			t.Fatalf("unexpected error: %v", err)
+		}
 	})
 
 	t.Run("Allow equal repeat bounds", func(t *testing.T) {
@@ -215,12 +216,13 @@ func TestGeneratorArgs(t *testing.T) {
 
 	t.Run("Rng", func(t *testing.T) {
 
-		t.Run("Panics if called before initialize", func(t *testing.T) {
+		t.Run("Error if called before initialize", func(t *testing.T) {
 			args := &GeneratorArgs{}
 
-			shouldPanic(t, func() {
-				_ = args.Rng()
-			})
+			_, err := args.Rng()
+			if err == nil {
+				t.Fatal("expected error")
+			}
 		})
 
 		t.Run("Non-nil after initialize", func(t *testing.T) {
@@ -229,7 +231,10 @@ func TestGeneratorArgs(t *testing.T) {
 			if err != nil {
 				t.Fatal(err)
 			}
-			rng := args.Rng()
+			rng, err := args.Rng()
+			if err != nil {
+				t.Fatal(err)
+			}
 			if rng == nil {
 				t.Fatal("expected non-nil")
 			}
@@ -906,24 +911,6 @@ func max(values ...int) int {
 	return m
 }
 
-func shouldPanic(t *testing.T, f func()) {
-	t.Helper()
-	defer func() { _ = recover() }()
-	f()
-	t.Errorf("should have panicked")
-}
-
-func shouldPanicWith(t *testing.T, f func(), expected string) {
-	t.Helper()
-	defer func() {
-		if r := recover(); r != expected {
-			t.Errorf("expected panic %q, got %q", expected, r)
-		}
-	}()
-	f()
-	t.Errorf("should have panicked")
-}
-
 func shouldNotPanic(t *testing.T, f func()) {
 	t.Helper()
 	defer func() {