Ver código fonte

Add listable

Fangliding 4 meses atrás
pai
commit
37efc4237a

+ 1 - 1
infra/conf/cfgcommon/duration/duration.go → infra/conf/cfgcommon/types/duration.go

@@ -1,4 +1,4 @@
-package duration
+package types
 
 import (
 	"encoding/json"

+ 4 - 4
infra/conf/cfgcommon/duration/duration_test.go → infra/conf/cfgcommon/types/duration_test.go

@@ -1,20 +1,20 @@
-package duration_test
+package types_test
 
 import (
 	"encoding/json"
 	"testing"
 	"time"
 
-	"github.com/xtls/xray-core/infra/conf/cfgcommon/duration"
+	"github.com/xtls/xray-core/infra/conf/cfgcommon/types"
 )
 
 type testWithDuration struct {
-	Duration duration.Duration
+	Duration types.Duration
 }
 
 func TestDurationJSON(t *testing.T) {
 	expected := &testWithDuration{
-		Duration: duration.Duration(time.Hour),
+		Duration: types.Duration(time.Hour),
 	}
 	data, err := json.Marshal(expected)
 	if err != nil {

+ 47 - 0
infra/conf/cfgcommon/types/listable.go

@@ -0,0 +1,47 @@
+package types
+
+import (
+	"encoding/json"
+	"reflect"
+	"slices"
+	"strings"
+)
+
+// Listable allows a field to be unmarshalled from a single object or a list of objects.
+// If the json input is a single object, it will be stored as a slice with one element.
+// If the json input is null or empty or a single empty object, it will be nil.
+type Listable[T any] []T
+
+func (l *Listable[T]) UnmarshalJSON(data []byte) error {
+	var v T
+	if len(data) != 0 && !slices.Equal(data, []byte("null")) && data[0] != '[' {
+		if err := json.Unmarshal(data, &v); err == nil {
+			// make the list nil if the single value is the zero value
+			var zero T
+			if reflect.DeepEqual(v, zero) {
+				return nil
+			}
+			*l = []T{v}
+			return err
+		}
+	}
+	return json.Unmarshal(data, (*[]T)(l))
+}
+
+// ListableSimpleString is like Listable[string], but able to separate by `~`
+type ListableSimpleString []string
+
+func (l *ListableSimpleString) UnmarshalJSON(data []byte) error {
+	var v string
+	if len(data) != 0 && !slices.Equal(data, []byte("null")) && data[0] != '[' {
+		if err := json.Unmarshal(data, &v); err == nil {
+			if v == "" {
+				// make the list nil if the single value is empty string
+				return nil
+			}
+			*l = strings.Split(v, "~")
+			return nil
+		}
+	}
+	return json.Unmarshal(data, (*[]string)(l))
+}

+ 159 - 0
infra/conf/cfgcommon/types/listable_test.go

@@ -0,0 +1,159 @@
+package types_test
+
+import (
+	"encoding/json"
+	"slices"
+	"testing"
+
+	"github.com/xtls/xray-core/infra/conf/cfgcommon/types"
+)
+
+type TestGroup[T any] struct {
+	name     string
+	input    string
+	expected []T
+}
+
+// intentionally to be so chaos
+var rawJson = `{
+	"field": 
+	["value1",
+			"value2", "value3"
+			]
+}`
+
+func TestListableUnmarshal(t *testing.T) {
+	type TestStruct struct {
+		Field types.Listable[string] `json:"field"`
+	}
+
+	tests := []TestGroup[string]{
+		{
+			name:     "SingleString",
+			input:    `{"field": "hello"}`,
+			expected: []string{"hello"},
+		},
+		{
+			name:     "ArrayString",
+			input:    `{"field": ["value1", "value2", "value3"]}`,
+			expected: []string{"value1", "value2", "value3"},
+		},
+		{
+			name:     "ComplexArray",
+			input:    rawJson,
+			expected: []string{"value1", "value2", "value3"},
+		},
+		{
+			name:     "SingleStringWithSpace",
+			input:    `{"field":   "hello"  }`,
+			expected: []string{"hello"},
+		},
+		{
+			name:     "ArrayWithSpace",
+			input:    `{"field":   [ "a", "b" ]  }`,
+			expected: []string{"a", "b"},
+		},
+		{
+			name:     "SingleEmptyString",
+			input:    `{"field": ""}`,
+			expected: nil,
+		},
+		{
+			name:     "Null",
+			input:    `{"field": null}`,
+			expected: nil,
+		},
+		{
+			name:     "Missing (default)",
+			input:    `{}`,
+			expected: nil,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			var ts TestStruct
+			err := json.Unmarshal([]byte(tt.input), &ts)
+			if err != nil {
+				t.Fatalf("Unmarshal failed: %v", err)
+			}
+			if !slices.Equal([]string(ts.Field), tt.expected) {
+				t.Errorf("Expected %v, got %v", tt.expected, ts.Field)
+			}
+		})
+	}
+}
+
+func TestListableInt(t *testing.T) {
+	tests := []TestGroup[int]{
+		{
+			name:     "SingleInt",
+			input:    `123`,
+			expected: []int{123},
+		},
+		{
+			name:     "ArrayInt",
+			input:    `[1, 2]`,
+			expected: []int{1, 2},
+		},
+		{
+			name:     "Null",
+			input:    `null`,
+			expected: nil,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			var l types.Listable[int]
+			err := json.Unmarshal([]byte(tt.input), &l)
+			if err != nil {
+				t.Fatalf("Unmarshal failed: %v", err)
+			}
+			if !slices.Equal([]int(l), tt.expected) {
+				t.Errorf("Expected %v, got %v", tt.expected, l)
+			}
+		})
+	}
+}
+
+func TestListableSimpleString(t *testing.T) {
+	type TestStruct struct {
+		Field types.ListableSimpleString `json:"field"`
+	}
+
+	tests := []TestGroup[string]{
+		{
+			name:     "SingleString",
+			input:    `{"field": "singleValue"}`,
+			expected: []string{"singleValue"},
+		},
+		{
+			name:     "ArrayString",
+			input:    `{"field": ["value1", "value2", "value3"]}`,
+			expected: []string{"value1", "value2", "value3"},
+		},
+		{
+			name:     "SingleEmptyString",
+			input:    `{"field": ""}`,
+			expected: nil,
+		},
+		{
+			name:     "WaveSplit",
+			input:    `{"field": "value1~value2~value3"}`,
+			expected: []string{"value1", "value2", "value3"},
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			var ts TestStruct
+			err := json.Unmarshal([]byte(tt.input), &ts)
+			if err != nil {
+				t.Fatalf("Unmarshal failed: %v", err)
+			}
+			if !slices.Equal([]string(ts.Field), tt.expected) {
+				t.Errorf("Expected %v, got %v", tt.expected, ts.Field)
+			}
+		})
+	}
+}

