فهرست منبع

Proxy: Add Hysteria 2 inbound & transport (supports listening port range, Salamander finalmask) (#5679)

https://github.com/XTLS/Xray-core/pull/5679#issuecomment-3888548778

Closes https://github.com/XTLS/Xray-core/issues/5605
LjhAUMEM 4 ماه پیش
والد
کامیت
6a909b2507

+ 9 - 0
app/proxyman/inbound/worker.go

@@ -19,6 +19,8 @@ import (
 	"github.com/xtls/xray-core/features/routing"
 	"github.com/xtls/xray-core/features/stats"
 	"github.com/xtls/xray-core/proxy"
+	"github.com/xtls/xray-core/proxy/hysteria/account"
+	hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx"
 	"github.com/xtls/xray-core/transport/internet"
 	"github.com/xtls/xray-core/transport/internet/stat"
 	"github.com/xtls/xray-core/transport/internet/tcp"
@@ -138,6 +140,13 @@ func (w *tcpWorker) Proxy() proxy.Inbound {
 
 func (w *tcpWorker) Start() error {
 	ctx := context.Background()
+
+	type HysteriaInboundValidator interface{ HysteriaInboundValidator() *account.Validator }
+	if v, ok := w.proxy.(HysteriaInboundValidator); ok {
+		ctx = hyCtx.ContextWithRequireDatagram(ctx, true)
+		ctx = hyCtx.ContextWithValidator(ctx, v.HysteriaInboundValidator())
+	}
+
 	hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) {
 		go w.callback(conn)
 	})

+ 32 - 0
infra/conf/hysteria.go

@@ -3,7 +3,9 @@ package conf
 import (
 	"github.com/xtls/xray-core/common/errors"
 	"github.com/xtls/xray-core/common/protocol"
+	"github.com/xtls/xray-core/common/serial"
 	"github.com/xtls/xray-core/proxy/hysteria"
+	"github.com/xtls/xray-core/proxy/hysteria/account"
 	"google.golang.org/protobuf/proto"
 )
 
@@ -27,3 +29,33 @@ func (c *HysteriaClientConfig) Build() (proto.Message, error) {
 
 	return config, nil
 }
+
+type HysteriaUserConfig struct {
+	Auth  string `json:"auth"`
+	Level uint32 `json:"level"`
+	Email string `json:"email"`
+}
+
+type HysteriaServerConfig struct {
+	Version int32                 `json:"version"`
+	Users   []*HysteriaUserConfig `json:"clients"`
+}
+
+func (c *HysteriaServerConfig) Build() (proto.Message, error) {
+	config := new(hysteria.ServerConfig)
+
+	if c.Users != nil {
+		for _, user := range c.Users {
+			account := &account.Account{
+				Auth: user.Auth,
+			}
+			config.Users = append(config.Users, &protocol.User{
+				Email:   user.Email,
+				Level:   user.Level,
+				Account: serial.ToTypedMessage(account),
+			})
+		}
+	}
+
+	return config, nil
+}

+ 42 - 2
infra/conf/transport_internet.go

@@ -508,6 +508,20 @@ type UdpHop struct {
 	Interval *Int32Range     `json:"interval"`
 }
 
+type Masquerade struct {
+	Type string `json:"type"`
+
+	Dir string `json:"dir"`
+
+	Url         string `json:"url"`
+	RewriteHost bool   `json:"rewriteHost"`
+	Insecure    bool   `json:"insecure"`
+
+	Content    string            `json:"content"`
+	Headers    map[string]string `json:"headers"`
+	StatusCode int32             `json:"statusCode"`
+}
+
 type HysteriaConfig struct {
 	Version    int32     `json:"version"`
 	Auth       string    `json:"auth"`
@@ -523,6 +537,10 @@ type HysteriaConfig struct {
 	MaxIdleTimeout              int64  `json:"maxIdleTimeout"`
 	KeepAlivePeriod             int64  `json:"keepAlivePeriod"`
 	DisablePathMTUDiscovery     bool   `json:"disablePathMTUDiscovery"`
+	MaxIncomingStreams          int64  `json:"maxIncomingStreams"`
+
+	UdpIdleTimeout int64      `json:"udpIdleTimeout"`
+	Masquerade     Masquerade `json:"masquerade"`
 }
 
 func (c *HysteriaConfig) Build() (proto.Message, error) {
@@ -556,10 +574,10 @@ func (c *HysteriaConfig) Build() (proto.Message, error) {
 	}
 
 	if up > 0 && up < 65536 {
-		return nil, errors.New("Up must be at least 65536 Bps")
+		return nil, errors.New("Up must be at least 65536 bytes per second")
 	}
 	if down > 0 && down < 65536 {
-		return nil, errors.New("Down must be at least 65536 Bps")
+		return nil, errors.New("Down must be at least 65536 bytes per second")
 	}
 	if (inertvalMin != 0 && inertvalMin < 5) || (inertvalMax != 0 && inertvalMax < 5) {
 		return nil, errors.New("Interval must be at least 5")
@@ -583,6 +601,12 @@ func (c *HysteriaConfig) Build() (proto.Message, error) {
 	if c.KeepAlivePeriod != 0 && (c.KeepAlivePeriod < 2 || c.KeepAlivePeriod > 60) {
 		return nil, errors.New("KeepAlivePeriod must be between 2 and 60")
 	}
+	if c.MaxIncomingStreams != 0 && c.MaxIncomingStreams < 8 {
+		return nil, errors.New("MaxIncomingStreams must be at least 8")
+	}
+	if c.UdpIdleTimeout != 0 && (c.UdpIdleTimeout < 2 || c.UdpIdleTimeout > 600) {
+		return nil, errors.New("UdpIdleTimeout must be between 2 and 600")
+	}
 
 	config := &hysteria.Config{}
 	config.Version = c.Version
@@ -600,6 +624,16 @@ func (c *HysteriaConfig) Build() (proto.Message, error) {
 	config.MaxIdleTimeout = c.MaxIdleTimeout
 	config.KeepAlivePeriod = c.KeepAlivePeriod
 	config.DisablePathMtuDiscovery = c.DisablePathMTUDiscovery
+	config.MaxIncomingStreams = c.MaxIncomingStreams
+	config.UdpIdleTimeout = c.UdpIdleTimeout
+	config.MasqType = c.Masquerade.Type
+	config.MasqFile = c.Masquerade.Dir
+	config.MasqUrl = c.Masquerade.Url
+	config.MasqUrlRewriteHost = c.Masquerade.RewriteHost
+	config.MasqUrlInsecure = c.Masquerade.Insecure
+	config.MasqString = c.Masquerade.Content
+	config.MasqStringHeaders = c.Masquerade.Headers
+	config.MasqStringStatusCode = c.Masquerade.StatusCode
 
 	if config.InitStreamReceiveWindow == 0 {
 		config.InitStreamReceiveWindow = 8388608
@@ -619,6 +653,12 @@ func (c *HysteriaConfig) Build() (proto.Message, error) {
 	// if config.KeepAlivePeriod == 0 {
 	// 	config.KeepAlivePeriod = 10
 	// }
+	if config.MaxIncomingStreams == 0 {
+		config.MaxIncomingStreams = 1024
+	}
+	if config.UdpIdleTimeout == 0 {
+		config.UdpIdleTimeout = 60
+	}
 
 	return config, nil
 }

+ 1 - 0
infra/conf/xray.go

@@ -33,6 +33,7 @@ var (
 		"vmess":         func() interface{} { return new(VMessInboundConfig) },
 		"trojan":        func() interface{} { return new(TrojanServerConfig) },
 		"wireguard":     func() interface{} { return &WireGuardConfig{IsClient: false} },
+		"hysteria":      func() interface{} { return new(HysteriaServerConfig) },
 		"tun":           func() interface{} { return new(TunConfig) },
 	}, "protocol", "settings")
 

+ 129 - 0
proxy/hysteria/account/config.go

@@ -0,0 +1,129 @@
+package account
+
+import (
+	"sync"
+
+	"github.com/xtls/xray-core/common/errors"
+	"github.com/xtls/xray-core/common/protocol"
+
+	"google.golang.org/protobuf/proto"
+)
+
+func (a *Account) AsAccount() (protocol.Account, error) {
+	return &MemoryAccount{
+		Auth: a.Auth,
+	}, nil
+}
+
+type MemoryAccount struct {
+	Auth string
+}
+
+func (a *MemoryAccount) Equals(another protocol.Account) bool {
+	if account, ok := another.(*MemoryAccount); ok {
+		return a.Auth == account.Auth
+	}
+	return false
+}
+
+func (a *MemoryAccount) ToProto() proto.Message {
+	return &Account{
+		Auth: a.Auth,
+	}
+}
+
+type Validator struct {
+	emails map[string]struct{}
+	users  map[string]*protocol.MemoryUser
+
+	mutex sync.Mutex
+}
+
+func NewValidator() *Validator {
+	return &Validator{
+		emails: make(map[string]struct{}),
+		users:  make(map[string]*protocol.MemoryUser),
+	}
+}
+
+func (v *Validator) Add(u *protocol.MemoryUser) error {
+	v.mutex.Lock()
+	defer v.mutex.Unlock()
+
+	if u.Email != "" {
+		if _, ok := v.emails[u.Email]; ok {
+			return errors.New("User ", u.Email, " already exists.")
+		}
+		v.emails[u.Email] = struct{}{}
+	}
+	v.users[u.Account.(*MemoryAccount).Auth] = u
+
+	return nil
+}
+
+func (v *Validator) Del(email string) error {
+	if email == "" {
+		return errors.New("Email must not be empty.")
+	}
+
+	v.mutex.Lock()
+	defer v.mutex.Unlock()
+
+	if _, ok := v.emails[email]; !ok {
+		return errors.New("User ", email, " not found.")
+	}
+	delete(v.emails, email)
+	for key, user := range v.users {
+		if user.Email == email {
+			delete(v.users, key)
+			break
+		}
+	}
+
+	return nil
+}
+
+func (v *Validator) Get(auth string) *protocol.MemoryUser {
+	v.mutex.Lock()
+	defer v.mutex.Unlock()
+
+	return v.users[auth]
+}
+
+func (v *Validator) GetByEmail(email string) *protocol.MemoryUser {
+	if email == "" {
+		return nil
+	}
+
+	v.mutex.Lock()
+	defer v.mutex.Unlock()
+
+	if _, ok := v.emails[email]; ok {
+		for _, user := range v.users {
+			if user.Email == email {
+				return user
+			}
+		}
+	}
+
+	return nil
+}
+
+func (v *Validator) GetAll() []*protocol.MemoryUser {
+	v.mutex.Lock()
+	defer v.mutex.Unlock()
+
+	var users = make([]*protocol.MemoryUser, 0, len(v.users))
+	for _, user := range v.users {
+		users = append(users, user)
+	}
+
+	return users
+}
+
+func (v *Validator) GetCount() int64 {
+	v.mutex.Lock()
+	defer v.mutex.Unlock()
+
+	return int64(len(v.users))
+}

+ 123 - 0
proxy/hysteria/account/config.pb.go

@@ -0,0 +1,123 @@
+// Code generated by protoc-gen-go. DO NOT EDIT.
+// versions:
+// 	protoc-gen-go v1.36.11
+// 	protoc        v6.33.5
+// source: proxy/hysteria/account/config.proto
+
+package account
+
+import (
+	protoreflect "google.golang.org/protobuf/reflect/protoreflect"
+	protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+	reflect "reflect"
+	sync "sync"
+	unsafe "unsafe"
+)
+
+const (
+	// Verify that this generated code is sufficiently up-to-date.
+	_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
+	// Verify that runtime/protoimpl is sufficiently up-to-date.
+	_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
+)
+
+type Account struct {
+	state         protoimpl.MessageState `protogen:"open.v1"`
+	Auth          string                 `protobuf:"bytes,1,opt,name=auth,proto3" json:"auth,omitempty"`
+	unknownFields protoimpl.UnknownFields
+	sizeCache     protoimpl.SizeCache
+}
+
+func (x *Account) Reset() {
+	*x = Account{}
+	mi := &file_proxy_hysteria_account_config_proto_msgTypes[0]
+	ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+	ms.StoreMessageInfo(mi)
+}
+
+func (x *Account) String() string {
+	return protoimpl.X.MessageStringOf(x)
+}
+
+func (*Account) ProtoMessage() {}
+
+func (x *Account) ProtoReflect() protoreflect.Message {
+	mi := &file_proxy_hysteria_account_config_proto_msgTypes[0]
+	if x != nil {
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		if ms.LoadMessageInfo() == nil {
+			ms.StoreMessageInfo(mi)
+		}
+		return ms
+	}
+	return mi.MessageOf(x)
+}
+
+// Deprecated: Use Account.ProtoReflect.Descriptor instead.
+func (*Account) Descriptor() ([]byte, []int) {
+	return file_proxy_hysteria_account_config_proto_rawDescGZIP(), []int{0}
+}
+
+func (x *Account) GetAuth() string {
+	if x != nil {
+		return x.Auth
+	}
+	return ""
+}
+
+var File_proxy_hysteria_account_config_proto protoreflect.FileDescriptor
+
+const file_proxy_hysteria_account_config_proto_rawDesc = "" +
+	"\n" +
+	"#proxy/hysteria/account/config.proto\x12\x1bxray.proxy.hysteria.account\"\x1d\n" +
+	"\aAccount\x12\x12\n" +
+	"\x04auth\x18\x01 \x01(\tR\x04authBs\n" +
+	"\x1fcom.xray.proxy.hysteria.accountP\x01Z0github.com/xtls/xray-core/proxy/hysteria/account\xaa\x02\x1bXray.Proxy.Hysteria.Accountb\x06proto3"
+
+var (
+	file_proxy_hysteria_account_config_proto_rawDescOnce sync.Once
+	file_proxy_hysteria_account_config_proto_rawDescData []byte
+)
+
+func file_proxy_hysteria_account_config_proto_rawDescGZIP() []byte {
+	file_proxy_hysteria_account_config_proto_rawDescOnce.Do(func() {
+		file_proxy_hysteria_account_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_proxy_hysteria_account_config_proto_rawDesc), len(file_proxy_hysteria_account_config_proto_rawDesc)))
+	})
+	return file_proxy_hysteria_account_config_proto_rawDescData
+}
+
+var file_proxy_hysteria_account_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
+var file_proxy_hysteria_account_config_proto_goTypes = []any{
+	(*Account)(nil), // 0: xray.proxy.hysteria.account.Account
+}
+var file_proxy_hysteria_account_config_proto_depIdxs = []int32{
+	0, // [0:0] is the sub-list for method output_type
+	0, // [0:0] is the sub-list for method input_type
+	0, // [0:0] is the sub-list for extension type_name
+	0, // [0:0] is the sub-list for extension extendee
+	0, // [0:0] is the sub-list for field type_name
+}
+
+func init() { file_proxy_hysteria_account_config_proto_init() }
+func file_proxy_hysteria_account_config_proto_init() {
+	if File_proxy_hysteria_account_config_proto != nil {
+		return
+	}
+	type x struct{}
+	out := protoimpl.TypeBuilder{
+		File: protoimpl.DescBuilder{
+			GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
+			RawDescriptor: unsafe.Slice(unsafe.StringData(file_proxy_hysteria_account_config_proto_rawDesc), len(file_proxy_hysteria_account_config_proto_rawDesc)),
+			NumEnums:      0,
+			NumMessages:   1,
+			NumExtensions: 0,
+			NumServices:   0,
+		},
+		GoTypes:           file_proxy_hysteria_account_config_proto_goTypes,
+		DependencyIndexes: file_proxy_hysteria_account_config_proto_depIdxs,
+		MessageInfos:      file_proxy_hysteria_account_config_proto_msgTypes,
+	}.Build()
+	File_proxy_hysteria_account_config_proto = out.File
+	file_proxy_hysteria_account_config_proto_goTypes = nil
+	file_proxy_hysteria_account_config_proto_depIdxs = nil
+}

+ 11 - 0
proxy/hysteria/account/config.proto

@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package xray.proxy.hysteria.account;
+option csharp_namespace = "Xray.Proxy.Hysteria.Account";
+option go_package = "github.com/xtls/xray-core/proxy/hysteria/account";
+option java_package = "com.xray.proxy.hysteria.account";
+option java_multiple_files = true;
+
+message Account {
+  string auth = 1;
+}

+ 44 - 25
proxy/hysteria/client.go

@@ -135,6 +135,7 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 			if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil {
 				return errors.New("failed to transport all UDP request").Base(err)
 			}
+
 			return nil
 		}
 
@@ -143,12 +144,14 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 
 			reader := &UDPReader{
 				Reader: conn,
+				buf:    make([]byte, MaxUDPSize),
 				df:     &Defragger{},
 			}
 
 			if err := buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)); err != nil {
 				return errors.New("failed to transport all UDP response").Base(err)
 			}
+
 			return nil
 		}
 
