Ver código fonte

Fixes from PR review

Adam Pritchard 8 anos atrás
pai
commit
2e8550915b
2 arquivos alterados com 35 adições e 22 exclusões
  1. 27 16
      psiphon/httpProxy.go
  2. 8 6
      psiphon/httpProxy_test.go

+ 27 - 16
psiphon/httpProxy.go

@@ -63,7 +63,7 @@ import (
 // An example use case for tunneled relaying with rewriting (/tunneled-rewrite/) is when the
 // content of retrieved files contains URLs that also need to be modified to be tunneled.
 // For example, in iOS 10 the UIWebView media player does not put requests through the
-// NSURLProtocol, so they are no tunneled. Instead, we rewrite those URLs to use the URL
+// NSURLProtocol, so they are not tunneled. Instead, we rewrite those URLs to use the URL
 // proxy, and rewrite retrieved playlist files so they also contain proxied URLs.
 //
 // Origin URLs must include the scheme prefix ("http://" or "https://") and must be
@@ -80,6 +80,8 @@ type HttpProxy struct {
 	urlProxyDirectClient   *http.Client
 	openConns              *common.Conns
 	stopListeningBroadcast chan struct{}
+	listenIP               string
+	listenPort             int
 }
 
 var _HTTP_PROXY_TYPE = "HTTP"
@@ -148,6 +150,9 @@ func NewHttpProxy(
 		Jar:       nil,
 	}
 
+	proxyIP, proxyPortString, _ := net.SplitHostPort(listener.Addr().String())
+	proxyPort, _ := strconv.Atoi(proxyPortString)
+
 	proxy = &HttpProxy{
 		tunneler:               tunneler,
 		listener:               listener,
@@ -159,6 +164,8 @@ func NewHttpProxy(
 		urlProxyDirectClient:   urlProxyDirectClient,
 		openConns:              new(common.Conns),
 		stopListeningBroadcast: make(chan struct{}),
+		listenIP:               proxyIP,
+		listenPort:             proxyPort,
 	}
 	proxy.serveWaitGroup.Add(1)
 	go proxy.serve()
@@ -174,7 +181,7 @@ func NewHttpProxy(
 	// NoticeListeningHttpProxyPort after that call.
 	// Also, check the listen backlog queue length -- shouldn't it be possible
 	// to enqueue pending connections between net.Listen() and httpServer.Serve()?
-	NoticeListeningHttpProxyPort(proxy.listener.Addr().(*net.TCPAddr).Port)
+	NoticeListeningHttpProxyPort(proxy.listenPort)
 
 	return proxy, nil
 }
@@ -341,6 +348,8 @@ func (proxy *HttpProxy) relayHTTPRequest(
 		return
 	}
 
+	defer response.Body.Close()
+
 	if rewrites != nil {
 		// NOTE: Rewrite functions are responsible for leaving response.Body in
 		// a valid, readable state if there's no error.
@@ -350,7 +359,7 @@ func (proxy *HttpProxy) relayHTTPRequest(
 
 			switch key {
 			case "m3u8":
-				err = rewriteM3U8(proxy.listener.Addr().(*net.TCPAddr).Port, response)
+				err = rewriteM3U8(proxy.listenIP, proxy.listenPort, response)
 			}
 
 			if err != nil {
@@ -375,8 +384,6 @@ func (proxy *HttpProxy) relayHTTPRequest(
 		}
 	}
 
-	defer response.Body.Close()
-
 	// Relay the response code and body
 	responseWriter.WriteHeader(response.StatusCode)
 	_, err = io.Copy(responseWriter, response.Body)
@@ -476,18 +483,22 @@ func toAbsoluteURL(baseURL *url.URL, relativeURLString string) string {
 // urlProxy port is the local HTTP proxy port.
 // If rewriteParams is nil, then no rewriting will be done. Otherwise, it should contain
 // supported rewriting flags (like "m3u8").
-func proxifyURL(urlProxyPort int, urlString string, rewriteParams []string) string {
+func proxifyURL(localHTTPProxyIP string, localHTTPProxyPort int, urlString string, rewriteParams []string) string {
 	// Note that we need to use the "opaque" form of URL so that it doesn't double-escape the path. See: https://github.com/golang/go/issues/10887
 
-	opaqueFormat := "//127.0.0.1:%d/tunneled/%s"
+	if localHTTPProxyIP == "0.0.0.0" {
+		localHTTPProxyIP = "127.0.0.1"
+	}
+
+	opaqueFormat := "//%s:%d/tunneled/%s"
 	if rewriteParams != nil {
-		opaqueFormat = "//127.0.0.1:%d/tunneled-rewrite/%s"
+		opaqueFormat = "//%s:%d/tunneled-rewrite/%s"
 	}
 
 	var proxifiedURL url.URL
 
 	proxifiedURL.Scheme = "http"
-	proxifiedURL.Opaque = fmt.Sprintf(opaqueFormat, urlProxyPort, url.QueryEscape(urlString))
+	proxifiedURL.Opaque = fmt.Sprintf(opaqueFormat, localHTTPProxyIP, localHTTPProxyPort, url.QueryEscape(urlString))
 
 	qp := proxifiedURL.Query()
 	for _, rewrite := range rewriteParams {
@@ -499,8 +510,8 @@ func proxifyURL(urlProxyPort int, urlString string, rewriteParams []string) stri
 }
 
 // Rewrite the contents of the M3U8 file in body to be compatible with URL proxying.
-// If error is returned, response body may not be valid.
-func rewriteM3U8(httpProxyPort int, response *http.Response) error {
+// If error is returned, response body may not be valid for reading.
+func rewriteM3U8(localHTTPProxyIP string, localHTTPProxyPort int, response *http.Response) error {
 	// Check URL path extension
 	extension := filepath.Ext(response.Request.URL.Path)
 	var shouldHandle = (extension == ".m3u8")
@@ -560,15 +571,15 @@ func rewriteM3U8(httpProxyPort int, response *http.Response) error {
 			}
 
 			if segment.URI != "" {
-				segment.URI = proxifyURL(httpProxyPort, toAbsoluteURL(response.Request.URL, segment.URI), nil)
+				segment.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, segment.URI), nil)
 			}
 
 			if segment.Key != nil && segment.Key.URI != "" {
-				segment.Key.URI = proxifyURL(httpProxyPort, toAbsoluteURL(response.Request.URL, segment.Key.URI), nil)
+				segment.Key.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, segment.Key.URI), nil)
 			}
 
 			if segment.Map != nil && segment.Map.URI != "" {
-				segment.Map.URI = proxifyURL(httpProxyPort, toAbsoluteURL(response.Request.URL, segment.Map.URI), nil)
+				segment.Map.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, segment.Map.URI), nil)
 			}
 		}
 		rewrittenBodyBytes = []byte(mediapl.String())
@@ -580,7 +591,7 @@ func rewriteM3U8(httpProxyPort int, response *http.Response) error {
 			}
 
 			if variant.URI != "" {
-				variant.URI = proxifyURL(httpProxyPort, toAbsoluteURL(response.Request.URL, variant.URI), []string{"m3u8"})
+				variant.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, variant.URI), []string{"m3u8"})
 			}
 
 			for _, alternative := range variant.Alternatives {
@@ -589,7 +600,7 @@ func rewriteM3U8(httpProxyPort int, response *http.Response) error {
 				}
 
 				if alternative.URI != "" {
-					alternative.URI = proxifyURL(httpProxyPort, toAbsoluteURL(response.Request.URL, alternative.URI), []string{"m3u8"})
+					alternative.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, alternative.URI), []string{"m3u8"})
 				}
 			}
 		}