+ 5 - 5
infra/conf/observatory.go

@@ -6,14 +6,14 @@ import (
 	"github.com/xtls/xray-core/app/observatory"
 	"github.com/xtls/xray-core/app/observatory/burst"
 	"github.com/xtls/xray-core/common/errors"
-	"github.com/xtls/xray-core/infra/conf/cfgcommon/duration"
+	"github.com/xtls/xray-core/infra/conf/cfgcommon/types"
 )
 
 type ObservatoryConfig struct {
-	SubjectSelector   []string          `json:"subjectSelector"`
-	ProbeURL          string            `json:"probeURL"`
-	ProbeInterval     duration.Duration `json:"probeInterval"`
-	EnableConcurrency bool              `json:"enableConcurrency"`
+	SubjectSelector   []string       `json:"subjectSelector"`
+	ProbeURL          string         `json:"probeURL"`
+	ProbeInterval     types.Duration `json:"probeInterval"`
+	EnableConcurrency bool           `json:"enableConcurrency"`
 }
 
 func (o *ObservatoryConfig) Build() (proto.Message, error) {

+ 11 - 10
infra/conf/router_strategy.go

@@ -1,12 +1,13 @@
 package conf
 
 import (
-	"google.golang.org/protobuf/proto"
 	"strings"
 
+	"google.golang.org/protobuf/proto"
+
 	"github.com/xtls/xray-core/app/observatory/burst"
 	"github.com/xtls/xray-core/app/router"
-	"github.com/xtls/xray-core/infra/conf/cfgcommon/duration"
+	"github.com/xtls/xray-core/infra/conf/cfgcommon/types"
 )
 
 const (
@@ -36,23 +37,23 @@ type strategyLeastLoadConfig struct {
 	// weight settings
 	Costs []*router.StrategyWeight `json:"costs,omitempty"`
 	// ping rtt baselines
-	Baselines []duration.Duration `json:"baselines,omitempty"`
+	Baselines []types.Duration `json:"baselines,omitempty"`
 	// expected nodes count to select
 	Expected int32 `json:"expected,omitempty"`
 	// max acceptable rtt, filter away high delay nodes. default 0
-	MaxRTT duration.Duration `json:"maxRTT,omitempty"`
+	MaxRTT types.Duration `json:"maxRTT,omitempty"`
 	// acceptable failure rate
 	Tolerance float64 `json:"tolerance,omitempty"`
 }
 
 // healthCheckSettings holds settings for health Checker
 type healthCheckSettings struct {
-	Destination   string            `json:"destination"`
-	Connectivity  string            `json:"connectivity"`
-	Interval      duration.Duration `json:"interval"`
-	SamplingCount int               `json:"sampling"`
-	Timeout       duration.Duration `json:"timeout"`
-	HttpMethod    string            `json:"httpMethod"`
+	Destination   string         `json:"destination"`
+	Connectivity  string         `json:"connectivity"`
+	Interval      types.Duration `json:"interval"`
+	SamplingCount int            `json:"sampling"`
+	Timeout       types.Duration `json:"timeout"`
+	HttpMethod    string         `json:"httpMethod"`
 }
 
 func (h healthCheckSettings) Build() (proto.Message, error) {