websocket.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package goproxy
  2. import (
  3. "bufio"
  4. "crypto/tls"
  5. "io"
  6. "net/http"
  7. "net/url"
  8. "strings"
  9. )
  10. func headerContains(header http.Header, name string, value string) bool {
  11. for _, v := range header[name] {
  12. for _, s := range strings.Split(v, ",") {
  13. if strings.EqualFold(value, strings.TrimSpace(s)) {
  14. return true
  15. }
  16. }
  17. }
  18. return false
  19. }
  20. func isWebSocketRequest(r *http.Request) bool {
  21. return headerContains(r.Header, "Connection", "upgrade") &&
  22. headerContains(r.Header, "Upgrade", "websocket")
  23. }
  24. func (proxy *ProxyHttpServer) serveWebsocketTLS(ctx *ProxyCtx, w http.ResponseWriter, req *http.Request, tlsConfig *tls.Config, clientConn *tls.Conn) {
  25. targetURL := url.URL{Scheme: "wss", Host: req.URL.Host, Path: req.URL.Path}
  26. // Connect to upstream
  27. targetConn, err := tls.Dial("tcp", targetURL.Host, tlsConfig)
  28. if err != nil {
  29. ctx.Warnf("Error dialing target site: %v", err)
  30. return
  31. }
  32. defer targetConn.Close()
  33. // Perform handshake
  34. if err := proxy.websocketHandshake(ctx, req, targetConn, clientConn); err != nil {
  35. ctx.Warnf("Websocket handshake error: %v", err)
  36. return
  37. }
  38. // Proxy wss connection
  39. proxy.proxyWebsocket(ctx, targetConn, clientConn)
  40. }
  41. func (proxy *ProxyHttpServer) serveWebsocket(ctx *ProxyCtx, w http.ResponseWriter, req *http.Request) {
  42. targetURL := url.URL{Scheme: "ws", Host: req.URL.Host, Path: req.URL.Path}
  43. targetConn, err := proxy.connectDial("tcp", targetURL.Host)
  44. if err != nil {
  45. ctx.Warnf("Error dialing target site: %v", err)
  46. return
  47. }
  48. defer targetConn.Close()
  49. // Connect to Client
  50. hj, ok := w.(http.Hijacker)
  51. if !ok {
  52. panic("httpserver does not support hijacking")
  53. }
  54. clientConn, _, err := hj.Hijack()
  55. if err != nil {
  56. ctx.Warnf("Hijack error: %v", err)
  57. return
  58. }
  59. // Perform handshake
  60. if err := proxy.websocketHandshake(ctx, req, targetConn, clientConn); err != nil {
  61. ctx.Warnf("Websocket handshake error: %v", err)
  62. return
  63. }
  64. // Proxy ws connection
  65. proxy.proxyWebsocket(ctx, targetConn, clientConn)
  66. }
  67. func (proxy *ProxyHttpServer) websocketHandshake(ctx *ProxyCtx, req *http.Request, targetSiteConn io.ReadWriter, clientConn io.ReadWriter) error {
  68. // write handshake request to target
  69. err := req.Write(targetSiteConn)
  70. if err != nil {
  71. ctx.Warnf("Error writing upgrade request: %v", err)
  72. return err
  73. }
  74. targetTLSReader := bufio.NewReader(targetSiteConn)
  75. // Read handshake response from target
  76. resp, err := http.ReadResponse(targetTLSReader, req)
  77. if err != nil {
  78. ctx.Warnf("Error reading handhsake response %v", err)
  79. return err
  80. }
  81. // Run response through handlers
  82. resp = proxy.filterResponse(resp, ctx)
  83. // Proxy handshake back to client
  84. err = resp.Write(clientConn)
  85. if err != nil {
  86. ctx.Warnf("Error writing handshake response: %v", err)
  87. return err
  88. }
  89. return nil
  90. }
  91. func (proxy *ProxyHttpServer) proxyWebsocket(ctx *ProxyCtx, dest io.ReadWriter, source io.ReadWriter) {
  92. errChan := make(chan error, 2)
  93. cp := func(dst io.Writer, src io.Reader) {
  94. _, err := io.Copy(dst, src)
  95. ctx.Warnf("Websocket error: %v", err)
  96. errChan <- err
  97. }
  98. // Start proxying websocket data
  99. go cp(dest, source)
  100. go cp(source, dest)
  101. <-errChan
  102. }