@@ -178,7 +181,6 @@ type UDPWriter struct {
 func (w *UDPWriter) sendMsg(msg *UDPMessage) error {
 	msgN := msg.Serialize(w.buf)
 	if msgN < 0 {
-		// Message larger than buffer, silent drop
 		return nil
 	}
 	_, err := w.Writer.Write(w.buf[:msgN])
@@ -192,10 +194,12 @@ func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 		if b == nil {
 			break
 		}
+
 		addr := w.addr
 		if b.UDP != nil {
 			addr = b.UDP.NetAddr()
 		}
+
 		msg := &UDPMessage{
 			SessionID: 0,
 			PacketID:  0,
@@ -204,47 +208,58 @@ func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 			Addr:      addr,
 			Data:      b.Bytes(),
 		}
-		if err := w.sendMsg(msg); err != nil {
-			var errTooLarge *quic.DatagramTooLargeError
-			if go_errors.As(err, &errTooLarge) {
-				msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
-				fMsgs := FragUDPMessage(msg, int(errTooLarge.MaxDatagramPayloadSize))
-				for _, fMsg := range fMsgs {
-					err := w.sendMsg(&fMsg)
-					if err != nil {
-						b.Release()
-						buf.ReleaseMulti(mb)
-						return err
-					}
+
+		err := w.sendMsg(msg)
+		var errTooLarge *quic.DatagramTooLargeError
+		if go_errors.As(err, &errTooLarge) {
+			msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
+			fMsgs := FragUDPMessage(msg, int(errTooLarge.MaxDatagramPayloadSize))
+			for _, fMsg := range fMsgs {
+				err := w.sendMsg(&fMsg)
+				if err != nil {
+					b.Release()
+					buf.ReleaseMulti(mb)
+					return err
 				}
-			} else {
-				b.Release()
-				buf.ReleaseMulti(mb)
-				return err
 			}
+		} else if err != nil {
+			b.Release()
+			buf.ReleaseMulti(mb)
+			return err
 		}
+
 		b.Release()
 	}
+
 	return nil
 }
 
 type UDPReader struct {
-	Reader io.Reader
-	df     *Defragger
+	Reader    io.Reader
+	buf       []byte
+	df        *Defragger
+	firstMsg  *UDPMessage
+	firstDest *net.Destination
 }
 
 func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
+	if r.firstMsg != nil {
+		buffer := buf.New()
+		buffer.Write(r.firstMsg.Data)
+		buffer.UDP = r.firstDest
+
+		r.firstMsg = nil
+
+		return buf.MultiBuffer{buffer}, nil
+	}
 	for {
-		b := buf.New()
-		_, err := b.ReadFrom(r.Reader)
+		n, err := r.Reader.Read(r.buf)
 		if err != nil {
-			b.Release()
 			return nil, err
 		}
 
-		msg, err := ParseUDPMessage(b.Bytes())
+		msg, err := ParseUDPMessage(r.buf[:n])
 		if err != nil {
-			b.Release()
 			continue
 		}
 
@@ -253,7 +268,11 @@ func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 			continue
 		}
 
-		dest, _ := net.ParseDestination("udp:" + dfMsg.Addr)
+		dest, err := net.ParseDestination("udp:" + dfMsg.Addr)
+		if err != nil {
+			errors.LogDebug(context.Background(), dfMsg.Addr, " ParseDestination err ", err)
+			continue
+		}
 
 		buffer := buf.New()
 		buffer.Write(dfMsg.Data)

+ 2 - 2
proxy/hysteria/config.go

@@ -5,6 +5,6 @@ import (
 )
 
 var (
-	tcpRequestPadding = padding.Padding{Min: 64, Max: 512}
-	// tcpResponsePadding = padding.Padding{Min: 128, Max: 1024}
+	tcpRequestPadding  = padding.Padding{Min: 64, Max: 512}
+	tcpResponsePadding = padding.Padding{Min: 128, Max: 1024}
 )

+ 60 - 11
proxy/hysteria/config.pb.go

@@ -74,14 +74,60 @@ func (x *ClientConfig) GetServer() *protocol.ServerEndpoint {
 	return nil
 }
 
+type ServerConfig struct {
+	state         protoimpl.MessageState `protogen:"open.v1"`
+	Users         []*protocol.User       `protobuf:"bytes,1,rep,name=users,proto3" json:"users,omitempty"`
+	unknownFields protoimpl.UnknownFields
+	sizeCache     protoimpl.SizeCache
+}
+
+func (x *ServerConfig) Reset() {
+	*x = ServerConfig{}
+	mi := &file_proxy_hysteria_config_proto_msgTypes[1]
+	ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+	ms.StoreMessageInfo(mi)
+}
+
+func (x *ServerConfig) String() string {
+	return protoimpl.X.MessageStringOf(x)
+}
+
+func (*ServerConfig) ProtoMessage() {}
+
+func (x *ServerConfig) ProtoReflect() protoreflect.Message {
+	mi := &file_proxy_hysteria_config_proto_msgTypes[1]
+	if x != nil {
+		ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
+		if ms.LoadMessageInfo() == nil {
+			ms.StoreMessageInfo(mi)
+		}
+		return ms
+	}
+	return mi.MessageOf(x)
+}
+
+// Deprecated: Use ServerConfig.ProtoReflect.Descriptor instead.
+func (*ServerConfig) Descriptor() ([]byte, []int) {
+	return file_proxy_hysteria_config_proto_rawDescGZIP(), []int{1}
+}
+
+func (x *ServerConfig) GetUsers() []*protocol.User {
+	if x != nil {
+		return x.Users
+	}
+	return nil
+}
+
 var File_proxy_hysteria_config_proto protoreflect.FileDescriptor
 
 const file_proxy_hysteria_config_proto_rawDesc = "" +
 	"\n" +
-	"\x1bproxy/hysteria/config.proto\x12\x13xray.proxy.hysteria\x1a!common/protocol/server_spec.proto\"f\n" +
+	"\x1bproxy/hysteria/config.proto\x12\x13xray.proxy.hysteria\x1a!common/protocol/server_spec.proto\x1a\x1acommon/protocol/user.proto\"f\n" +
 	"\fClientConfig\x12\x18\n" +
 	"\aversion\x18\x01 \x01(\x05R\aversion\x12<\n" +
-	"\x06server\x18\x02 \x01(\v2$.xray.common.protocol.ServerEndpointR\x06serverB[\n" +
+	"\x06server\x18\x02 \x01(\v2$.xray.common.protocol.ServerEndpointR\x06server\"@\n" +
+	"\fServerConfig\x120\n" +
+	"\x05users\x18\x01 \x03(\v2\x1a.xray.common.protocol.UserR\x05usersB[\n" +
 	"\x17com.xray.proxy.hysteriaP\x01Z(github.com/xtls/xray-core/proxy/hysteria\xaa\x02\x13Xray.Proxy.Hysteriab\x06proto3"
 
 var (
@@ -96,18 +142,21 @@ func file_proxy_hysteria_config_proto_rawDescGZIP() []byte {
 	return file_proxy_hysteria_config_proto_rawDescData
 }
 
-var file_proxy_hysteria_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
+var file_proxy_hysteria_config_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
 var file_proxy_hysteria_config_proto_goTypes = []any{
 	(*ClientConfig)(nil),            // 0: xray.proxy.hysteria.ClientConfig
-	(*protocol.ServerEndpoint)(nil), // 1: xray.common.protocol.ServerEndpoint
+	(*ServerConfig)(nil),            // 1: xray.proxy.hysteria.ServerConfig
+	(*protocol.ServerEndpoint)(nil), // 2: xray.common.protocol.ServerEndpoint
+	(*protocol.User)(nil),           // 3: xray.common.protocol.User
 }
 var file_proxy_hysteria_config_proto_depIdxs = []int32{
-	1, // 0: xray.proxy.hysteria.ClientConfig.server:type_name -> xray.common.protocol.ServerEndpoint
-	1, // [1:1] is the sub-list for method output_type
-	1, // [1:1] is the sub-list for method input_type
-	1, // [1:1] is the sub-list for extension type_name
-	1, // [1:1] is the sub-list for extension extendee
-	0, // [0:1] is the sub-list for field type_name
+	2, // 0: xray.proxy.hysteria.ClientConfig.server:type_name -> xray.common.protocol.ServerEndpoint
+	3, // 1: xray.proxy.hysteria.ServerConfig.users:type_name -> xray.common.protocol.User
+	2, // [2:2] is the sub-list for method output_type
+	2, // [2:2] is the sub-list for method input_type
+	2, // [2:2] is the sub-list for extension type_name
+	2, // [2:2] is the sub-list for extension extendee
+	0, // [0:2] is the sub-list for field type_name
 }
 
 func init() { file_proxy_hysteria_config_proto_init() }
@@ -121,7 +170,7 @@ func file_proxy_hysteria_config_proto_init() {
 			GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
 			RawDescriptor: unsafe.Slice(unsafe.StringData(file_proxy_hysteria_config_proto_rawDesc), len(file_proxy_hysteria_config_proto_rawDesc)),
 			NumEnums:      0,
-			NumMessages:   1,
+			NumMessages:   2,
 			NumExtensions: 0,
 			NumServices:   0,
 		},

+ 5 - 0
proxy/hysteria/config.proto

@@ -7,8 +7,13 @@ option java_package = "com.xray.proxy.hysteria";
 option java_multiple_files = true;
 
 import "common/protocol/server_spec.proto";
+import "common/protocol/user.proto";
 
 message ClientConfig {
   int32 version = 1;
   xray.common.protocol.ServerEndpoint server = 2;
 }
+
+message ServerConfig {
+  repeated xray.common.protocol.User users = 1;
+}

+ 12 - 0
proxy/hysteria/ctx/ctx.go

@@ -2,12 +2,15 @@ package ctx
 
 import (
 	"context"
+
+	"github.com/xtls/xray-core/proxy/hysteria/account"
 )
 
 type key int
 
 const (
 	requireDatagram key = iota
+	validator
 )
 
 func ContextWithRequireDatagram(ctx context.Context, udp bool) context.Context {
@@ -21,3 +24,12 @@ func RequireDatagramFromContext(ctx context.Context) bool {
 	_, ok := ctx.Value(requireDatagram).(struct{})
 	return ok
 }
+
+func ContextWithValidator(ctx context.Context, v *account.Validator) context.Context {
+	return context.WithValue(ctx, validator, v)
+}
+
+func ValidatorFromContext(ctx context.Context) *account.Validator {
+	v, _ := ctx.Value(validator).(*account.Validator)
+	return v
+}

+ 52 - 7
proxy/hysteria/protocol.go

@@ -11,8 +11,6 @@ import (
 )
 
 const (
-	FrameTypeTCPRequest = 0x401
-
 	// Max length values are for preventing DoS attacks
 
 	MaxAddressLength = 2048
@@ -28,22 +26,49 @@ const (
 )
 
 // TCPRequest format:
-// 0x401 (QUIC varint)
 // Address length (QUIC varint)
 // Address (bytes)
 // Padding length (QUIC varint)
 // Padding (bytes)
 
+func ReadTCPRequest(r io.Reader) (string, error) {
+	bReader := quicvarint.NewReader(r)
+	addrLen, err := quicvarint.Read(bReader)
+	if err != nil {
+		return "", err
+	}
+	if addrLen == 0 || addrLen > MaxAddressLength {
+		return "", errors.New("invalid address length")
+	}
+	addrBuf := make([]byte, addrLen)
+	_, err = io.ReadFull(r, addrBuf)
+	if err != nil {
+		return "", err
+	}
+	paddingLen, err := quicvarint.Read(bReader)
+	if err != nil {
+		return "", err
+	}
+	if paddingLen > MaxPaddingLength {
+		return "", errors.New("invalid padding length")
+	}
+	if paddingLen > 0 {
+		_, err = io.CopyN(io.Discard, r, int64(paddingLen))
+		if err != nil {
+			return "", err
+		}
+	}
+	return string(addrBuf), nil
+}
+
 func WriteTCPRequest(w io.Writer, addr string) error {
 	padding := tcpRequestPadding.String()
 	paddingLen := len(padding)
 	addrLen := len(addr)
-	sz := int(quicvarint.Len(FrameTypeTCPRequest)) +
-		int(quicvarint.Len(uint64(addrLen))) + addrLen +
+	sz := int(quicvarint.Len(uint64(addrLen))) + addrLen +
 		int(quicvarint.Len(uint64(paddingLen))) + paddingLen
 	buf := make([]byte, sz)
-	i := varintPut(buf, FrameTypeTCPRequest)
-	i += varintPut(buf[i:], uint64(addrLen))
+	i := varintPut(buf, uint64(addrLen))
 	i += copy(buf[i:], addr)
 	i += varintPut(buf[i:], uint64(paddingLen))
 	copy(buf[i:], padding)
@@ -96,6 +121,26 @@ func ReadTCPResponse(r io.Reader) (bool, string, error) {
 	return status[0] == 0, string(msgBuf), nil
 }
 
+func WriteTCPResponse(w io.Writer, ok bool, msg string) error {
+	padding := tcpResponsePadding.String()
+	paddingLen := len(padding)
+	msgLen := len(msg)
+	sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen +
+		int(quicvarint.Len(uint64(paddingLen))) + paddingLen
+	buf := make([]byte, sz)
+	if ok {
+		buf[0] = 0
+	} else {
+		buf[0] = 1
+	}
+	i := varintPut(buf[1:], uint64(msgLen))
+	i += copy(buf[1+i:], msg)
+	i += varintPut(buf[1+i:], uint64(paddingLen))
+	copy(buf[1+i:], padding)
+	_, err := w.Write(buf)
+	return err
+}
+
 // UDPMessage format:
 // Session ID (uint32 BE)
 // Packet ID (uint16 BE)

+ 198 - 0
proxy/hysteria/server.go

@@ -0,0 +1,198 @@
+package hysteria
+
+import (
+	"context"
+	"io"
+	"time"
+
+	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/buf"
+	"github.com/xtls/xray-core/common/errors"
+	"github.com/xtls/xray-core/common/log"
+	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/protocol"
+	"github.com/xtls/xray-core/common/session"
+	"github.com/xtls/xray-core/core"
+	"github.com/xtls/xray-core/features/policy"
+	"github.com/xtls/xray-core/features/routing"
+	"github.com/xtls/xray-core/proxy/hysteria/account"
+	"github.com/xtls/xray-core/transport"
+	"github.com/xtls/xray-core/transport/internet/hysteria"
+	"github.com/xtls/xray-core/transport/internet/stat"
+)
+
+type Server struct {
+	config        *ServerConfig
+	validator     *account.Validator
+	policyManager policy.Manager
+}
+
+func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
+	validator := account.NewValidator()
+	for _, user := range config.Users {
+		u, err := user.ToMemoryUser()
+		if err != nil {
+			return nil, errors.New("failed to get hysteria user").Base(err).AtError()
+		}
+
+		if err := validator.Add(u); err != nil {
+			return nil, errors.New("failed to add user").Base(err).AtError()
+		}
+	}
+
+	v := core.MustFromContext(ctx)
+	s := &Server{
+		config:        config,
+		validator:     validator,
+		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
+	}
+
+	return s, nil
+}
+
+func (s *Server) HysteriaInboundValidator() *account.Validator {
+	return s.validator
+}
+
+func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
+	return s.validator.Add(u)
+}
+
+func (s *Server) RemoveUser(ctx context.Context, e string) error {
+	return s.validator.Del(e)
+}
+
+func (s *Server) GetUser(ctx context.Context, email string) *protocol.MemoryUser {
+	return s.validator.GetByEmail(email)
+}
+
+func (s *Server) GetUsers(ctx context.Context) []*protocol.MemoryUser {
+	return s.validator.GetAll()
+}
+
+func (s *Server) GetUsersCount(context.Context) int64 {
+	return s.validator.GetCount()
+}
+
+func (s *Server) Network() []net.Network {
+	return []net.Network{net.Network_TCP}
+}
+
+func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
+	inbound := session.InboundFromContext(ctx)
+	inbound.Name = "hysteria"
+	inbound.CanSpliceCopy = 3
+
+	var useremail string
+	var userlevel uint32
+	type User interface{ User() *protocol.MemoryUser }
+	if v, ok := conn.(User); ok {
+		inbound.User = v.User()
+		if inbound.User != nil {
+			useremail = inbound.User.Email
+			userlevel = inbound.User.Level
+		}
+	}
+
+	iConn := stat.TryUnwrapStatsConn(conn)
+	if _, ok := iConn.(*hysteria.InterUdpConn); ok {
+		r := io.Reader(conn)
+		b := make([]byte, MaxUDPSize)
+		df := &Defragger{}
+		var firstMsg *UDPMessage
+		var firstDest net.Destination
+
+		for {
+			n, err := r.Read(b)
+			if err != nil {
+				return err
+			}
+
+			msg, err := ParseUDPMessage(b[:n])
+			if err != nil {
+				continue
+			}
+
+			dfMsg := df.Feed(msg)
+			if dfMsg == nil {
+				continue
+			}
+
+			firstMsg = dfMsg
+			firstDest, err = net.ParseDestination("udp:" + firstMsg.Addr)
+			if err != nil {
+				errors.LogDebug(context.Background(), dfMsg.Addr, " ParseDestination err ", err)
+				continue
+			}
+
+			break
+		}
+
+		reader := &UDPReader{
+			Reader:    r,
+			buf:       b,
+			df:        df,
+			firstMsg:  firstMsg,
+			firstDest: &firstDest,
+		}
+
+		writer := &UDPWriter{
+			Writer: conn,
+			buf:    make([]byte, MaxUDPSize),
+			addr:   firstMsg.Addr,
+		}
+
+		return dispatcher.DispatchLink(ctx, firstDest, &transport.Link{
+			Reader: reader,
+			Writer: writer,
+		})
+	} else {
+		sessionPolicy := s.policyManager.ForLevel(userlevel)
+
+		common.Must(conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)))
+		addr, err := ReadTCPRequest(conn)
+		if err != nil {
+			log.Record(&log.AccessMessage{
+				From:   conn.RemoteAddr(),
+				To:     "",
+				Status: log.AccessRejected,
+				Reason: err,
+			})
+			return errors.New("failed to create request from: ", conn.RemoteAddr()).Base(err)
+		}
+		common.Must(conn.SetReadDeadline(time.Time{}))
+
+		dest, err := net.ParseDestination("tcp:" + addr)
+		if err != nil {
+			return err
+		}
+		ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
+			From:   conn.RemoteAddr(),
+			To:     dest,
+			Status: log.AccessAccepted,
+			Reason: "",
+			Email:  useremail,
+		})
+		errors.LogInfo(ctx, "tunnelling request to ", dest)
+
+		bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
+		err = WriteTCPResponse(bufferedWriter, true, "")
+		if err != nil {
+			return errors.New("failed to write response").Base(err)
+		}
+		if err := bufferedWriter.SetBuffered(false); err != nil {
+			return err
+		}
+
+		return dispatcher.DispatchLink(ctx, dest, &transport.Link{
+			Reader: buf.NewReader(conn),
+			Writer: bufferedWriter,
+		})
+	}
+}
+
+func init() {
+	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
+		return NewServer(ctx, config.(*ServerConfig))
+	}))
+}

