소스 검색

Log: More flexible mask addr

Fangliding 4 달 전
부모
커밋
95aff2e35d
2개의 변경된 파일119개의 추가작업 그리고 32개의 파일을 삭제
  1. 82 32
      app/log/log.go
  2. 37 0
      app/log/log_test.go

+ 82 - 32
app/log/log.go

@@ -2,8 +2,9 @@ package log
 
 import (
 	"context"
-	"fmt"
+	"net"
 	"regexp"
+	"strconv"
 	"strings"
 	"sync"
 
@@ -20,14 +21,23 @@ type Instance struct {
 	errorLogger  log.Handler
 	active       bool
 	dns          bool
+	mask4        int
+	mask6        int
 }
 
 // New creates a new log.Instance based on the given config.
 func New(ctx context.Context, config *Config) (*Instance, error) {
+	m4, m6, err := ParseMaskAddress(config.MaskAddress)
+	if err != nil {
+		return nil, err
+	}
+
 	g := &Instance{
 		config: config,
 		active: false,
 		dns:    config.EnableDnsLog,
+		mask4:  m4,
+		mask6:  m6,
 	}
 	log.RegisterHandler(g)
 
@@ -104,7 +114,11 @@ func (g *Instance) Handle(msg log.Message) {
 
 	var Msg log.Message
 	if g.config.MaskAddress != "" {
-		Msg = &MaskedMsgWrapper{Message: msg, config: g.config}
+		Msg = &MaskedMsgWrapper{
+			Message: msg,
+			Mask4:   g.mask4,
+			Mask6:   g.mask6,
+		}
 	} else {
 		Msg = msg
 	}
@@ -149,51 +163,87 @@ func (g *Instance) Close() error {
 	return nil
 }
 
+func ParseMaskAddress(c string) (int, int, error) {
+	var m4, m6 int
+	switch c {
+	case "half":
+		m4, m6 = 16, 32
+	case "quarter":
+		m4, m6 = 8, 16
+	case "full":
+		m4, m6 = 0, 0
+	case "":
+		// do nothing
+	default:
+		if parts := strings.Split(c, "+"); len(parts) > 0 {
+			if len(parts) >= 1 && parts[0] != "" {
+				i, err := strconv.Atoi(strings.TrimPrefix(parts[0], "/"))
+				if err != nil {
+					return 32, 128, err
+				}
+				m4 = i
+			}
+			if len(parts) >= 2 && parts[1] != "" {
+				i, err := strconv.Atoi(strings.TrimPrefix(parts[1], "/"))
+				if err != nil {
+					return 32, 128, err
+				}
+				m6 = i
+			}
+		}
+	}
+
+	if m4%8 != 0 || m4 > 32 || m4 < 0 {
+		return 32, 128, errors.New("Log Mask: ipv4 mask must be divisible by 8 and between 0-32")
+	}
+
+	return m4, m6, nil
+}
+
 // MaskedMsgWrapper is to wrap the string() method to mask IP addresses in the log.
 type MaskedMsgWrapper struct {
 	log.Message
-	config *Config
+	Mask4 int
+	Mask6 int
 }
 
+var (
+	ipv4Regex = regexp.MustCompile(`(\d{1,3}\.){3}\d{1,3}`)
+	ipv6Regex = regexp.MustCompile(`(?:[\da-fA-F]{0,4}:[\da-fA-F]{0,4}){2,7}`)
+)
+
 func (m *MaskedMsgWrapper) String() string {
 	str := m.Message.String()
 
-	ipv4Regex := regexp.MustCompile(`(\d{1,3}\.){3}\d{1,3}`)
-	ipv6Regex := regexp.MustCompile(`((?:[\da-fA-F]{0,4}:[\da-fA-F]{0,4}){2,7})(?:[\/\\%](\d{1,3}))?`)
-
 	// Process ipv4
-	maskedMsg := ipv4Regex.ReplaceAllStringFunc(str, func(ip string) string {
-		parts := strings.Split(ip, ".")
-		switch m.config.MaskAddress {
-		case "half":
-			return fmt.Sprintf("%s.%s.*.*", parts[0], parts[1])
-		case "quarter":
-			return fmt.Sprintf("%s.*.*.*", parts[0])
-		case "full":
+	maskedMsg := ipv4Regex.ReplaceAllStringFunc(str, func(s string) string {
+		if m.Mask4 == 32 {
+			return s
+		}
+		if m.Mask4 == 0 {
 			return "[Masked IPv4]"
-		default:
-			return ip
 		}
+
+		parts := strings.Split(s, ".")
+		for i := m.Mask4 / 8; i < 4; i++ {
+			parts[i] = "*"
+		}
+		return strings.Join(parts, ".")
 	})
 
 	// process ipv6
-	maskedMsg = ipv6Regex.ReplaceAllStringFunc(maskedMsg, func(ip string) string {
-		parts := strings.Split(ip, ":")
-		switch m.config.MaskAddress {
-		case "half":
-			if len(parts) >= 2 {
-				return fmt.Sprintf("%s:%s::/32", parts[0], parts[1])
-			}
-		case "quarter":
-			if len(parts) >= 1 {
-				return fmt.Sprintf("%s::/16", parts[0])
-			}
-		case "full":
-			return "Masked IPv6" // Do not use [Masked IPv6] like ipv4, or you will get "[[Masked IPv6]]" (v6 address already has [])
-		default:
-			return ip
+	maskedMsg = ipv6Regex.ReplaceAllStringFunc(maskedMsg, func(s string) string {
+		if m.Mask6 == 128 {
+			return s
+		}
+		if m.Mask6 == 0 {
+			return "Masked IPv6"
+		}
+		ip := net.ParseIP(s)
+		if ip == nil {
+			return s
 		}
-		return ip
+		return ip.Mask(net.CIDRMask(m.Mask6, 128)).String() + "/" + strconv.Itoa(m.Mask6)
 	})
 
 	return maskedMsg

+ 37 - 0
app/log/log_test.go

@@ -2,6 +2,7 @@ package log_test
 
 import (
 	"context"
+	"net"
 	"testing"
 
 	"github.com/golang/mock/gomock"
@@ -50,3 +51,39 @@ func TestCustomLogHandler(t *testing.T) {
 
 	common.Must(logger.Close())
 }
+
+func TestMaskAddress(t *testing.T) {
+	m4, m6, err := log.ParseMaskAddress("half")
+	if err != nil {
+		t.Fatal(err)
+	}
+	maskedAddr := log.MaskedMsgWrapper{
+		Mask4: m4,
+		Mask6: m6,
+	}
+	maskedAddr.Message = net.ParseIP("11.45.1.4")
+	if maskedAddr.String() != "11.45.*.*" {
+		t.Fatal("expected '11.45.*.*', but actually ", maskedAddr.String())
+	}
+	maskedAddr.Message = net.ParseIP("11:45:14:19:19:81:0::")
+	if maskedAddr.String() != "11:45::/32" {
+		t.Fatal("expected '11:45::/32', but actually", maskedAddr.String())
+	}
+
+	m4, m6, err = log.ParseMaskAddress("/16+/64")
+	if err != nil {
+		t.Fatal(err)
+	}
+	maskedAddr = log.MaskedMsgWrapper{
+		Mask4: m4,
+		Mask6: m6,
+	}
+	maskedAddr.Message = net.ParseIP("11.45.1.4")
+	if maskedAddr.String() != "11.45.*.*" {
+		t.Fatal("expected '11.45.*.*', but actually ", maskedAddr.String())
+	}
+	maskedAddr.Message = net.ParseIP("11:45:14:19:19:81:0::")
+	if maskedAddr.String() != "11:45:14:19::/64" {
+		t.Fatal("expected '11:45:14:19::/64', but actually", maskedAddr.String())
+	}
+}