Просмотр исходного кода

Refine must2 and apply NewAesGcm() to all usage (#5011)

* Refine must2 and apply NewAesGcm() to all usage

* Remove unused package

* Fix test
风扇滑翔翼 10 месяцев назад
Родитель
Сommit
b1107b9810

+ 13 - 13
app/dns/dnscommon_test.go

@@ -18,31 +18,31 @@ func Test_parseResponse(t *testing.T) {
 
 	ans := new(dns.Msg)
 	ans.Id = 0
-	p = append(p, common.Must2(ans.Pack()).([]byte))
+	p = append(p, common.Must2(ans.Pack()))
 
 	p = append(p, []byte{})
 
 	ans = new(dns.Msg)
 	ans.Id = 1
 	ans.Answer = append(ans.Answer,
-		common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
-		common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR),
-		common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")).(dns.RR),
-		common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")).(dns.RR),
+		common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")),
+		common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")),
+		common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")),
+		common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")),
 	)
-	p = append(p, common.Must2(ans.Pack()).([]byte))
+	p = append(p, common.Must2(ans.Pack()))
 
 	ans = new(dns.Msg)
 	ans.Id = 2
 	ans.Answer = append(ans.Answer,
-		common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
-		common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR),
-		common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
-		common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")).(dns.RR),
-		common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")).(dns.RR),
-		common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")).(dns.RR),
+		common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")),
+		common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")),
+		common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")),
+		common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")),
+		common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")),
+		common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")),
 	)
-	p = append(p, common.Must2(ans.Pack()).([]byte))
+	p = append(p, common.Must2(ans.Pack()))
 
 	tests := []struct {
 		name    string

+ 3 - 1
common/common.go

@@ -23,7 +23,9 @@ func Must(err error) {
 }
 
 // Must2 panics if the second parameter is not nil, otherwise returns the first parameter.
-func Must2(v interface{}, err error) interface{} {
+// This is useful when function returned "sth, err" and avoid many "if err != nil"
+// Internal usage only, if user input can cause err, it must be handled
+func Must2[T any](v T, err error) T {
 	Must(err)
 	return v
 }

+ 2 - 4
common/crypto/aes.go

@@ -32,9 +32,7 @@ func NewAesCTRStream(key []byte, iv []byte) cipher.Stream {
 
 // NewAesGcm creates a AEAD cipher based on AES-GCM.
 func NewAesGcm(key []byte) cipher.AEAD {
-	block, err := aes.NewCipher(key)
-	common.Must(err)
-	aead, err := cipher.NewGCM(block)
-	common.Must(err)
+	block := common.Must2(aes.NewCipher(key))
+	aead := common.Must2(cipher.NewGCM(block))
 	return aead
 }

+ 3 - 11
common/crypto/auth_test.go

@@ -2,8 +2,6 @@ package crypto_test
 
 import (
 	"bytes"
-	"crypto/aes"
-	"crypto/cipher"
 	"crypto/rand"
 	"io"
 	"testing"
@@ -18,11 +16,8 @@ import (
 func TestAuthenticationReaderWriter(t *testing.T) {
 	key := make([]byte, 16)
 	rand.Read(key)
-	block, err := aes.NewCipher(key)
-	common.Must(err)
 
-	aead, err := cipher.NewGCM(block)
-	common.Must(err)
+	aead := NewAesGcm(key)
 
 	const payloadSize = 1024 * 80
 	rawPayload := make([]byte, payloadSize)
@@ -71,7 +66,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
 		t.Error(r)
 	}
 
-	_, err = reader.ReadMultiBuffer()
+	_, err := reader.ReadMultiBuffer()
 	if err != io.EOF {
 		t.Error("error: ", err)
 	}
@@ -80,11 +75,8 @@ func TestAuthenticationReaderWriter(t *testing.T) {
 func TestAuthenticationReaderWriterPacket(t *testing.T) {
 	key := make([]byte, 16)
 	common.Must2(rand.Read(key))
-	block, err := aes.NewCipher(key)
-	common.Must(err)
 
-	aead, err := cipher.NewGCM(block)
-	common.Must(err)
+	aead := NewAesGcm(key)
 
 	cache := buf.New()
 	iv := make([]byte, 12)

+ 2 - 2
proxy/dokodemo/dokodemo.go

@@ -91,7 +91,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
 			}
 		}
 		if dest.Port == 0 {
-			dest.Port = net.Port(common.Must2(strconv.Atoi(port)).(int))
+			dest.Port = net.Port(common.Must2(strconv.Atoi(port)))
 		}
 		if d.portMap != nil && d.portMap[port] != "" {
 			h, p, _ := net.SplitHostPort(d.portMap[port])
@@ -99,7 +99,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
 				dest.Address = net.ParseAddress(h)
 			}
 			if len(p) > 0 {
-				dest.Port = net.Port(common.Must2(strconv.Atoi(p)).(int))
+				dest.Port = net.Port(common.Must2(strconv.Atoi(p)))
 			}
 		}
 	}

+ 1 - 6
proxy/shadowsocks/config.go

@@ -2,7 +2,6 @@ package shadowsocks
 
 import (
 	"bytes"
-	"crypto/aes"
 	"crypto/cipher"
 	"crypto/md5"
 	"crypto/sha1"
@@ -58,11 +57,7 @@ func (a *MemoryAccount) CheckIV(iv []byte) error {
 }
 
 func createAesGcm(key []byte) cipher.AEAD {
-	block, err := aes.NewCipher(key)
-	common.Must(err)
-	gcm, err := cipher.NewGCM(block)
-	common.Must(err)
-	return gcm
+	return crypto.NewAesGcm(key)
 }
 
 func createChaCha20Poly1305(key []byte) cipher.AEAD {

+ 5 - 38
proxy/vmess/aead/encrypt.go

@@ -2,14 +2,13 @@ package aead
 
 import (
 	"bytes"
-	"crypto/aes"
-	"crypto/cipher"
 	"crypto/rand"
 	"encoding/binary"
 	"io"
 	"time"
 
 	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/crypto"
 )
 
 func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
@@ -34,15 +33,7 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
 
 		payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12]
 
-		payloadHeaderLengthAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey)
-		if err != nil {
-			panic(err.Error())
-		}
-
-		payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderLengthAEADAESBlock)
-		if err != nil {
-			panic(err.Error())
-		}
+		payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderLengthAEADKey)
 
 		payloadHeaderLengthAEADEncrypted = payloadHeaderAEAD.Seal(nil, payloadHeaderLengthAEADNonce, aeadPayloadLengthSerializedByte, generatedAuthID[:])
 	}
