Просмотр исходного кода

Add ICY protocol support to the URL proxy

Rod Hynes 7 лет назад
Родитель
Сommit
586023f1ca
1 измененных файлов с 253 добавлено и 31 удалено
  1. 253 31
      psiphon/httpProxy.go

+ 253 - 31
psiphon/httpProxy.go

@@ -33,9 +33,12 @@ import (
 	"strconv"
 	"strings"
 	"sync"
+	"sync/atomic"
+	"time"
 
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
 	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
+	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tls"
 	"github.com/grafov/m3u8"
 )
 
@@ -52,7 +55,8 @@ import (
 // To make the Media Player use the Psiphon tunnel, construct a URL such as:
 // "http://127.0.0.1:<proxy-port>/tunneled/<origin media URL>"; and pass this to the player.
 // The <origin media URL> must be escaped in such a way that it can be used inside a URL query.
-// TODO: add ICY protocol to support certain streaming media (e.g., https://gist.github.com/tulskiy/1008126)
+//
+// The URL proxy offers /tunneled-icy/ which is compatible with ICY protocol resources.
 //
 // An example use case for direct, untunneled, relaying is to make use of Go's TLS
 // stack for HTTPS requests in cases where the native TLS stack is lacking (e.g.,
@@ -65,6 +69,7 @@ import (
 // For example, in iOS 10 the UIWebView media player does not put requests through the
 // NSURLProtocol, so they are not tunneled. Instead, we rewrite those URLs to use the URL
 // proxy, and rewrite retrieved playlist files so they also contain proxied URLs.
+// Media resource links within playlists are rewritten to use the /tunneled-icy/ path.
 //
 // Origin URLs must include the scheme prefix ("http://" or "https://") and must be
 // URL encoded.
@@ -78,6 +83,7 @@ type HttpProxy struct {
 	urlProxyTunneledClient *http.Client
 	urlProxyDirectRelay    *http.Transport
 	urlProxyDirectClient   *http.Client
+	responseHeaderTimeout  time.Duration
 	openConns              *common.Conns
 	stopListeningBroadcast chan struct{}
 	listenIP               string
@@ -165,6 +171,7 @@ func NewHttpProxy(
 		urlProxyTunneledClient: urlProxyTunneledClient,
 		urlProxyDirectRelay:    urlProxyDirectRelay,
 		urlProxyDirectClient:   urlProxyDirectClient,
+		responseHeaderTimeout:  responseHeaderTimeout,
 		openConns:              new(common.Conns),
 		stopListeningBroadcast: make(chan struct{}),
 		listenIP:               proxyIP,
@@ -221,10 +228,9 @@ func (proxy *HttpProxy) Close() {
 //
 func (proxy *HttpProxy) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
 	if request.Method == "CONNECT" {
-		hijacker, _ := responseWriter.(http.Hijacker)
-		conn, _, err := hijacker.Hijack()
-		if err != nil {
-			NoticeAlert("%s", common.ContextError(err))
+		conn := hijack(responseWriter)
+		if conn == nil {
+			// hijack emits an alert notice
 			http.Error(responseWriter, "", http.StatusInternalServerError)
 			return
 		}
@@ -262,18 +268,20 @@ func (proxy *HttpProxy) httpConnectHandler(localConn net.Conn, target string) (e
 }
 
 func (proxy *HttpProxy) httpProxyHandler(responseWriter http.ResponseWriter, request *http.Request) {
-	proxy.relayHTTPRequest(nil, proxy.httpProxyTunneledRelay, request, responseWriter, nil)
+	proxy.relayHTTPRequest(nil, proxy.httpProxyTunneledRelay, request, responseWriter, nil, nil)
 }
 
 const (
-	URL_PROXY_TUNNELED_REQUEST_PATH = "/tunneled/"
-	URL_PROXY_REWRITE_REQUEST_PATH  = "/tunneled-rewrite/"
-	URL_PROXY_DIRECT_REQUEST_PATH   = "/direct/"
+	URL_PROXY_TUNNELED_REQUEST_PATH         = "/tunneled/"
+	URL_PROXY_TUNNELED_REWRITE_REQUEST_PATH = "/tunneled-rewrite/"
+	URL_PROXY_TUNNELED_ICY_REQUEST_PATH     = "/tunneled-icy/"
+	URL_PROXY_DIRECT_REQUEST_PATH           = "/direct/"
 )
 
 func (proxy *HttpProxy) urlProxyHandler(responseWriter http.ResponseWriter, request *http.Request) {
 
 	var client *http.Client
+	var rewriteICYStatus *rewriteICYStatus
 	var originURLString string
 	var err error
 	var rewrites url.Values
@@ -284,10 +292,14 @@ func (proxy *HttpProxy) urlProxyHandler(responseWriter http.ResponseWriter, requ
 	case strings.HasPrefix(request.URL.RawPath, URL_PROXY_TUNNELED_REQUEST_PATH):
 		originURLString, err = url.QueryUnescape(request.URL.RawPath[len(URL_PROXY_TUNNELED_REQUEST_PATH):])
 		client = proxy.urlProxyTunneledClient
-	case strings.HasPrefix(request.URL.RawPath, URL_PROXY_REWRITE_REQUEST_PATH):
-		originURLString, err = url.QueryUnescape(request.URL.RawPath[len(URL_PROXY_REWRITE_REQUEST_PATH):])
+	case strings.HasPrefix(request.URL.RawPath, URL_PROXY_TUNNELED_REWRITE_REQUEST_PATH):
+		originURLString, err = url.QueryUnescape(request.URL.RawPath[len(URL_PROXY_TUNNELED_REWRITE_REQUEST_PATH):])
 		client = proxy.urlProxyTunneledClient
 		rewrites = request.URL.Query()
+	case strings.HasPrefix(request.URL.RawPath, URL_PROXY_TUNNELED_ICY_REQUEST_PATH):
+		originURLString, err = url.QueryUnescape(request.URL.RawPath[len(URL_PROXY_TUNNELED_ICY_REQUEST_PATH):])
+		client, rewriteICYStatus = proxy.makeRewriteICYClient()
+		rewrites = request.URL.Query()
 	case strings.HasPrefix(request.URL.RawPath, URL_PROXY_DIRECT_REQUEST_PATH):
 		originURLString, err = url.QueryUnescape(request.URL.RawPath[len(URL_PROXY_DIRECT_REQUEST_PATH):])
 		client = proxy.urlProxyDirectClient
@@ -317,7 +329,141 @@ func (proxy *HttpProxy) urlProxyHandler(responseWriter http.ResponseWriter, requ
 	request.Host = originURL.Host
 	request.URL = originURL
 
-	proxy.relayHTTPRequest(client, nil, request, responseWriter, rewrites)
+	proxy.relayHTTPRequest(client, nil, request, responseWriter, rewrites, rewriteICYStatus)
+}
+
+// rewriteICYConn rewrites an ICY procotol responses to that it may be
+// consumed by Go's http package. rewriteICYConn expects the ICY response to
+// be equivilent to HTTP/1.1 with the exception of the protocol name in the
+// status line, which is the one part that is rewritten. Responses that are
+// already HTTP are passed through unmodified.
+type rewriteICYConn struct {
+	net.Conn
+	doneRewriting int32
+	isICY         *int32
+}
+
+func (conn *rewriteICYConn) Read(b []byte) (int, error) {
+
+	if !atomic.CompareAndSwapInt32(&conn.doneRewriting, 0, 1) {
+		return conn.Conn.Read(b)
+	}
+
+	if len(b) < 3 {
+		// Don't attempt to rewrite the protocol when insufficient
+		// buffer space. This is not expected to happen in practise
+		// when Go's http reads the response, so for now we just
+		// skip the rewrite instead of tracking state accross Reads.
+		return conn.Conn.Read(b)
+	}
+
+	// Expect to read either "ICY" or "HTT".
+
+	n, err := conn.Conn.Read(b[:3])
+	if err != nil {
+		return n, err
+	}
+
+	if string(b[:3]) == "ICY" {
+		atomic.StoreInt32(conn.isICY, 1)
+		copy(b, []byte("HTTP/1.1"))
+		return 8, nil
+	}
+
+	return n, nil
+}
+
+type rewriteICYStatus struct {
+	isFirstConnICY int32
+}
+
+func (status *rewriteICYStatus) isICY() bool {
+	return atomic.LoadInt32(&status.isFirstConnICY) == 1
+}
+
+// makeRewriteICYClient creates an http.Client with a Transport configured to
+// use rewriteICYConn. Both HTTP and HTTPS are handled. The http.Client is
+// intended to be used for one single request. The client disables keep alives
+// as rewriteICYConn can only rewrite the first response in a connection. The
+// returned rewriteICYStatus indicates which the first response for the first
+// request was ICY, allowing the downstream relayed response to replicate the
+// ICY protocol.
+func (proxy *HttpProxy) makeRewriteICYClient() (*http.Client, *rewriteICYStatus) {
+
+	rewriteICYStatus := &rewriteICYStatus{}
+
+	tunneledDialer := func(_, addr string) (conn net.Conn, err error) {
+		// See comment in NewHttpProxy regarding downstreamConn
+		return proxy.tunneler.Dial(addr, false, nil)
+	}
+
+	dial := func(network, address string) (net.Conn, error) {
+
+		conn, err := tunneledDialer(network, address)
+		if err != nil {
+			return nil, common.ContextError(err)
+		}
+
+		return &rewriteICYConn{
+			Conn:  conn,
+			isICY: &rewriteICYStatus.isFirstConnICY,
+		}, nil
+	}
+
+	dialTLS := func(network, address string) (net.Conn, error) {
+
+		conn, err := tunneledDialer(network, address)
+		if err != nil {
+			return nil, common.ContextError(err)
+		}
+
+		serverName, _, err := net.SplitHostPort(address)
+		if err != nil {
+			conn.Close()
+			return nil, common.ContextError(err)
+		}
+
+		tlsConn := tls.Client(conn, &tls.Config{ServerName: serverName})
+
+		resultChannel := make(chan error, 1)
+
+		timeout := proxy.responseHeaderTimeout
+		afterFunc := time.AfterFunc(timeout, func() {
+			resultChannel <- errors.New("TLS handshake timeout")
+		})
+		defer afterFunc.Stop()
+
+		go func() {
+			resultChannel <- tlsConn.Handshake()
+		}()
+
+		err = <-resultChannel
+		if err != nil {
+			conn.Close()
+			return nil, common.ContextError(err)
+		}
+
+		err = tlsConn.VerifyHostname(serverName)
+		if err != nil {
+			conn.Close()
+			return nil, common.ContextError(err)
+		}
+
+		return &rewriteICYConn{
+			Conn:  tlsConn,
+			isICY: &rewriteICYStatus.isFirstConnICY,
+		}, nil
+
+	}
+
+	return &http.Client{
+		Transport: &http.Transport{
+			Dial:                  dial,
+			DialTLS:               dialTLS,
+			DisableKeepAlives:     true,
+			ResponseHeaderTimeout: proxy.responseHeaderTimeout,
+		},
+	}, rewriteICYStatus
 }
 
 func (proxy *HttpProxy) relayHTTPRequest(
@@ -325,7 +471,8 @@ func (proxy *HttpProxy) relayHTTPRequest(
 	transport *http.Transport,
 	request *http.Request,
 	responseWriter http.ResponseWriter,
-	rewrites url.Values) {
+	rewrites url.Values,
+	rewriteICYStatus *rewriteICYStatus) {
 
 	// Transform received request struct before using as input to relayed request
 	request.Close = false
@@ -375,6 +522,7 @@ func (proxy *HttpProxy) relayHTTPRequest(
 	}
 
 	// Relay the remote response headers
+
 	for _, key := range hopHeaders {
 		response.Header.Del(key)
 	}
@@ -387,13 +535,59 @@ func (proxy *HttpProxy) relayHTTPRequest(
 		}
 	}
 
-	// Relay the response code and body
-	responseWriter.WriteHeader(response.StatusCode)
-	_, err = io.Copy(responseWriter, response.Body)
-	if err != nil {
-		NoticeAlert("%s", common.ContextError(err))
-		forceClose(responseWriter)
-		return
+	// Send the response downstream
+
+	if rewriteICYStatus != nil && rewriteICYStatus.isICY() {
+
+		// Custom ICY response, using "ICY" as the protocol name
+		// but otherwise equivilent to the HTTP response.
+
+		// As the ICY http.Transport has disabled keep-alives,
+		// hijacking here does not disrupt an otherwise persistent
+		// connection.
+
+		conn := hijack(responseWriter)
+		if conn == nil {
+			// hijack emits an alert notice
+			return
+		}
+
+		_, err := fmt.Fprint(
+			conn,
+			"ICY %d %s",
+			response.StatusCode,
+			http.StatusText(response.StatusCode))
+		if err != nil {
+			NoticeAlert("write status line failed: %s", common.ContextError(err))
+			conn.Close()
+			return
+		}
+
+		err = responseWriter.Header().Write(conn)
+		if err != nil {
+			NoticeAlert("write headers failed: %s", common.ContextError(err))
+			conn.Close()
+			return
+		}
+
+		_, err = io.Copy(conn, response.Body)
+		if err != nil {
+			NoticeAlert("write body failed: %s", common.ContextError(err))
+			conn.Close()
+			return
+		}
+
+	} else {
+
+		// Standard HTTP response.
+
+		responseWriter.WriteHeader(response.StatusCode)
+		_, err = io.Copy(responseWriter, response.Body)
+		if err != nil {
+			NoticeAlert("%s", common.ContextError(err))
+			forceClose(responseWriter)
+			return
+		}
 	}
 }
 
@@ -401,13 +595,26 @@ func (proxy *HttpProxy) relayHTTPRequest(
 // to ensure local persistent connections into the HTTP proxy are closed
 // when ServeHTTP encounters an error.
 func forceClose(responseWriter http.ResponseWriter) {
-	hijacker, _ := responseWriter.(http.Hijacker)
-	conn, _, err := hijacker.Hijack()
-	if err == nil {
+	conn := hijack(responseWriter)
+	if conn != nil {
 		conn.Close()
 	}
 }
 
+func hijack(responseWriter http.ResponseWriter) net.Conn {
+	hijacker, ok := responseWriter.(http.Hijacker)
+	if !ok {
+		NoticeAlert("%s", common.ContextError(errors.New("responseWriter is not an http.Hijacker")))
+		return nil
+	}
+	conn, _, err := hijacker.Hijack()
+	if err != nil {
+		NoticeAlert("%s", common.ContextError(fmt.Errorf("responseWriter hijack failed: %s", err)))
+		return nil
+	}
+	return conn
+}
+
 // From https://golang.org/src/pkg/net/http/httputil/reverseproxy.go:
 // Hop-by-hop headers. These are removed when sent to the backend.
 // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
@@ -484,9 +691,19 @@ func toAbsoluteURL(baseURL *url.URL, relativeURLString string) string {
 
 // proxifyURL takes an absolute URL and rewrites it to go through the local URL proxy.
 // urlProxy port is the local HTTP proxy port.
+//
 // If rewriteParams is nil, then no rewriting will be done. Otherwise, it should contain
 // supported rewriting flags (like "m3u8").
-func proxifyURL(localHTTPProxyIP string, localHTTPProxyPort int, urlString string, rewriteParams []string) string {
+//
+// If useICY is specified, the ICY rewriting path is selected. useICY is ignored when
+// rewriteParams is set.
+func proxifyURL(
+	localHTTPProxyIP string,
+	localHTTPProxyPort int,
+	urlString string,
+	rewriteParams []string,
+	useICY bool) string {
+
 	// Note that we need to use the "opaque" form of URL so that it doesn't double-escape the path. See: https://github.com/golang/go/issues/10887
 
 	// TODO: IPv6 support
@@ -494,10 +711,13 @@ func proxifyURL(localHTTPProxyIP string, localHTTPProxyPort int, urlString strin
 		localHTTPProxyIP = "127.0.0.1"
 	}
 
-	opaqueFormat := "//%s:%d/tunneled/%s"
+	proxyPath := URL_PROXY_TUNNELED_REQUEST_PATH
 	if rewriteParams != nil {
-		opaqueFormat = "//%s:%d/tunneled-rewrite/%s"
+		proxyPath = URL_PROXY_TUNNELED_REWRITE_REQUEST_PATH
+	} else if useICY {
+		proxyPath = URL_PROXY_TUNNELED_ICY_REQUEST_PATH
 	}
+	opaqueFormat := fmt.Sprintf("//%%s:%%d/%s/%%s", proxyPath)
 
 	var proxifiedURL url.URL
 
@@ -564,6 +784,8 @@ func rewriteM3U8(localHTTPProxyIP string, localHTTPProxyPort int, response *http
 		return nil
 	}
 
+	useICY := true
+
 	var rewrittenBodyBytes []byte
 
 	switch listType {
@@ -575,15 +797,15 @@ func rewriteM3U8(localHTTPProxyIP string, localHTTPProxyPort int, response *http
 			}
 
 			if segment.URI != "" {
-				segment.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, segment.URI), nil)
+				segment.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, segment.URI), nil, useICY)
 			}
 
 			if segment.Key != nil && segment.Key.URI != "" {
-				segment.Key.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, segment.Key.URI), nil)
+				segment.Key.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, segment.Key.URI), nil, useICY)
 			}
 
 			if segment.Map != nil && segment.Map.URI != "" {
-				segment.Map.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, segment.Map.URI), nil)
+				segment.Map.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, segment.Map.URI), nil, useICY)
 			}
 		}
 		rewrittenBodyBytes = []byte(mediapl.String())
@@ -595,7 +817,7 @@ func rewriteM3U8(localHTTPProxyIP string, localHTTPProxyPort int, response *http
 			}
 
 			if variant.URI != "" {
-				variant.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, variant.URI), []string{"m3u8"})
+				variant.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, variant.URI), []string{"m3u8"}, useICY)
 			}
 
 			for _, alternative := range variant.Alternatives {
@@ -604,7 +826,7 @@ func rewriteM3U8(localHTTPProxyIP string, localHTTPProxyPort int, response *http
 				}
 
 				if alternative.URI != "" {
-					alternative.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, alternative.URI), []string{"m3u8"})
+					alternative.URI = proxifyURL(localHTTPProxyIP, localHTTPProxyPort, toAbsoluteURL(response.Request.URL, alternative.URI), []string{"m3u8"}, useICY)
 				}
 			}
 		}