net.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. /*
  2. * Copyright (c) 2016, Psiphon Inc.
  3. * All rights reserved.
  4. *
  5. * This program is free software: you can redistribute it and/or modify
  6. * it under the terms of the GNU General Public License as published by
  7. * the Free Software Foundation, either version 3 of the License, or
  8. * (at your option) any later version.
  9. *
  10. * This program is distributed in the hope that it will be useful,
  11. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. * GNU General Public License for more details.
  14. *
  15. * You should have received a copy of the GNU General Public License
  16. * along with this program. If not, see <http://www.gnu.org/licenses/>.
  17. *
  18. */
  19. package common
  20. import (
  21. "container/list"
  22. "context"
  23. "io"
  24. "net"
  25. "net/http"
  26. "net/netip"
  27. "strconv"
  28. "sync"
  29. "time"
  30. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
  31. "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
  32. "github.com/miekg/dns"
  33. "github.com/wader/filtertransport"
  34. )
  35. // Dialer is a custom network dialer.
  36. type Dialer func(context.Context, string, string) (net.Conn, error)
  37. // NetDialer mimicks the net.Dialer interface.
  38. type NetDialer interface {
  39. Dial(network, address string) (net.Conn, error)
  40. DialContext(ctx context.Context, network, address string) (net.Conn, error)
  41. }
  42. // Closer defines the interface to a type, typically a net.Conn, that can be
  43. // closed.
  44. type Closer interface {
  45. IsClosed() bool
  46. }
  47. // CloseWriter defines the interface to a type, typically a net.TCPConn, that
  48. // implements CloseWrite.
  49. type CloseWriter interface {
  50. CloseWrite() error
  51. }
  52. // IrregularIndicator defines the interface for a type, typically a net.Conn,
  53. // that detects and reports irregular conditions during initial network
  54. // connection establishment.
  55. type IrregularIndicator interface {
  56. IrregularTunnelError() error
  57. }
  58. // UnderlyingTCPAddrSource defines the interface for a type, typically a
  59. // net.Conn, such as a server meek Conn, which has an underlying TCP conn(s),
  60. // providing access to the LocalAddr and RemoteAddr properties of the
  61. // underlying TCP conn.
  62. type UnderlyingTCPAddrSource interface {
  63. // GetUnderlyingTCPAddrs returns the LocalAddr and RemoteAddr properties of
  64. // the underlying TCP conn.
  65. GetUnderlyingTCPAddrs() (*net.TCPAddr, *net.TCPAddr, bool)
  66. }
  67. // FragmentorAccessor defines the interface for accessing properties
  68. // of a fragmentor Conn.
  69. type FragmentorAccessor interface {
  70. SetReplay(*prng.PRNG)
  71. GetReplay() (*prng.Seed, bool)
  72. StopFragmenting()
  73. }
  74. // HTTPRoundTripper is an adapter that allows using a function as a
  75. // http.RoundTripper.
  76. type HTTPRoundTripper struct {
  77. roundTrip func(*http.Request) (*http.Response, error)
  78. }
  79. // NewHTTPRoundTripper creates a new HTTPRoundTripper, using the specified
  80. // roundTrip function for HTTP round trips.
  81. func NewHTTPRoundTripper(
  82. roundTrip func(*http.Request) (*http.Response, error)) *HTTPRoundTripper {
  83. return &HTTPRoundTripper{roundTrip: roundTrip}
  84. }
  85. // RoundTrip implements http.RoundTripper RoundTrip.
  86. func (h HTTPRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
  87. return h.roundTrip(request)
  88. }
  89. // TerminateHTTPConnection sends a 404 response to a client and also closes
  90. // the persistent connection.
  91. func TerminateHTTPConnection(
  92. responseWriter http.ResponseWriter, request *http.Request) {
  93. responseWriter.Header().Set("Content-Length", "0")
  94. http.NotFound(responseWriter, request)
  95. hijack, ok := responseWriter.(http.Hijacker)
  96. if !ok {
  97. return
  98. }
  99. conn, buffer, err := hijack.Hijack()
  100. if err != nil {
  101. return
  102. }
  103. buffer.Flush()
  104. conn.Close()
  105. }
  106. // IPAddressFromAddr is a helper which extracts an IP address
  107. // from a net.Addr or returns "" if there is no IP address.
  108. func IPAddressFromAddr(addr net.Addr) string {
  109. ipAddress := ""
  110. if addr != nil {
  111. host, _, err := net.SplitHostPort(addr.String())
  112. if err == nil {
  113. ipAddress = host
  114. }
  115. }
  116. return ipAddress
  117. }
  118. // PortFromAddr is a helper which extracts a port number from a net.Addr or
  119. // returns 0 if there is no port number.
  120. func PortFromAddr(addr net.Addr) int {
  121. port := 0
  122. if addr != nil {
  123. _, portStr, err := net.SplitHostPort(addr.String())
  124. if err == nil {
  125. port, _ = strconv.Atoi(portStr)
  126. }
  127. }
  128. return port
  129. }
  130. // Conns is a synchronized list of Conns that is used to coordinate
  131. // interrupting a set of goroutines establishing connections, or
  132. // close a set of open connections, etc.
  133. // Once the list is closed, no more items may be added to the
  134. // list (unless it is reset).
  135. type Conns[T interface {
  136. comparable
  137. io.Closer
  138. }] struct {
  139. mutex sync.Mutex
  140. isClosed bool
  141. conns map[T]bool
  142. }
  143. // NewConns initializes a new Conns.
  144. func NewConns[T interface {
  145. comparable
  146. io.Closer
  147. }]() *Conns[T] {
  148. return &Conns[T]{}
  149. }
  150. func (conns *Conns[T]) Reset() {
  151. conns.mutex.Lock()
  152. defer conns.mutex.Unlock()
  153. conns.isClosed = false
  154. conns.conns = make(map[T]bool)
  155. }
  156. func (conns *Conns[T]) Add(conn T) bool {
  157. conns.mutex.Lock()
  158. defer conns.mutex.Unlock()
  159. if conns.isClosed {
  160. return false
  161. }
  162. if conns.conns == nil {
  163. conns.conns = make(map[T]bool)
  164. }
  165. conns.conns[conn] = true
  166. return true
  167. }
  168. func (conns *Conns[T]) Remove(conn T) {
  169. conns.mutex.Lock()
  170. defer conns.mutex.Unlock()
  171. delete(conns.conns, conn)
  172. }
  173. func (conns *Conns[T]) CloseAll() {
  174. conns.mutex.Lock()
  175. conns.isClosed = true
  176. closeConns := conns.conns
  177. conns.conns = make(map[T]bool)
  178. conns.mutex.Unlock()
  179. // Close is invoked outside of the mutex in case a member conn's Close
  180. // invokes Remove.
  181. for conn := range closeConns {
  182. _ = conn.Close()
  183. }
  184. }
  185. func (conns *Conns[T]) IsClosed() bool {
  186. conns.mutex.Lock()
  187. defer conns.mutex.Unlock()
  188. return conns.isClosed
  189. }
  190. // LRUConns is a concurrency-safe list of net.Conns ordered
  191. // by recent activity. Its purpose is to facilitate closing
  192. // the oldest connection in a set of connections.
  193. //
  194. // New connections added are referenced by a LRUConnsEntry,
  195. // which is used to Touch() active connections, which
  196. // promotes them to the front of the order and to Remove()
  197. // connections that are no longer LRU candidates.
  198. //
  199. // CloseOldest() will remove the oldest connection from the
  200. // list and call net.Conn.Close() on the connection.
  201. //
  202. // After an entry has been removed, LRUConnsEntry Touch()
  203. // and Remove() will have no effect.
  204. type LRUConns struct {
  205. mutex sync.Mutex
  206. list *list.List
  207. }
  208. // NewLRUConns initializes a new LRUConns.
  209. func NewLRUConns() *LRUConns {
  210. return &LRUConns{list: list.New()}
  211. }
  212. // Add inserts a net.Conn as the freshest connection
  213. // in a LRUConns and returns an LRUConnsEntry to be
  214. // used to freshen the connection or remove the connection
  215. // from the LRU list.
  216. func (conns *LRUConns) Add(conn net.Conn) *LRUConnsEntry {
  217. conns.mutex.Lock()
  218. defer conns.mutex.Unlock()
  219. return &LRUConnsEntry{
  220. lruConns: conns,
  221. element: conns.list.PushFront(conn),
  222. }
  223. }
  224. // CloseOldest closes the oldest connection in a
  225. // LRUConns. It calls net.Conn.Close() on the
  226. // connection.
  227. func (conns *LRUConns) CloseOldest() {
  228. conns.mutex.Lock()
  229. oldest := conns.list.Back()
  230. if oldest != nil {
  231. conns.list.Remove(oldest)
  232. }
  233. // Release mutex before closing conn
  234. conns.mutex.Unlock()
  235. if oldest != nil {
  236. oldest.Value.(net.Conn).Close()
  237. }
  238. }
  239. // LRUConnsEntry is an entry in a LRUConns list.
  240. type LRUConnsEntry struct {
  241. lruConns *LRUConns
  242. element *list.Element
  243. }
  244. // Remove deletes the connection referenced by the
  245. // LRUConnsEntry from the associated LRUConns.
  246. // Has no effect if the entry was not initialized
  247. // or previously removed.
  248. func (entry *LRUConnsEntry) Remove() {
  249. if entry.lruConns == nil || entry.element == nil {
  250. return
  251. }
  252. entry.lruConns.mutex.Lock()
  253. defer entry.lruConns.mutex.Unlock()
  254. entry.lruConns.list.Remove(entry.element)
  255. }
  256. // Touch promotes the connection referenced by the
  257. // LRUConnsEntry to the front of the associated LRUConns.
  258. // Has no effect if the entry was not initialized
  259. // or previously removed.
  260. func (entry *LRUConnsEntry) Touch() {
  261. if entry.lruConns == nil || entry.element == nil {
  262. return
  263. }
  264. entry.lruConns.mutex.Lock()
  265. defer entry.lruConns.mutex.Unlock()
  266. entry.lruConns.list.MoveToFront(entry.element)
  267. }
  268. // IsBogon checks if the specified IP is a bogon (loopback, private addresses,
  269. // link-local addresses, etc.)
  270. func IsBogon(IP net.IP) bool {
  271. return filtertransport.FindIPNet(
  272. filtertransport.DefaultFilteredNetworks, IP)
  273. }
  274. // ParseDNSQuestion parses a DNS message. When the message is a query,
  275. // the first question, a fully-qualified domain name, is returned.
  276. //
  277. // For other valid DNS messages, "" is returned. An error is returned only
  278. // for invalid DNS messages.
  279. //
  280. // Limitations:
  281. // - Only the first Question field is extracted.
  282. // - ParseDNSQuestion only functions for plaintext DNS and cannot
  283. // extract domains from DNS-over-TLS/HTTPS, etc.
  284. func ParseDNSQuestion(request []byte) (string, error) {
  285. m := new(dns.Msg)
  286. err := m.Unpack(request)
  287. if err != nil {
  288. return "", errors.Trace(err)
  289. }
  290. if len(m.Question) > 0 {
  291. return m.Question[0].Name, nil
  292. }
  293. return "", nil
  294. }
  295. // WriteTimeoutUDPConn sets write deadlines before each UDP packet write.
  296. //
  297. // Generally, a UDP packet write doesn't block. However, Go's
  298. // internal/poll.FD.WriteMsg continues to loop when syscall.SendmsgN fails
  299. // with EAGAIN, which indicates that an OS socket buffer is currently full;
  300. // in certain OS states this may cause WriteMsgUDP/etc. to block
  301. // indefinitely. In this scenario, we want to instead behave as if the packet
  302. // were dropped, so we set a write deadline which will eventually interrupt
  303. // any EAGAIN loop.
  304. type WriteTimeoutUDPConn struct {
  305. *net.UDPConn
  306. }
  307. func (conn *WriteTimeoutUDPConn) Write(b []byte) (int, error) {
  308. err := conn.SetWriteDeadline(time.Now().Add(UDP_PACKET_WRITE_TIMEOUT))
  309. if err != nil {
  310. return 0, errors.Trace(err)
  311. }
  312. // Do not wrap any I/O err returned by UDPConn
  313. return conn.UDPConn.Write(b)
  314. }
  315. func (conn *WriteTimeoutUDPConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (int, int, error) {
  316. err := conn.SetWriteDeadline(time.Now().Add(UDP_PACKET_WRITE_TIMEOUT))
  317. if err != nil {
  318. return 0, 0, errors.Trace(err)
  319. }
  320. // Do not wrap any I/O err returned by UDPConn
  321. return conn.UDPConn.WriteMsgUDP(b, oob, addr)
  322. }
  323. func (conn *WriteTimeoutUDPConn) WriteMsgUDPAddrPort(b, oob []byte, addr netip.AddrPort) (int, int, error) {
  324. err := conn.SetWriteDeadline(time.Now().Add(UDP_PACKET_WRITE_TIMEOUT))
  325. if err != nil {
  326. return 0, 0, errors.Trace(err)
  327. }
  328. // Do not wrap any I/O err returned by UDPConn
  329. return conn.UDPConn.WriteMsgUDPAddrPort(b, oob, addr)
  330. }
  331. func (conn *WriteTimeoutUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
  332. err := conn.SetWriteDeadline(time.Now().Add(UDP_PACKET_WRITE_TIMEOUT))
  333. if err != nil {
  334. return 0, errors.Trace(err)
  335. }
  336. // Do not wrap any I/O err returned by UDPConn
  337. return conn.UDPConn.WriteTo(b, addr)
  338. }
  339. func (conn *WriteTimeoutUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
  340. err := conn.SetWriteDeadline(time.Now().Add(UDP_PACKET_WRITE_TIMEOUT))
  341. if err != nil {
  342. return 0, errors.Trace(err)
  343. }
  344. // Do not wrap any I/O err returned by UDPConn
  345. return conn.UDPConn.WriteToUDPAddrPort(b, addr)
  346. }
  347. func (conn *WriteTimeoutUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
  348. err := conn.SetWriteDeadline(time.Now().Add(UDP_PACKET_WRITE_TIMEOUT))
  349. if err != nil {
  350. return 0, errors.Trace(err)
  351. }
  352. // Do not wrap any I/O err returned by UDPConn
  353. return conn.UDPConn.WriteToUDP(b, addr)
  354. }
  355. // WriteTimeoutPacketConn is the equivilent of WriteTimeoutUDPConn for
  356. // non-*net.UDPConns.
  357. type WriteTimeoutPacketConn struct {
  358. net.PacketConn
  359. }
  360. const UDP_PACKET_WRITE_TIMEOUT = 1 * time.Second
  361. func (conn *WriteTimeoutPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
  362. err := conn.SetWriteDeadline(time.Now().Add(UDP_PACKET_WRITE_TIMEOUT))
  363. if err != nil {
  364. return 0, errors.Trace(err)
  365. }
  366. // Do not wrap any I/O err returned by PacketConn
  367. return conn.PacketConn.WriteTo(b, addr)
  368. }
  369. // GetMetrics implements the common.MetricsSource interface.
  370. func (conn *WriteTimeoutPacketConn) GetMetrics() LogFields {
  371. logFields := make(LogFields)
  372. // Include metrics, such as inproxy and fragmentor metrics, from the
  373. // underlying dial conn.
  374. underlyingMetrics, ok := conn.PacketConn.(MetricsSource)
  375. if ok {
  376. logFields.Add(underlyingMetrics.GetMetrics())
  377. }
  378. return logFields
  379. }