@@ -54,15 +45,7 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
 
 		payloadHeaderAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12]
 
-		payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey)
-		if err != nil {
-			panic(err.Error())
-		}
-
-		payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock)
-		if err != nil {
-			panic(err.Error())
-		}
+		payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderAEADKey)
 
 		payloadHeaderAEADEncrypted = payloadHeaderAEAD.Seal(nil, payloadHeaderAEADNonce, data, generatedAuthID[:])
 	}
@@ -104,15 +87,7 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte,
 
 		payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(authid[:]), string(nonce[:]))[:12]
 
-		payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey)
-		if err != nil {
-			panic(err.Error())
-		}
-
-		payloadHeaderLengthAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock)
-		if err != nil {
-			panic(err.Error())
-		}
+		payloadHeaderLengthAEAD := crypto.NewAesGcm(payloadHeaderLengthAEADKey)
 
 		decryptedAEADHeaderLengthPayload, erropenAEAD := payloadHeaderLengthAEAD.Open(nil, payloadHeaderLengthAEADNonce, payloadHeaderLengthAEADEncrypted[:], authid[:])
 
@@ -145,15 +120,7 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte,
 			return nil, false, bytesRead, err
 		}
 
-		payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey)
-		if err != nil {
-			panic(err.Error())
-		}
-
-		payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock)
-		if err != nil {
-			panic(err.Error())
-		}
+		payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderAEADKey)
 
 		decryptedAEADHeaderPayload, erropenAEAD := payloadHeaderAEAD.Open(nil, payloadHeaderAEADNonce, payloadHeaderAEADEncrypted, authid[:])
 

+ 2 - 6
proxy/vmess/encoding/client.go