+ 8 - 2
transport/internet/hysteria/config.go

@@ -1,6 +1,8 @@
 package hysteria
 
 import (
+	"time"
+
 	"github.com/xtls/xray-core/common"
 	"github.com/xtls/xray-core/transport/internet"
 	"github.com/xtls/xray-core/transport/internet/hysteria/padding"
@@ -23,11 +25,15 @@ const (
 	StatusAuthOK = 233
 
 	udpMessageChanSize = 1024
+
+	FrameTypeTCPRequest = 0x401
+
+	idleCleanupInterval = 1 * time.Second
 )
 
 var (
-	authRequestPadding = padding.Padding{Min: 256, Max: 2048}
-	// authResponsePadding = padding.Padding{Min: 256, Max: 2048}
+	authRequestPadding  = padding.Padding{Min: 256, Max: 2048}
+	authResponsePadding = padding.Padding{Min: 256, Max: 2048}
 )
 
 type Status int

+ 105 - 9
transport/internet/hysteria/config.pb.go

@@ -38,6 +38,16 @@ type Config struct {
 	MaxIdleTimeout          int64                  `protobuf:"varint,13,opt,name=max_idle_timeout,json=maxIdleTimeout,proto3" json:"max_idle_timeout,omitempty"`
 	KeepAlivePeriod         int64                  `protobuf:"varint,14,opt,name=keep_alive_period,json=keepAlivePeriod,proto3" json:"keep_alive_period,omitempty"`
 	DisablePathMtuDiscovery bool                   `protobuf:"varint,15,opt,name=disable_path_mtu_discovery,json=disablePathMtuDiscovery,proto3" json:"disable_path_mtu_discovery,omitempty"`
+	MaxIncomingStreams      int64                  `protobuf:"varint,16,opt,name=max_incoming_streams,json=maxIncomingStreams,proto3" json:"max_incoming_streams,omitempty"`
+	UdpIdleTimeout          int64                  `protobuf:"varint,17,opt,name=udp_idle_timeout,json=udpIdleTimeout,proto3" json:"udp_idle_timeout,omitempty"`
+	MasqType                string                 `protobuf:"bytes,18,opt,name=masq_type,json=masqType,proto3" json:"masq_type,omitempty"`
+	MasqFile                string                 `protobuf:"bytes,19,opt,name=masq_file,json=masqFile,proto3" json:"masq_file,omitempty"`
+	MasqUrl                 string                 `protobuf:"bytes,20,opt,name=masq_url,json=masqUrl,proto3" json:"masq_url,omitempty"`
+	MasqUrlRewriteHost      bool                   `protobuf:"varint,21,opt,name=masq_url_rewrite_host,json=masqUrlRewriteHost,proto3" json:"masq_url_rewrite_host,omitempty"`
+	MasqUrlInsecure         bool                   `protobuf:"varint,22,opt,name=masq_url_insecure,json=masqUrlInsecure,proto3" json:"masq_url_insecure,omitempty"`
+	MasqString              string                 `protobuf:"bytes,23,opt,name=masq_string,json=masqString,proto3" json:"masq_string,omitempty"`
+	MasqStringHeaders       map[string]string      `protobuf:"bytes,24,rep,name=masq_string_headers,json=masqStringHeaders,proto3" json:"masq_string_headers,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
+	MasqStringStatusCode    int32                  `protobuf:"varint,25,opt,name=masq_string_status_code,json=masqStringStatusCode,proto3" json:"masq_string_status_code,omitempty"`
 	unknownFields           protoimpl.UnknownFields
 	sizeCache               protoimpl.SizeCache
 }
@@ -177,11 +187,81 @@ func (x *Config) GetDisablePathMtuDiscovery() bool {
 	return false
 }
 
+func (x *Config) GetMaxIncomingStreams() int64 {
+	if x != nil {
+		return x.MaxIncomingStreams
+	}
+	return 0
+}
+
+func (x *Config) GetUdpIdleTimeout() int64 {
+	if x != nil {
+		return x.UdpIdleTimeout
+	}
+	return 0
+}
+
+func (x *Config) GetMasqType() string {
+	if x != nil {
+		return x.MasqType
+	}
+	return ""
+}
+
+func (x *Config) GetMasqFile() string {
+	if x != nil {
+		return x.MasqFile
+	}
+	return ""
+}
+
+func (x *Config) GetMasqUrl() string {
+	if x != nil {
+		return x.MasqUrl
+	}
+	return ""
+}
+
+func (x *Config) GetMasqUrlRewriteHost() bool {
+	if x != nil {
+		return x.MasqUrlRewriteHost
+	}
+	return false
+}
+
+func (x *Config) GetMasqUrlInsecure() bool {
+	if x != nil {
+		return x.MasqUrlInsecure
+	}
+	return false
+}
+
+func (x *Config) GetMasqString() string {
+	if x != nil {
+		return x.MasqString
+	}
+	return ""
+}
+
+func (x *Config) GetMasqStringHeaders() map[string]string {
+	if x != nil {
+		return x.MasqStringHeaders
+	}
+	return nil
+}
+
+func (x *Config) GetMasqStringStatusCode() int32 {
+	if x != nil {
+		return x.MasqStringStatusCode
+	}
+	return 0
+}
+
 var File_transport_internet_hysteria_config_proto protoreflect.FileDescriptor
 
 const file_transport_internet_hysteria_config_proto_rawDesc = "" +
 	"\n" +
-	"(transport/internet/hysteria/config.proto\x12 xray.transport.internet.hysteria\"\xd1\x04\n" +
+	"(transport/internet/hysteria/config.proto\x12 xray.transport.internet.hysteria\"\xf0\b\n" +
 	"\x06Config\x12\x18\n" +
 	"\aversion\x18\x01 \x01(\x05R\aversion\x12\x12\n" +
 	"\x04auth\x18\x02 \x01(\tR\x04auth\x12\x1e\n" +
@@ -200,7 +280,21 @@ const file_transport_internet_hysteria_config_proto_rawDesc = "" +
 	"\x17max_conn_receive_window\x18\f \x01(\x04R\x14maxConnReceiveWindow\x12(\n" +
 	"\x10max_idle_timeout\x18\r \x01(\x03R\x0emaxIdleTimeout\x12*\n" +
 	"\x11keep_alive_period\x18\x0e \x01(\x03R\x0fkeepAlivePeriod\x12;\n" +
-	"\x1adisable_path_mtu_discovery\x18\x0f \x01(\bR\x17disablePathMtuDiscoveryB\x82\x01\n" +
+	"\x1adisable_path_mtu_discovery\x18\x0f \x01(\bR\x17disablePathMtuDiscovery\x120\n" +
+	"\x14max_incoming_streams\x18\x10 \x01(\x03R\x12maxIncomingStreams\x12(\n" +
+	"\x10udp_idle_timeout\x18\x11 \x01(\x03R\x0eudpIdleTimeout\x12\x1b\n" +
+	"\tmasq_type\x18\x12 \x01(\tR\bmasqType\x12\x1b\n" +
+	"\tmasq_file\x18\x13 \x01(\tR\bmasqFile\x12\x19\n" +
+	"\bmasq_url\x18\x14 \x01(\tR\amasqUrl\x121\n" +
+	"\x15masq_url_rewrite_host\x18\x15 \x01(\bR\x12masqUrlRewriteHost\x12*\n" +
+	"\x11masq_url_insecure\x18\x16 \x01(\bR\x0fmasqUrlInsecure\x12\x1f\n" +
+	"\vmasq_string\x18\x17 \x01(\tR\n" +
+	"masqString\x12o\n" +
+	"\x13masq_string_headers\x18\x18 \x03(\v2?.xray.transport.internet.hysteria.Config.MasqStringHeadersEntryR\x11masqStringHeaders\x125\n" +
+	"\x17masq_string_status_code\x18\x19 \x01(\x05R\x14masqStringStatusCode\x1aD\n" +
+	"\x16MasqStringHeadersEntry\x12\x10\n" +
+	"\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" +
+	"\x05value\x18\x02 \x01(\tR\x05value:\x028\x01B\x82\x01\n" +
 	"$com.xray.transport.internet.hysteriaP\x01Z5github.com/xtls/xray-core/transport/internet/hysteria\xaa\x02 Xray.Transport.Internet.Hysteriab\x06proto3"
 
 var (
@@ -215,16 +309,18 @@ func file_transport_internet_hysteria_config_proto_rawDescGZIP() []byte {
 	return file_transport_internet_hysteria_config_proto_rawDescData
 }
 
-var file_transport_internet_hysteria_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
+var file_transport_internet_hysteria_config_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
 var file_transport_internet_hysteria_config_proto_goTypes = []any{
 	(*Config)(nil), // 0: xray.transport.internet.hysteria.Config
+	nil,            // 1: xray.transport.internet.hysteria.Config.MasqStringHeadersEntry
 }
 var file_transport_internet_hysteria_config_proto_depIdxs = []int32{
-	0, // [0:0] is the sub-list for method output_type
-	0, // [0:0] is the sub-list for method input_type
-	0, // [0:0] is the sub-list for extension type_name
-	0, // [0:0] is the sub-list for extension extendee
-	0, // [0:0] is the sub-list for field type_name
+	1, // 0: xray.transport.internet.hysteria.Config.masq_string_headers:type_name -> xray.transport.internet.hysteria.Config.MasqStringHeadersEntry
+	1, // [1:1] is the sub-list for method output_type
+	1, // [1:1] is the sub-list for method input_type
+	1, // [1:1] is the sub-list for extension type_name
+	1, // [1:1] is the sub-list for extension extendee
+	0, // [0:1] is the sub-list for field type_name
 }
 
 func init() { file_transport_internet_hysteria_config_proto_init() }
@@ -238,7 +334,7 @@ func file_transport_internet_hysteria_config_proto_init() {
 			GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
 			RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_hysteria_config_proto_rawDesc), len(file_transport_internet_hysteria_config_proto_rawDesc)),
 			NumEnums:      0,
-			NumMessages:   1,
+			NumMessages:   2,
 			NumExtensions: 0,
 			NumServices:   0,
 		},

+ 11 - 1
transport/internet/hysteria/config.proto

@@ -23,5 +23,15 @@ message Config {
   int64 max_idle_timeout = 13;
   int64 keep_alive_period = 14;
   bool disable_path_mtu_discovery = 15;
-}
+  int64 max_incoming_streams = 16;
 
+  int64 udp_idle_timeout = 17;
+  string masq_type = 18;
+  string masq_file = 19;
+  string masq_url = 20;
+  bool masq_url_rewrite_host = 21;
+  bool masq_url_insecure = 22;
+  string masq_string = 23;
+  map<string, string> masq_string_headers = 24;
+  int32 masq_string_status_code = 25;
+}

+ 58 - 2
transport/internet/hysteria/conn.go

@@ -3,16 +3,28 @@ package hysteria
 import (
 	"encoding/binary"
 	"io"
+	"sync"
 	"time"
 
 	"github.com/apernet/quic-go"
+	"github.com/apernet/quic-go/quicvarint"
 	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/protocol"
 )
 
 type interConn struct {
 	stream *quic.Stream
 	local  net.Addr
 	remote net.Addr
+
+	client bool
+	mutex  sync.Mutex
+
+	user *protocol.MemoryUser
+}
+
+func (i *interConn) User() *protocol.MemoryUser {
+	return i.user
 }
 
 func (i *interConn) Read(b []byte) (int, error) {
@@ -20,6 +32,22 @@ func (i *interConn) Read(b []byte) (int, error) {
 }
 
 func (i *interConn) Write(b []byte) (int, error) {
+	if i.client {
+		i.mutex.Lock()
+		if i.client {
+			buf := make([]byte, 0, quicvarint.Len(FrameTypeTCPRequest)+len(b))
+			buf = quicvarint.Append(buf, FrameTypeTCPRequest)
+			buf = append(buf, b...)
+			_, err := i.stream.Write(buf)
+			if err != nil {
+				return 0, err
+			}
+			i.client = false
+			return len(b), nil
+		}
+		i.mutex.Unlock()
+	}
+
 	return i.stream.Write(b)
 }
 
@@ -53,10 +81,34 @@ type InterUdpConn struct {
 	local  net.Addr
 	remote net.Addr
 
-	id        uint32
-	ch        chan []byte
+	id uint32
+	ch chan []byte
+
 	closed    bool
 	closeFunc func()
+
+	last  time.Time
+	mutex sync.Mutex
+
+	user *protocol.MemoryUser
+}
+
+func (i *InterUdpConn) User() *protocol.MemoryUser {
+	return i.user
+}
+
+func (i *InterUdpConn) SetLast() {
+	i.mutex.Lock()
+	defer i.mutex.Unlock()
+
+	i.last = time.Now()
+}
+
+func (i *InterUdpConn) GetLast() time.Time {
+	i.mutex.Lock()
+	defer i.mutex.Unlock()
+
+	return i.last
 }
 
 func (i *InterUdpConn) Read(p []byte) (int, error) {
@@ -68,10 +120,14 @@ func (i *InterUdpConn) Read(p []byte) (int, error) {
 	if n != len(b) {
 		return 0, io.ErrShortBuffer
 	}
+
+	i.SetLast()
 	return n, nil
 }
 
 func (i *InterUdpConn) Write(p []byte) (int, error) {
+	i.SetLast()
+
 	binary.BigEndian.PutUint32(p, i.id)
 	if err := i.conn.SendDatagram(p); err != nil {
 		return 0, err

+ 27 - 24
transport/internet/hysteria/dialer.go

@@ -26,15 +26,23 @@ import (
 	"github.com/xtls/xray-core/transport/internet/tls"
 )
 
-type udpSessionManager struct {
+type udpSessionManagerClient struct {
 	conn   *quic.Conn
 	m      map[uint32]*InterUdpConn
-	nextId uint32
+	next   uint32
 	closed bool
 	mutex  sync.RWMutex
 }
 
-func (m *udpSessionManager) run() {
+func (m *udpSessionManagerClient) close(udpConn *InterUdpConn) {
+	if !udpConn.closed {
+		udpConn.closed = true
+		close(udpConn.ch)
+		delete(m.m, udpConn.id)
+	}
+}
+
+func (m *udpSessionManagerClient) run() {
 	for {
 		d, err := m.conn.ReceiveDatagram(context.Background())
 		if err != nil {
@@ -44,29 +52,22 @@ func (m *udpSessionManager) run() {
 		if len(d) < 4 {
 			continue
 		}
-		sessionId := binary.BigEndian.Uint32(d[:4])
+		id := binary.BigEndian.Uint32(d[:4])
 
-		m.feed(sessionId, d)
+		m.feed(id, d)
 	}
 
 	m.mutex.Lock()
 	defer m.mutex.Unlock()
 
 	m.closed = true
+
 	for _, udpConn := range m.m {
 		m.close(udpConn)
 	}
 }
 
-func (m *udpSessionManager) close(udpConn *InterUdpConn) {
-	if !udpConn.closed {
-		udpConn.closed = true
-		close(udpConn.ch)
-		delete(m.m, udpConn.id)
-	}
-}
-
-func (m *udpSessionManager) udp() (*InterUdpConn, error) {
+func (m *udpSessionManagerClient) udp() (*InterUdpConn, error) {
 	m.mutex.Lock()
 	defer m.mutex.Unlock()
 
@@ -79,7 +80,7 @@ func (m *udpSessionManager) udp() (*InterUdpConn, error) {
 		local:  m.conn.LocalAddr(),
 		remote: m.conn.RemoteAddr(),
 
-		id: m.nextId,
+		id: m.next,
 		ch: make(chan []byte, udpMessageChanSize),
 	}
 	udpConn.closeFunc = func() {
@@ -87,17 +88,17 @@ func (m *udpSessionManager) udp() (*InterUdpConn, error) {
 		defer m.mutex.Unlock()
 		m.close(udpConn)
 	}
-	m.m[m.nextId] = udpConn
-	m.nextId++
+	m.m[m.next] = udpConn
+	m.next++
 
 	return udpConn, nil
 }
 
-func (m *udpSessionManager) feed(sessionId uint32, d []byte) {
+func (m *udpSessionManagerClient) feed(id uint32, d []byte) {
 	m.mutex.RLock()
 	defer m.mutex.RUnlock()
 
-	udpConn, ok := m.m[sessionId]
+	udpConn, ok := m.m[id]
 	if !ok {
 		return
 	}
@@ -117,7 +118,7 @@ type client struct {
 	tlsConfig      *go_tls.Config
 	socketConfig   *internet.SocketConfig
 	udpmaskManager *finalmask.UdpmaskManager
-	udpSM          *udpSessionManager
+	udpSM          *udpSessionManagerClient
 	mutex          sync.Mutex
 }
 
@@ -269,10 +270,10 @@ func (c *client) dial() error {
 	c.pktConn = pktConn
 	c.conn = quicConn
 	if serverUdp {
-		c.udpSM = &udpSessionManager{
-			conn:   quicConn,
-			m:      make(map[uint32]*InterUdpConn),
-			nextId: 1,
+		c.udpSM = &udpSessionManagerClient{
+			conn: quicConn,
+			m:    make(map[uint32]*InterUdpConn),
+			next: 1,
 		}
 		go c.udpSM.run()
 	}
@@ -307,6 +308,8 @@ func (c *client) tcp() (stat.Connection, error) {
 		stream: stream,
 		local:  c.conn.LocalAddr(),
 		remote: c.conn.RemoteAddr(),
+
+		client: true,
 	}, nil
 }
 

+ 412 - 0
transport/internet/hysteria/hub.go

@@ -0,0 +1,412 @@
+package hysteria
+
+import (
+	"context"
+	gotls "crypto/tls"
+	"encoding/binary"
+	"net/http"
+	"net/http/httputil"
+	"net/url"
+	"strconv"
+	"strings"
+	"sync"
+	"time"
+
+	"github.com/apernet/quic-go"
+	"github.com/apernet/quic-go/http3"
+	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/errors"
+	"github.com/xtls/xray-core/common/net"
+	"github.com/xtls/xray-core/common/protocol"
+	"github.com/xtls/xray-core/proxy/hysteria/account"
+	hyCtx "github.com/xtls/xray-core/proxy/hysteria/ctx"
+	"github.com/xtls/xray-core/transport/internet"
+	"github.com/xtls/xray-core/transport/internet/hysteria/congestion"
+	"github.com/xtls/xray-core/transport/internet/tls"
+)
+
+type udpSessionManagerServer struct {
+	conn           *quic.Conn
+	m              map[uint32]*InterUdpConn
+	addConn        internet.ConnHandler
+	stopCh         chan struct{}
+	udpIdleTimeout time.Duration
+	mutex          sync.RWMutex
+
+	user *protocol.MemoryUser
+}
+
+func (m *udpSessionManagerServer) close(udpConn *InterUdpConn) {
+	if !udpConn.closed {
+		udpConn.closed = true
+		close(udpConn.ch)
+		delete(m.m, udpConn.id)
+	}
+}
+
+func (m *udpSessionManagerServer) clean() {
+	ticker := time.NewTicker(idleCleanupInterval)
+	defer ticker.Stop()
+	for {
+		select {
+		case <-ticker.C:
+			m.mutex.RLock()
+			now := time.Now()
+			timeoutConn := make([]*InterUdpConn, 0, len(m.m))
+			for _, udpConn := range m.m {
+				if now.Sub(udpConn.GetLast()) > m.udpIdleTimeout {
+					timeoutConn = append(timeoutConn, udpConn)
+				}
+			}
+			m.mutex.RUnlock()
+
+			for _, udpConn := range timeoutConn {
+				m.mutex.Lock()
+				m.close(udpConn)
+				m.mutex.Unlock()
+			}
+		case <-m.stopCh:
+			return
+		}
+	}
+}
+
+func (m *udpSessionManagerServer) run() {
+	for {
+		d, err := m.conn.ReceiveDatagram(context.Background())
+		if err != nil {
+			break
+		}
+
+		if len(d) < 4 {
+			continue
+		}
+		id := binary.BigEndian.Uint32(d[:4])
+
+		m.feed(id, d)
+	}
+
+	m.mutex.Lock()
+	defer m.mutex.Unlock()
+
+	close(m.stopCh)
+
+	for _, udpConn := range m.m {
+		m.close(udpConn)
+	}
+}
+
+func (m *udpSessionManagerServer) feed(id uint32, d []byte) {
+	m.mutex.RLock()
+	udpConn, ok := m.m[id]
+	m.mutex.RUnlock()
+
+	if !ok {
+		m.mutex.Lock()
+		udpConn, ok = m.m[id]
+		if !ok {
+			udpConn = &InterUdpConn{
+				conn:   m.conn,
+				local:  m.conn.LocalAddr(),
+				remote: m.conn.RemoteAddr(),
+
+				id:   id,
+				ch:   make(chan []byte, udpMessageChanSize),
+				last: time.Now(),
+
+				user: m.user,
+			}
+			udpConn.closeFunc = func() {
+				m.mutex.Lock()
+				defer m.mutex.Unlock()
+				m.close(udpConn)
+			}
+			m.m[id] = udpConn
+			m.addConn(udpConn)
+		}
+		m.mutex.Unlock()
+	}
+
+	select {
+	case udpConn.ch <- d:
+	default:
+	}
+}
+
+type httpHandler struct {
+	ctx     context.Context
+	conn    *quic.Conn
+	addConn internet.ConnHandler
+
+	config      *Config
+	validator   *account.Validator
+	masqHandler http.Handler
+
+	auth  bool
+	mutex sync.Mutex
+	user  *protocol.MemoryUser
+}
+
+func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	if r.Method == http.MethodPost && r.Host == URLHost && r.URL.Path == URLPath {
+		h.mutex.Lock()
+		defer h.mutex.Unlock()
+
+		if h.auth {
+			w.Header().Set(ResponseHeaderUDPEnabled, strconv.FormatBool(hyCtx.RequireDatagramFromContext(h.ctx)))
+			w.Header().Set(CommonHeaderCCRX, strconv.FormatUint(h.config.Down, 10))
+			w.Header().Set(CommonHeaderPadding, authResponsePadding.String())
+			w.WriteHeader(StatusAuthOK)
+			return
+		}
+
+		auth := r.Header.Get(RequestHeaderAuth)
+		clientDown, _ := strconv.ParseUint(r.Header.Get(CommonHeaderCCRX), 10, 64)
+
+		var user *protocol.MemoryUser
+		var ok bool
+		if h.validator != nil {
+			user = h.validator.Get(auth)
+		} else if auth == h.config.Auth {
+			ok = true
+		}
+
+		if user != nil || ok {
+			h.auth = true
+			h.user = user
+
+			switch h.config.Congestion {
+			case "reno":
+				errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion reno")
+			case "bbr":
+				errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion bbr")
+				congestion.UseBBR(h.conn)
+			case "brutal", "":
+				if h.config.Up == 0 || clientDown == 0 {
+					errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion bbr")
+					congestion.UseBBR(h.conn)
+				} else {
+					errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion brutal bytes per second ", min(h.config.Up, clientDown))
+					congestion.UseBrutal(h.conn, min(h.config.Up, clientDown))
+				}
+			case "force-brutal":
+				errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion brutal bytes per second ", h.config.Up)
+				congestion.UseBrutal(h.conn, h.config.Up)
+			default:
+				errors.LogDebug(context.Background(), h.conn.RemoteAddr(), " ", "congestion reno")
+			}
+
+			if hyCtx.RequireDatagramFromContext(h.ctx) {
+				udpSM := &udpSessionManagerServer{
+					conn:           h.conn,
+					m:              make(map[uint32]*InterUdpConn),
+					addConn:        h.addConn,
+					stopCh:         make(chan struct{}),
+					udpIdleTimeout: time.Duration(h.config.UdpIdleTimeout) * time.Second,
+
+					user: h.user,
+				}
+				go udpSM.clean()
+				go udpSM.run()
+			}
+
+			w.Header().Set(ResponseHeaderUDPEnabled, strconv.FormatBool(hyCtx.RequireDatagramFromContext(h.ctx)))
+			w.Header().Set(CommonHeaderCCRX, strconv.FormatUint(h.config.Down, 10))
+			w.Header().Set(CommonHeaderPadding, authResponsePadding.String())
+			w.WriteHeader(StatusAuthOK)
+			return
+		}
+	}
+
+	h.masqHandler.ServeHTTP(w, r)
+}
+
+func (h *httpHandler) ProxyStreamHijacker(ft http3.FrameType, id quic.ConnectionTracingID, stream *quic.Stream, err error) (bool, error) {
+	if err != nil || !h.auth {
+		return false, nil
+	}
+
+	switch ft {
+	case FrameTypeTCPRequest:
+		h.addConn(&interConn{
+			stream: stream,
+			local:  h.conn.LocalAddr(),
+			remote: h.conn.RemoteAddr(),
+
+			user: h.user,
+		})
+		return true, nil
+	default:
+		return false, nil
+	}
+}
+
+type Listener struct {
+	ctx      context.Context
+	pktConn  net.PacketConn
+	listener *quic.Listener
+	addConn  internet.ConnHandler
+
+	config      *Config
+	validator   *account.Validator
+	masqHandler http.Handler
+}
+
+func (l *Listener) handleClient(conn *quic.Conn) {
+	handler := &httpHandler{
+		ctx:     l.ctx,
+		conn:    conn,
+		addConn: l.addConn,
+
+		config:      l.config,
+		validator:   l.validator,
+		masqHandler: l.masqHandler,
+	}
+	h3 := http3.Server{
+		Handler:        handler,
+		StreamHijacker: handler.ProxyStreamHijacker,
+	}
+	err := h3.ServeQUICConn(conn)
+	errors.LogDebug(context.Background(), conn.RemoteAddr(), " disconnected with err ", err)
+	_ = conn.CloseWithError(closeErrCodeOK, "")
+}
+
+func (l *Listener) keepAccepting() {
+	for {
+		conn, err := l.listener.Accept(context.Background())
+		if err != nil {
+			errors.LogInfoInner(context.Background(), err, "failed to accept QUIC connection")
+			break
+		}
+		go l.handleClient(conn)
+	}
+}
+
+func (l *Listener) Addr() net.Addr {
+	return l.listener.Addr()
+}
+
+func (l *Listener) Close() error {
+	err := l.listener.Close()
+	_ = l.pktConn.Close()
+	return err
+}
+
+func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) {
+	if address.Family().IsDomain() {
+		return nil, errors.New("address is domain")
+	}
+
+	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
+	if tlsConfig == nil {
+		return nil, errors.New("tls config is nil")
+	}
+
+	config := streamSettings.ProtocolSettings.(*Config)
+
+	validator := hyCtx.ValidatorFromContext(ctx)
+
+	if config.Auth == "" && validator == nil {
+		return nil, errors.New("validator is nil")
+	}
+
+	var masqHandler http.Handler
+	switch strings.ToLower(config.MasqType) {
+	case "", "404":
+		masqHandler = http.NotFoundHandler()
+	case "file":
+		masqHandler = http.FileServer(http.Dir(config.MasqFile))
+	case "proxy":
+		u, err := url.Parse(config.MasqUrl)
+		if err != nil {
+			return nil, err
+		}
+		transport := http.DefaultTransport.(*http.Transport)
+		if config.MasqUrlInsecure {
+			transport = transport.Clone()
+			transport.TLSClientConfig = &gotls.Config{
+				InsecureSkipVerify: true,
+			}
+		}
+		masqHandler = &httputil.ReverseProxy{
+			Rewrite: func(pr *httputil.ProxyRequest) {
+				pr.SetURL(u)
+				if !config.MasqUrlRewriteHost {
+					pr.Out.Host = pr.In.Host
+				}
+			},
+			Transport: transport,
+			ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
+				w.WriteHeader(http.StatusBadGateway)
+			},
+		}
+	case "string":
+		masqHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			for k, v := range config.MasqStringHeaders {
+				w.Header().Set(k, v)
+			}
+			if config.MasqStringStatusCode != 0 {
+				w.WriteHeader(int(config.MasqStringStatusCode))
+			} else {
+				w.WriteHeader(http.StatusOK)
+			}
+			_, _ = w.Write([]byte(config.MasqString))
+		})
+	default:
+		return nil, errors.New("unknown masq type")
+	}
+
+	raw, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{IP: address.IP(), Port: int(port)}, streamSettings.SocketSettings)
+	if err != nil {
+		return nil, err
+	}
+
+	var pktConn net.PacketConn
+	pktConn = raw
+
+	if streamSettings.UdpmaskManager != nil {
+		pktConn, err = streamSettings.UdpmaskManager.WrapPacketConnServer(raw)
+		if err != nil {
+			raw.Close()
+			return nil, errors.New("mask err").Base(err)
+		}
+	}
+
+	quicConfig := &quic.Config{
+		InitialStreamReceiveWindow:     config.InitStreamReceiveWindow,
+		MaxStreamReceiveWindow:         config.MaxStreamReceiveWindow,
+		InitialConnectionReceiveWindow: config.InitConnReceiveWindow,
+		MaxConnectionReceiveWindow:     config.MaxConnReceiveWindow,
+		MaxIdleTimeout:                 time.Duration(config.MaxIdleTimeout) * time.Second,
+		MaxIncomingStreams:             config.MaxIncomingStreams,
+		DisablePathMTUDiscovery:        config.DisablePathMtuDiscovery,
+		EnableDatagrams:                true,
+		MaxDatagramFrameSize:           MaxDatagramFrameSize,
+		DisablePathManager:             true,
+	}
+
+	qListener, err := quic.Listen(pktConn, tlsConfig.GetTLSConfig(), quicConfig)
+	if err != nil {
+		_ = pktConn.Close()
+		return nil, err
+	}
+
+	listener := &Listener{
+		ctx:      ctx,
+		pktConn:  pktConn,
+		listener: qListener,
+		addConn:  handler,
+
+		config:      config,
+		validator:   validator,
+		masqHandler: masqHandler,
+	}
+
+	go listener.keepAccepting()
+
+	return listener, nil
+}
+
+func init() {
+	common.Must(internet.RegisterTransportListener(protocolName, Listen))
+}