protocol.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package proxyproto
  2. import (
  3. "bufio"
  4. "bytes"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "log"
  9. "net"
  10. "strconv"
  11. "strings"
  12. "sync"
  13. "time"
  14. )
  15. var (
  16. // prefix is the string we look for at the start of a connection
  17. // to check if this connection is using the proxy protocol
  18. prefix = []byte("PROXY ")
  19. prefixLen = len(prefix)
  20. ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information")
  21. )
  22. // SourceChecker can be used to decide whether to trust the PROXY info or pass
  23. // the original connection address through. If set, the connecting address is
  24. // passed in as an argument. If the function returns an error due to the source
  25. // being disallowed, it should return ErrInvalidUpstream.
  26. //
  27. // If error is not nil, the call to Accept() will fail. If the reason for
  28. // triggering this failure is due to a disallowed source, it should return
  29. // ErrInvalidUpstream.
  30. //
  31. // If bool is true, the PROXY-set address is used.
  32. //
  33. // If bool is false, the connection's remote address is used, rather than the
  34. // address claimed in the PROXY info.
  35. type SourceChecker func(net.Addr) (bool, error)
  36. // Listener is used to wrap an underlying listener,
  37. // whose connections may be using the HAProxy Proxy Protocol (version 1).
  38. // If the connection is using the protocol, the RemoteAddr() will return
  39. // the correct client address.
  40. //
  41. // Optionally define ProxyHeaderTimeout to set a maximum time to
  42. // receive the Proxy Protocol Header. Zero means no timeout.
  43. type Listener struct {
  44. Listener net.Listener
  45. ProxyHeaderTimeout time.Duration
  46. SourceCheck SourceChecker
  47. }
  48. // Conn is used to wrap and underlying connection which
  49. // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
  50. // return the address of the client instead of the proxy address.
  51. type Conn struct {
  52. bufReader *bufio.Reader
  53. conn net.Conn
  54. dstAddr *net.TCPAddr
  55. srcAddr *net.TCPAddr
  56. useConnRemoteAddr bool
  57. once sync.Once
  58. proxyHeaderTimeout time.Duration
  59. }
  60. // Accept waits for and returns the next connection to the listener.
  61. func (p *Listener) Accept() (net.Conn, error) {
  62. // Get the underlying connection
  63. conn, err := p.Listener.Accept()
  64. if err != nil {
  65. return nil, err
  66. }
  67. var useConnRemoteAddr bool
  68. if p.SourceCheck != nil {
  69. allowed, err := p.SourceCheck(conn.RemoteAddr())
  70. if err != nil {
  71. return nil, err
  72. }
  73. if !allowed {
  74. useConnRemoteAddr = true
  75. }
  76. }
  77. newConn := NewConn(conn, p.ProxyHeaderTimeout)
  78. newConn.useConnRemoteAddr = useConnRemoteAddr
  79. return newConn, nil
  80. }
  81. // Close closes the underlying listener.
  82. func (p *Listener) Close() error {
  83. return p.Listener.Close()
  84. }
  85. // Addr returns the underlying listener's network address.
  86. func (p *Listener) Addr() net.Addr {
  87. return p.Listener.Addr()
  88. }
  89. // NewConn is used to wrap a net.Conn that may be speaking
  90. // the proxy protocol into a proxyproto.Conn
  91. func NewConn(conn net.Conn, timeout time.Duration) *Conn {
  92. pConn := &Conn{
  93. bufReader: bufio.NewReader(conn),
  94. conn: conn,
  95. proxyHeaderTimeout: timeout,
  96. }
  97. return pConn
  98. }
  99. // Read is check for the proxy protocol header when doing
  100. // the initial scan. If there is an error parsing the header,
  101. // it is returned and the socket is closed.
  102. func (p *Conn) Read(b []byte) (int, error) {
  103. var err error
  104. p.once.Do(func() { err = p.checkPrefix() })
  105. if err != nil {
  106. return 0, err
  107. }
  108. return p.bufReader.Read(b)
  109. }
  110. func (p *Conn) Write(b []byte) (int, error) {
  111. return p.conn.Write(b)
  112. }
  113. func (p *Conn) Close() error {
  114. return p.conn.Close()
  115. }
  116. func (p *Conn) LocalAddr() net.Addr {
  117. return p.conn.LocalAddr()
  118. }
  119. // RemoteAddr returns the address of the client if the proxy
  120. // protocol is being used, otherwise just returns the address of
  121. // the socket peer. If there is an error parsing the header, the
  122. // address of the client is not returned, and the socket is closed.
  123. // Once implication of this is that the call could block if the
  124. // client is slow. Using a Deadline is recommended if this is called
  125. // before Read()
  126. func (p *Conn) RemoteAddr() net.Addr {
  127. p.once.Do(func() {
  128. if err := p.checkPrefix(); err != nil && err != io.EOF {
  129. log.Printf("[ERR] Failed to read proxy prefix: %v", err)
  130. p.Close()
  131. p.bufReader = bufio.NewReader(p.conn)
  132. }
  133. })
  134. if p.srcAddr != nil && !p.useConnRemoteAddr {
  135. return p.srcAddr
  136. }
  137. return p.conn.RemoteAddr()
  138. }
  139. func (p *Conn) SetDeadline(t time.Time) error {
  140. return p.conn.SetDeadline(t)
  141. }
  142. func (p *Conn) SetReadDeadline(t time.Time) error {
  143. return p.conn.SetReadDeadline(t)
  144. }
  145. func (p *Conn) SetWriteDeadline(t time.Time) error {
  146. return p.conn.SetWriteDeadline(t)
  147. }
  148. func (p *Conn) checkPrefix() error {
  149. if p.proxyHeaderTimeout != 0 {
  150. readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
  151. p.conn.SetReadDeadline(readDeadLine)
  152. defer p.conn.SetReadDeadline(time.Time{})
  153. }
  154. // Incrementally check each byte of the prefix
  155. for i := 1; i <= prefixLen; i++ {
  156. inp, err := p.bufReader.Peek(i)
  157. if err != nil {
  158. if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
  159. return nil
  160. } else {
  161. return err
  162. }
  163. }
  164. // Check for a prefix mis-match, quit early
  165. if !bytes.Equal(inp, prefix[:i]) {
  166. return nil
  167. }
  168. }
  169. // Read the header line
  170. header, err := p.bufReader.ReadString('\n')
  171. if err != nil {
  172. p.conn.Close()
  173. return err
  174. }
  175. // Strip the carriage return and new line
  176. header = header[:len(header)-2]
  177. // Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
  178. parts := strings.Split(header, " ")
  179. if len(parts) != 6 {
  180. p.conn.Close()
  181. return fmt.Errorf("Invalid header line: %s", header)
  182. }
  183. // Verify the type is known
  184. switch parts[1] {
  185. case "TCP4":
  186. case "TCP6":
  187. default:
  188. p.conn.Close()
  189. return fmt.Errorf("Unhandled address type: %s", parts[1])
  190. }
  191. // Parse out the source address
  192. ip := net.ParseIP(parts[2])
  193. if ip == nil {
  194. p.conn.Close()
  195. return fmt.Errorf("Invalid source ip: %s", parts[2])
  196. }
  197. port, err := strconv.Atoi(parts[4])
  198. if err != nil {
  199. p.conn.Close()
  200. return fmt.Errorf("Invalid source port: %s", parts[4])
  201. }
  202. p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
  203. // Parse out the destination address
  204. ip = net.ParseIP(parts[3])
  205. if ip == nil {
  206. p.conn.Close()
  207. return fmt.Errorf("Invalid destination ip: %s", parts[3])
  208. }
  209. port, err = strconv.Atoi(parts[5])
  210. if err != nil {
  211. p.conn.Close()
  212. return fmt.Errorf("Invalid destination port: %s", parts[5])
  213. }
  214. p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
  215. return nil
  216. }