| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- package proxyproto
- import (
- "bufio"
- "bytes"
- "errors"
- "fmt"
- "io"
- "log"
- "net"
- "strconv"
- "strings"
- "sync"
- "time"
- )
- var (
- // prefix is the string we look for at the start of a connection
- // to check if this connection is using the proxy protocol
- prefix = []byte("PROXY ")
- prefixLen = len(prefix)
- ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information")
- )
- // SourceChecker can be used to decide whether to trust the PROXY info or pass
- // the original connection address through. If set, the connecting address is
- // passed in as an argument. If the function returns an error due to the source
- // being disallowed, it should return ErrInvalidUpstream.
- //
- // If error is not nil, the call to Accept() will fail. If the reason for
- // triggering this failure is due to a disallowed source, it should return
- // ErrInvalidUpstream.
- //
- // If bool is true, the PROXY-set address is used.
- //
- // If bool is false, the connection's remote address is used, rather than the
- // address claimed in the PROXY info.
- type SourceChecker func(net.Addr) (bool, error)
- // Listener is used to wrap an underlying listener,
- // whose connections may be using the HAProxy Proxy Protocol (version 1).
- // If the connection is using the protocol, the RemoteAddr() will return
- // the correct client address.
- //
- // Optionally define ProxyHeaderTimeout to set a maximum time to
- // receive the Proxy Protocol Header. Zero means no timeout.
- type Listener struct {
- Listener net.Listener
- ProxyHeaderTimeout time.Duration
- SourceCheck SourceChecker
- }
- // Conn is used to wrap and underlying connection which
- // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
- // return the address of the client instead of the proxy address.
- type Conn struct {
- bufReader *bufio.Reader
- conn net.Conn
- dstAddr *net.TCPAddr
- srcAddr *net.TCPAddr
- useConnRemoteAddr bool
- once sync.Once
- proxyHeaderTimeout time.Duration
- }
- // Accept waits for and returns the next connection to the listener.
- func (p *Listener) Accept() (net.Conn, error) {
- // Get the underlying connection
- conn, err := p.Listener.Accept()
- if err != nil {
- return nil, err
- }
- var useConnRemoteAddr bool
- if p.SourceCheck != nil {
- allowed, err := p.SourceCheck(conn.RemoteAddr())
- if err != nil {
- return nil, err
- }
- if !allowed {
- useConnRemoteAddr = true
- }
- }
- newConn := NewConn(conn, p.ProxyHeaderTimeout)
- newConn.useConnRemoteAddr = useConnRemoteAddr
- return newConn, nil
- }
- // Close closes the underlying listener.
- func (p *Listener) Close() error {
- return p.Listener.Close()
- }
- // Addr returns the underlying listener's network address.
- func (p *Listener) Addr() net.Addr {
- return p.Listener.Addr()
- }
- // NewConn is used to wrap a net.Conn that may be speaking
- // the proxy protocol into a proxyproto.Conn
- func NewConn(conn net.Conn, timeout time.Duration) *Conn {
- pConn := &Conn{
- bufReader: bufio.NewReader(conn),
- conn: conn,
- proxyHeaderTimeout: timeout,
- }
- return pConn
- }
- // Read is check for the proxy protocol header when doing
- // the initial scan. If there is an error parsing the header,
- // it is returned and the socket is closed.
- func (p *Conn) Read(b []byte) (int, error) {
- var err error
- p.once.Do(func() { err = p.checkPrefix() })
- if err != nil {
- return 0, err
- }
- return p.bufReader.Read(b)
- }
- func (p *Conn) Write(b []byte) (int, error) {
- return p.conn.Write(b)
- }
- func (p *Conn) Close() error {
- return p.conn.Close()
- }
- func (p *Conn) LocalAddr() net.Addr {
- return p.conn.LocalAddr()
- }
- // RemoteAddr returns the address of the client if the proxy
- // protocol is being used, otherwise just returns the address of
- // the socket peer. If there is an error parsing the header, the
- // address of the client is not returned, and the socket is closed.
- // Once implication of this is that the call could block if the
- // client is slow. Using a Deadline is recommended if this is called
- // before Read()
- func (p *Conn) RemoteAddr() net.Addr {
- p.once.Do(func() {
- if err := p.checkPrefix(); err != nil && err != io.EOF {
- log.Printf("[ERR] Failed to read proxy prefix: %v", err)
- p.Close()
- p.bufReader = bufio.NewReader(p.conn)
- }
- })
- if p.srcAddr != nil && !p.useConnRemoteAddr {
- return p.srcAddr
- }
- return p.conn.RemoteAddr()
- }
- func (p *Conn) SetDeadline(t time.Time) error {
- return p.conn.SetDeadline(t)
- }
- func (p *Conn) SetReadDeadline(t time.Time) error {
- return p.conn.SetReadDeadline(t)
- }
- func (p *Conn) SetWriteDeadline(t time.Time) error {
- return p.conn.SetWriteDeadline(t)
- }
- func (p *Conn) checkPrefix() error {
- if p.proxyHeaderTimeout != 0 {
- readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
- p.conn.SetReadDeadline(readDeadLine)
- defer p.conn.SetReadDeadline(time.Time{})
- }
- // Incrementally check each byte of the prefix
- for i := 1; i <= prefixLen; i++ {
- inp, err := p.bufReader.Peek(i)
- if err != nil {
- if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
- return nil
- } else {
- return err
- }
- }
- // Check for a prefix mis-match, quit early
- if !bytes.Equal(inp, prefix[:i]) {
- return nil
- }
- }
- // Read the header line
- header, err := p.bufReader.ReadString('\n')
- if err != nil {
- p.conn.Close()
- return err
- }
- // Strip the carriage return and new line
- header = header[:len(header)-2]
- // Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
- parts := strings.Split(header, " ")
- if len(parts) != 6 {
- p.conn.Close()
- return fmt.Errorf("Invalid header line: %s", header)
- }
- // Verify the type is known
- switch parts[1] {
- case "TCP4":
- case "TCP6":
- default:
- p.conn.Close()
- return fmt.Errorf("Unhandled address type: %s", parts[1])
- }
- // Parse out the source address
- ip := net.ParseIP(parts[2])
- if ip == nil {
- p.conn.Close()
- return fmt.Errorf("Invalid source ip: %s", parts[2])
- }
- port, err := strconv.Atoi(parts[4])
- if err != nil {
- p.conn.Close()
- return fmt.Errorf("Invalid source port: %s", parts[4])
- }
- p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
- // Parse out the destination address
- ip = net.ParseIP(parts[3])
- if ip == nil {
- p.conn.Close()
- return fmt.Errorf("Invalid destination ip: %s", parts[3])
- }
- port, err = strconv.Atoi(parts[5])
- if err != nil {
- p.conn.Close()
- return fmt.Errorf("Invalid destination port: %s", parts[5])
- }
- p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
- return nil
- }
|