Просмотр исходного кода

Recover from all panics in regen package

Amir Khan 2 лет назад
Родитель
Сommit
0f5e4e4142
3 измененных файлов с 37 добавлено и 34 удалено
  1. 6 1
      psiphon/common/regen/internal_generator.go
  2. 17 6
      psiphon/common/regen/regen.go
  3. 14 27
      psiphon/common/regen/regen_test.go

+ 6 - 1
psiphon/common/regen/internal_generator.go

@@ -78,7 +78,12 @@ type internalGenerator struct {
 	GenerateFunc func() ([]byte, error)
 	GenerateFunc func() ([]byte, error)
 }
 }
 
 
-func (gen *internalGenerator) Generate() ([]byte, error) {
+func (gen *internalGenerator) Generate() (b []byte, err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			err = fmt.Errorf("panicked on bad input: Generate: %v", r)
+		}
+	}()
 	return gen.GenerateFunc()
 	return gen.GenerateFunc()
 }
 }
 
 

+ 17 - 6
psiphon/common/regen/regen.go

@@ -186,8 +186,8 @@ func (a *GeneratorArgs) initialize() error {
 	}
 	}
 
 
 	if a.MinUnboundedRepeatCount > a.MaxUnboundedRepeatCount {
 	if a.MinUnboundedRepeatCount > a.MaxUnboundedRepeatCount {
-		panic(fmt.Sprintf("MinUnboundedRepeatCount(%d) > MaxUnboundedRepeatCount(%d)",
-			a.MinUnboundedRepeatCount, a.MaxUnboundedRepeatCount))
+		return fmt.Errorf("MinUnboundedRepeatCount(%d) > MaxUnboundedRepeatCount(%d)",
+			a.MinUnboundedRepeatCount, a.MaxUnboundedRepeatCount)
 	}
 	}
 
 
 	if a.CaptureGroupHandler == nil {
 	if a.CaptureGroupHandler == nil {
@@ -199,11 +199,11 @@ func (a *GeneratorArgs) initialize() error {
 
 
 // Rng returns the random number generator used by generators.
 // Rng returns the random number generator used by generators.
 // Panics if called before the GeneratorArgs has been initialized by NewGenerator.
 // Panics if called before the GeneratorArgs has been initialized by NewGenerator.
-func (a *GeneratorArgs) Rng() *rand.Rand {
+func (a *GeneratorArgs) Rng() (*rand.Rand, error) {
 	if a.rng == nil {
 	if a.rng == nil {
-		panic("GeneratorArgs has not been initialized by NewGenerator yet")
+		return nil, fmt.Errorf("GeneratorArgs has not been initialized by NewGenerator yet")
 	}
 	}
-	return a.rng
+	return a.rng, nil
 }
 }
 
 
 // Generator generates random bytes or strings.
 // Generator generates random bytes or strings.
@@ -219,7 +219,12 @@ If args is nil, default values are used.
 This function does not seed the default RNG, so you must call rand.Seed() if you want
 This function does not seed the default RNG, so you must call rand.Seed() if you want
 non-deterministic strings.
 non-deterministic strings.
 */
 */
-func GenerateString(pattern string) (string, error) {
+func GenerateString(pattern string) (str string, err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			err = fmt.Errorf("panicked on bad input: GenerateString: %v", r)
+		}
+	}()
 	generator, err := NewGenerator(pattern, nil)
 	generator, err := NewGenerator(pattern, nil)
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
@@ -237,6 +242,12 @@ func GenerateString(pattern string) (string, error) {
 // character range. This makes it impossible to infer the original negated
 // character range. This makes it impossible to infer the original negated
 // character class.
 // character class.
 func NewGenerator(pattern string, inputArgs *GeneratorArgs) (generator Generator, err error) {
 func NewGenerator(pattern string, inputArgs *GeneratorArgs) (generator Generator, err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			err = fmt.Errorf("panicked on bad input: NewGenerator: %v", r)
+		}
+	}()
+
 	args := GeneratorArgs{}
 	args := GeneratorArgs{}
 
 
 	// Copy inputArgs so the caller can't change them.
 	// Copy inputArgs so the caller can't change them.

+ 14 - 27
psiphon/common/regen/regen_test.go

@@ -188,15 +188,16 @@ func TestGeneratorArgs(t *testing.T) {
 		}
 		}
 	})
 	})
 
 
-	t.Run("Panics if repeat bounds are invalid", func(t *testing.T) {
+	t.Run("Error if repeat bounds are invalid", func(t *testing.T) {
 		args := &GeneratorArgs{
 		args := &GeneratorArgs{
 			MinUnboundedRepeatCount: 2,
 			MinUnboundedRepeatCount: 2,
 			MaxUnboundedRepeatCount: 1,
 			MaxUnboundedRepeatCount: 1,
 		}
 		}
 
 
-		shouldPanicWith(t, func() {
-			_ = args.initialize()
-		}, "MinUnboundedRepeatCount(2) > MaxUnboundedRepeatCount(1)")
+		err := args.initialize()
+		if err.Error() != "MinUnboundedRepeatCount(2) > MaxUnboundedRepeatCount(1)" {
+			t.Fatalf("unexpected error: %v", err)
+		}
 	})
 	})
 
 
 	t.Run("Allow equal repeat bounds", func(t *testing.T) {
 	t.Run("Allow equal repeat bounds", func(t *testing.T) {
@@ -215,12 +216,13 @@ func TestGeneratorArgs(t *testing.T) {
 
 
 	t.Run("Rng", func(t *testing.T) {
 	t.Run("Rng", func(t *testing.T) {
 
 
-		t.Run("Panics if called before initialize", func(t *testing.T) {
+		t.Run("Error if called before initialize", func(t *testing.T) {
 			args := &GeneratorArgs{}
 			args := &GeneratorArgs{}
 
 
-			shouldPanic(t, func() {
-				_ = args.Rng()
-			})
+			_, err := args.Rng()
+			if err == nil {
+				t.Fatal("expected error")
+			}
 		})
 		})
 
 
 		t.Run("Non-nil after initialize", func(t *testing.T) {
 		t.Run("Non-nil after initialize", func(t *testing.T) {
@@ -229,7 +231,10 @@ func TestGeneratorArgs(t *testing.T) {
 			if err != nil {
 			if err != nil {
 				t.Fatal(err)
 				t.Fatal(err)
 			}
 			}
-			rng := args.Rng()
+			rng, err := args.Rng()
+			if err != nil {
+				t.Fatal(err)
+			}
 			if rng == nil {
 			if rng == nil {
 				t.Fatal("expected non-nil")
 				t.Fatal("expected non-nil")
 			}
 			}
@@ -906,24 +911,6 @@ func max(values ...int) int {
 	return m
 	return m
 }
 }
 
 
-func shouldPanic(t *testing.T, f func()) {
-	t.Helper()
-	defer func() { _ = recover() }()
-	f()
-	t.Errorf("should have panicked")
-}
-
-func shouldPanicWith(t *testing.T, f func(), expected string) {
-	t.Helper()
-	defer func() {
-		if r := recover(); r != expected {
-			t.Errorf("expected panic %q, got %q", expected, r)
-		}
-	}()
-	f()
-	t.Errorf("should have panicked")
-}
-
 func shouldNotPanic(t *testing.T, f func()) {
 func shouldNotPanic(t *testing.T, f func()) {
 	t.Helper()
 	t.Helper()
 	defer func() {
 	defer func() {