@@ -3,8 +3,6 @@ package encoding
 import (
 	"bytes"
 	"context"
-	"crypto/aes"
-	"crypto/cipher"
 	"crypto/rand"
 	"crypto/sha256"
 	"encoding/binary"
@@ -182,8 +180,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
 	aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey)
 	aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12]
 
-	aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block)
-	aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD)
+	aeadResponseHeaderLengthEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderLengthEncryptionKey)
 
 	var aeadEncryptedResponseHeaderLength [18]byte
 	var decryptedResponseHeaderLength int
@@ -205,8 +202,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
 	aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey)
 	aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12]
 
-	aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block)
-	aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD)
+	aeadResponseHeaderPayloadEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderPayloadEncryptionKey)
 
 	encryptedResponseHeaderBuffer := make([]byte, decryptedResponseHeaderLength+16)
 

+ 2 - 6
proxy/vmess/encoding/server.go

@@ -2,8 +2,6 @@ package encoding
 
 import (
 	"bytes"
-	"crypto/aes"
-	"crypto/cipher"
 	"crypto/sha256"
 	"encoding/binary"
 	"hash/fnv"
@@ -350,8 +348,7 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr
 	aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey)
 	aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12]
 
-	aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block)
-	aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD)
+	aeadResponseHeaderLengthEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderLengthEncryptionKey)
 
 	aeadResponseHeaderLengthEncryptionBuffer := bytes.NewBuffer(nil)
 
@@ -365,8 +362,7 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr
 	aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey)
 	aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12]
 
-	aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block)
-	aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD)
+	aeadResponseHeaderPayloadEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderPayloadEncryptionKey)
 
 	aeadEncryptedHeaderPayload := aeadResponseHeaderPayloadEncryptionAEAD.Seal(nil, aeadResponseHeaderPayloadEncryptionIV, aeadEncryptedHeaderBuffer.Bytes(), nil)
 	common.Must2(io.Copy(writer, bytes.NewReader(aeadEncryptedHeaderPayload)))

+ 2 - 4
transport/internet/kcp/cryptreal.go

@@ -1,15 +1,13 @@
 package kcp
 
 import (
-	"crypto/aes"
 	"crypto/cipher"
 	"crypto/sha256"
 
-	"github.com/xtls/xray-core/common"
+	"github.com/xtls/xray-core/common/crypto"
 )
 
 func NewAEADAESGCMBasedOnSeed(seed string) cipher.AEAD {
 	hashedSeed := sha256.Sum256([]byte(seed))
-	aesBlock := common.Must2(aes.NewCipher(hashedSeed[:16])).(cipher.Block)
-	return common.Must2(cipher.NewGCM(aesBlock)).(cipher.AEAD)
+	return crypto.NewAesGcm(hashedSeed[:])
 }

+ 1 - 4
transport/internet/reality/reality.go

@@ -3,8 +3,6 @@ package reality
 import (
 	"bytes"
 	"context"
-	"crypto/aes"
-	"crypto/cipher"
 	"crypto/ecdh"
 	"crypto/ed25519"
 	"crypto/hmac"
@@ -169,8 +167,7 @@ func UClient(c net.Conn, config *Config, ctx context.Context, dest net.Destinati
 		if _, err := hkdf.New(sha256.New, uConn.AuthKey, hello.Random[:20], []byte("REALITY")).Read(uConn.AuthKey); err != nil {
 			return nil, err
 		}
-		block, _ := aes.NewCipher(uConn.AuthKey)
-		aead, _ := cipher.NewGCM(block)
+		aead := crypto.NewAesGcm(uConn.AuthKey)
 		if config.Show {
 			fmt.Printf("REALITY localAddr: %v\tuConn.AuthKey[:16]: %v\tAEAD: %T\n", localAddr, uConn.AuthKey[:16], aead)
 		}

+ 1 - 1
transport/internet/splithttp/dialer.go

@@ -297,7 +297,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
 	if transportConfiguration.DownloadSettings != nil {
 		globalDialerAccess.Lock()
 		if streamSettings.DownloadSettings == nil {
-			streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings)).(*internet.MemoryStreamConfig)
+			streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings))
 			if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.Penetrate {
 				streamSettings.DownloadSettings.SocketSettings = streamSettings.SocketSettings
 			}