1
0

POpen.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. #!/usr/bin/env python
  2. import sys
  3. import httplib
  4. from SocketServer import ThreadingMixIn
  5. from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
  6. from threading import Lock, Timer
  7. from cStringIO import StringIO
  8. from urlparse import urlsplit
  9. import socket
  10. import select
  11. import gzip
  12. import zlib
  13. import re
  14. import traceback
  15. class ThreadingHTTPServer(ThreadingMixIn, HTTPServer):
  16. address_family = socket.AF_INET
  17. def handle_error(self, request, client_address):
  18. print >>sys.stderr, '-'*40
  19. print >>sys.stderr, 'Exception happened during processing of request from', client_address
  20. traceback.print_exc()
  21. print >>sys.stderr, '-'*40
  22. class ThreadingHTTPServer6(ThreadingHTTPServer):
  23. address_family = socket.AF_INET6
  24. class SimpleHTTPProxyHandler(BaseHTTPRequestHandler):
  25. global_lock = Lock()
  26. conn_table = {}
  27. timeout = 300
  28. upstream_timeout = 300
  29. proxy_via = None
  30. def log_error(self, format, *args):
  31. if format == "Request timed out: %r":
  32. return
  33. self.log_message(format, *args)
  34. def do_CONNECT(self):
  35. req = self
  36. reqbody = None
  37. req.path = "https://%s/" % req.path.replace(':443', '')
  38. replaced_reqbody = self.request_handler(req, reqbody)
  39. if replaced_reqbody is True:
  40. return
  41. u = urlsplit(req.path)
  42. address = (u.hostname, u.port or 443)
  43. try:
  44. conn = socket.create_connection(address)
  45. except socket.error:
  46. return
  47. self.send_response(200, 'SOCKS5')
  48. self.send_header('Connection', 'close')
  49. self.end_headers()
  50. conns = [self.connection, conn]
  51. keep_connection = True
  52. while keep_connection:
  53. keep_connection = False
  54. rlist, wlist, xlist = select.select(conns, [], conns, self.timeout)
  55. if xlist:
  56. break
  57. for r in rlist:
  58. other = conns[1] if r is conns[0] else conns[0]
  59. data = r.recv(8192)
  60. if data:
  61. other.sendall(data)
  62. keep_connection = True
  63. conn.close()
  64. def do_HEAD(self):
  65. self.do_SPAM()
  66. def do_GET(self):
  67. self.do_SPAM()
  68. def do_POST(self):
  69. self.do_SPAM()
  70. def do_SPAM(self):
  71. req = self
  72. content_length = int(req.headers.get('Content-Length', 0))
  73. if content_length > 0:
  74. reqbody = self.rfile.read(content_length)
  75. else:
  76. reqbody = None
  77. replaced_reqbody = self.request_handler(req, reqbody)
  78. if replaced_reqbody is True:
  79. return
  80. elif replaced_reqbody is not None:
  81. reqbody = replaced_reqbody
  82. if 'Content-Length' in req.headers:
  83. req.headers['Content-Length'] = str(len(reqbody))
  84. self.remove_hop_by_hop_headers(req.headers)
  85. if self.upstream_timeout:
  86. req.headers['Connection'] = 'Keep-Alive'
  87. else:
  88. req.headers['Connection'] = 'close'
  89. if self.proxy_via:
  90. self.modify_via_header(req.headers)
  91. try:
  92. res, resdata = self.request_to_upstream_server(req, reqbody)
  93. except socket.error:
  94. return
  95. content_encoding = res.headers.get('Content-Encoding', 'identity')
  96. resbody = self.decode_content_body(resdata, content_encoding)
  97. replaced_resbody = self.response_handler(req, reqbody, res, resbody)
  98. if replaced_resbody is True:
  99. return
  100. elif replaced_resbody is not None:
  101. resdata = self.encode_content_body(replaced_resbody, content_encoding)
  102. if 'Content-Length' in res.headers:
  103. res.headers['Content-Length'] = str(len(resdata))
  104. resbody = replaced_resbody
  105. self.remove_hop_by_hop_headers(res.headers)
  106. if self.timeout:
  107. res.headers['Connection'] = 'Keep-Alive'
  108. else:
  109. res.headers['Connection'] = 'close'
  110. if self.proxy_via:
  111. self.modify_via_header(res.headers)
  112. self.send_response(res.status, res.reason)
  113. for k, v in res.headers.items():
  114. if k == 'set-cookie':
  115. for value in self.split_set_cookie_header(v):
  116. self.send_header(k, value)
  117. else:
  118. self.send_header(k, v)
  119. self.end_headers()
  120. if self.command != 'HEAD':
  121. self.wfile.write(resdata)
  122. with self.global_lock:
  123. self.save_handler(req, reqbody, res, resbody)
  124. def request_to_upstream_server(self, req, reqbody):
  125. u = urlsplit(req.path)
  126. origin = (u.scheme, u.netloc)
  127. req.headers['Host'] = u.netloc
  128. selector = "%s?%s" % (u.path, u.query) if u.query else u.path
  129. while True:
  130. with self.lock_origin(origin):
  131. conn = self.open_origin(origin)
  132. try:
  133. conn.request(req.command, selector, reqbody, headers=dict(req.headers))
  134. except socket.error:
  135. self.close_origin(origin)
  136. raise
  137. try:
  138. res = conn.getresponse(buffering=True)
  139. except httplib.BadStatusLine as e:
  140. if e.line == "''":
  141. self.close_origin(origin)
  142. continue
  143. else:
  144. raise
  145. resdata = res.read()
  146. res.headers = res.msg
  147. if not self.upstream_timeout or 'close' in res.headers.get('Connection', ''):
  148. self.close_origin(origin)
  149. else:
  150. self.reset_timer(origin)
  151. return res, resdata
  152. def lock_origin(self, origin):
  153. d = self.conn_table.setdefault(origin, {})
  154. if not 'lock' in d:
  155. d['lock'] = Lock()
  156. return d['lock']
  157. def open_origin(self, origin):
  158. conn = self.conn_table[origin].get('connection')
  159. if not conn:
  160. scheme, netloc = origin
  161. if scheme == 'https':
  162. conn = httplib.HTTPSConnection(netloc)
  163. else:
  164. conn = httplib.HTTPConnection(netloc)
  165. self.reset_timer(origin)
  166. self.conn_table[origin]['connection'] = conn
  167. return conn
  168. def reset_timer(self, origin):
  169. timer = self.conn_table[origin].get('timer')
  170. if timer:
  171. timer.cancel()
  172. if self.upstream_timeout:
  173. timer = Timer(self.upstream_timeout, self.close_origin, args=[origin])
  174. timer.daemon = True
  175. timer.start()
  176. else:
  177. timer = None
  178. self.conn_table[origin]['timer'] = timer
  179. def close_origin(self, origin):
  180. timer = self.conn_table[origin]['timer']
  181. if timer:
  182. timer.cancel()
  183. conn = self.conn_table[origin]['connection']
  184. conn.close()
  185. del self.conn_table[origin]['connection']
  186. def remove_hop_by_hop_headers(self, headers):
  187. hop_by_hop_headers = ['Connection', 'Keep-Alive', 'Proxy-Authenticate', 'Proxy-Authorization', 'TE', 'Trailers', 'Trailer', 'Transfer-Encoding', 'Upgrade']
  188. connection = headers.get('Connection')
  189. if connection:
  190. keys = re.split(r',\s*', connection)
  191. hop_by_hop_headers.extend(keys)
  192. for k in hop_by_hop_headers:
  193. if k in headers:
  194. del headers[k]
  195. def modify_via_header(self, headers):
  196. via_string = "%s %s" % (self.protocol_version, self.proxy_via)
  197. via_string = re.sub(r'^HTTP/', '', via_string)
  198. original = headers.get('Via')
  199. if original:
  200. headers['Via'] = original + ', ' + via_string
  201. else:
  202. headers['Via'] = via_string
  203. def decode_content_body(self, data, content_encoding):
  204. if content_encoding in ('gzip', 'x-gzip'):
  205. io = StringIO(data)
  206. with gzip.GzipFile(fileobj=io) as f:
  207. body = f.read()
  208. elif content_encoding == 'deflate':
  209. body = zlib.decompress(data)
  210. elif content_encoding == 'identity':
  211. body = data
  212. else:
  213. raise Exception("Unknown Content-Encoding: %s" % content_encoding)
  214. return body
  215. def encode_content_body(self, body, content_encoding):
  216. if content_encoding in ('gzip', 'x-gzip'):
  217. io = StringIO()
  218. with gzip.GzipFile(fileobj=io, mode='wb') as f:
  219. f.write(body)
  220. data = io.getvalue()
  221. elif content_encoding == 'deflate':
  222. data = zlib.compress(body)
  223. elif content_encoding == 'identity':
  224. data = body
  225. else:
  226. raise Exception("Unknown Content-Encoding: %s" % content_encoding)
  227. return data
  228. def split_set_cookie_header(self, value):
  229. re_cookies = r'([^=]+=[^,;]+(?:;\s*Expires=[^,]+,[^,;]+|;[^,;]+)*)(?:,\s*)?'
  230. return re.findall(re_cookies, value, flags=re.IGNORECASE)
  231. def request_handler(self, req, reqbody):
  232. pass
  233. def response_handler(self, req, reqbody, res, resbody):
  234. pass
  235. def save_handler(self, req, reqbody, res, resbody):
  236. pass
  237. # Port
  238. def test(HandlerClass=SimpleHTTPProxyHandler, ServerClass=ThreadingHTTPServer, protocol="HTTP/1.1"):
  239. if sys.argv[1:]:
  240. port = int(sys.argv[1])
  241. else:
  242. port = 80
  243. server_address = ('', port)
  244. HandlerClass.protocol_version = protocol
  245. httpd = ServerClass(server_address, HandlerClass)
  246. sa = httpd.socket.getsockname()
  247. print "Serving HTTP on", sa[0], "port", sa[1], "..."
  248. httpd.serve_forever()
  249. if __name__ == '__main__':
  250. test()