+ 8 - 6
psiphon/httpProxy_test.go

@@ -26,6 +26,7 @@ import (
 	"net/url"
 	"os"
 	"strconv"
+	"strings"
 	"testing"
 )
 
@@ -55,18 +56,19 @@ func TestToAbsoluteURL(t *testing.T) {
 
 func TestProxifyURL(t *testing.T) {
 	var urlTests = []struct {
+		ip            string
 		port          int
 		urlString     string
 		rewriteParams []string
 		expected      string
 	}{
-		{1234, "http://example.com/media/pl.m3u8?q=p&p=q#hash", []string{"rewriter1"}, "http://127.0.0.1:1234/tunneled-rewrite/http%3A%2F%2Fexample.com%2Fmedia%2Fpl.m3u8%3Fq%3Dp%26p%3Dq%23hash?rewriter1="},
-		{12345, "http://example.com/media/pl.aaa", []string{"rewriter1", "rewriter2"}, "http://127.0.0.1:12345/tunneled-rewrite/http%3A%2F%2Fexample.com%2Fmedia%2Fpl.aaa?rewriter1=&rewriter2="},
-		{12346, "http://example.com/media/bbb", nil, "http://127.0.0.1:12346/tunneled/http%3A%2F%2Fexample.com%2Fmedia%2Fbbb"},
+		{"127.0.0.1", 1234, "http://example.com/media/pl.m3u8?q=p&p=q#hash", []string{"rewriter1"}, "http://127.0.0.1:1234/tunneled-rewrite/http%3A%2F%2Fexample.com%2Fmedia%2Fpl.m3u8%3Fq%3Dp%26p%3Dq%23hash?rewriter1="},
+		{"127.0.0.2", 12345, "http://example.com/media/pl.aaa", []string{"rewriter1", "rewriter2"}, "http://127.0.0.2:12345/tunneled-rewrite/http%3A%2F%2Fexample.com%2Fmedia%2Fpl.aaa?rewriter1=&rewriter2="},
+		{"127.0.0.3", 12346, "http://example.com/media/bbb", nil, "http://127.0.0.3:12346/tunneled/http%3A%2F%2Fexample.com%2Fmedia%2Fbbb"},
 	}
 
 	for _, tt := range urlTests {
-		actual := proxifyURL(tt.port, tt.urlString, tt.rewriteParams)
+		actual := proxifyURL(tt.ip, tt.port, tt.urlString, tt.rewriteParams)
 		if actual != tt.expected {
 			t.Errorf("proxifyURL(%d, %s, %v): expected %s, actual %s", tt.port, tt.urlString, tt.rewriteParams, tt.expected, actual)
 		}
@@ -115,7 +117,7 @@ func TestRewriteM3U8(t *testing.T) {
 		response.Body = inFile
 		response.Header.Set("Content-Length", strconv.FormatInt(inFileInfo.Size(), 10))
 
-		err := rewriteM3U8(12345, &response)
+		err := rewriteM3U8("127.0.0.1", 12345, &response)
 		if err != nil {
 			t.Errorf("rewriteM3U8 returned error: %s", err)
 		}
@@ -129,7 +131,7 @@ func TestRewriteM3U8(t *testing.T) {
 			t.Errorf("rewriteM3U8 body mismatch for test %d", i)
 		}
 
-		if tt.expectedContentType != "" && response.Header.Get("Content-Type") != tt.expectedContentType {
+		if tt.expectedContentType != "" && strings.ToLower(response.Header.Get("Content-Type")) != strings.ToLower(tt.expectedContentType) {
 			t.Errorf("rewriteM3U8 Content-Type mismatch for test %d: %s %s", i, tt.expectedContentType, response.Header.Get("Content-Type"))
 		}