usbmux.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. import sys
  4. import socket
  5. import struct
  6. import select
  7. import plistlib
  8. have_plist = True
  9. class MuxError(Exception):
  10. pass
  11. class MuxVersionError(MuxError):
  12. pass
  13. class SafeStreamSocket:
  14. def __init__(self, address: str, family: 'AddressFamily'):
  15. self.sock = socket.socket(family, socket.SOCK_STREAM)
  16. self.sock.connect(address)
  17. def send(self, msg: bytes) -> None:
  18. total_sent = 0
  19. while total_sent < len(msg):
  20. sent = self.sock.send(msg[total_sent:])
  21. if sent == 0:
  22. raise MuxError("socket connection broken")
  23. total_sent = total_sent + sent
  24. def recv(self, size: int) -> bytes:
  25. msg = bytes()
  26. while len(msg) < size:
  27. chunk = self.sock.recv(size - len(msg))
  28. if chunk == b'':
  29. raise MuxError("socket connection broken")
  30. msg = msg + chunk
  31. return msg
  32. class MuxDevice(object):
  33. def __init__(self, devid: int, usbprod: int, serial: str, location: int):
  34. self.devid = devid
  35. self.usbprod = usbprod
  36. self.serial = serial
  37. self.location = location
  38. def __str__(self):
  39. return f"<MuxDevice: ID %d ProdID 0x%04x Serial '%s' Location 0x%x>" % (
  40. self.devid, self.usbprod, self.serial, self.location)
  41. class BinaryProtocol(object):
  42. TYPE_RESULT = 1
  43. TYPE_CONNECT = 2
  44. TYPE_LISTEN = 3
  45. TYPE_DEVICE_ADD = 4
  46. TYPE_DEVICE_REMOVE = 5
  47. VERSION = 0
  48. def __init__(self, socket: socket):
  49. self.socket = socket
  50. self.connected = False
  51. def _pack(self, req, payload):
  52. if req == self.TYPE_CONNECT:
  53. return struct.pack("IH", payload['DeviceID'], payload['PortNumber']) + b"\x00\x00"
  54. elif req == self.TYPE_LISTEN:
  55. return bytes()
  56. else:
  57. raise ValueError(f"Invalid outgoing request type {req}")
  58. def _unpack(self, resp, payload):
  59. if resp == self.TYPE_RESULT:
  60. return {'Number': struct.unpack("I", payload)[0]}
  61. elif resp == self.TYPE_DEVICE_ADD:
  62. dev_id, usb_pid, serial, pad, location = struct.unpack("IH256sHI", payload)
  63. serial = serial.split(b"\0")[0]
  64. return {'DeviceID': dev_id,
  65. 'Properties': {'LocationID': location, 'SerialNumber': serial, 'ProductID': usb_pid}}
  66. elif resp == self.TYPE_DEVICE_REMOVE:
  67. dev_id = struct.unpack("I", payload)[0]
  68. return {'DeviceID': dev_id}
  69. else:
  70. raise MuxError("Invalid incoming request type")
  71. def send_packet(self, req, tag, payload=None):
  72. if payload is None:
  73. payload = {}
  74. payload = self._pack(req, payload)
  75. if self.connected:
  76. raise MuxError("Mux is connected, cannot issue control packets")
  77. length = 16 + len(payload)
  78. data = struct.pack("IIII", length, self.VERSION, req, tag) + payload
  79. self.socket.send(data)
  80. def get_packet(self):
  81. if self.connected:
  82. raise MuxError("Mux is connected, cannot issue control packets")
  83. d_len = self.socket.recv(4)
  84. d_len = struct.unpack("I", d_len)[0]
  85. body = self.socket.recv(d_len - 4)
  86. version, resp, tag = struct.unpack("III", body[:0xc])
  87. if version != self.VERSION:
  88. raise MuxVersionError(f"Version mismatch: expected {self.VERSION}, got {version}")
  89. payload = self._unpack(resp, body[0xc:])
  90. return resp, tag, payload
  91. class PlistProtocol(BinaryProtocol):
  92. TYPE_RESULT = "Result"
  93. TYPE_CONNECT = "Connect"
  94. TYPE_LISTEN = "Listen"
  95. TYPE_DEVICE_ADD = "Attached"
  96. TYPE_DEVICE_REMOVE = "Detached"
  97. TYPE_PLIST = 8
  98. VERSION = 1
  99. def __init__(self, socket: socket):
  100. if not have_plist:
  101. raise Exception("You need the plistlib module")
  102. BinaryProtocol.__init__(self, socket)
  103. def _pack(self, req, payload: bytes) -> bytes:
  104. return payload
  105. def _unpack(self, resp, payload: bytes) -> bytes:
  106. return payload
  107. def send_packet(self, req, tag, payload=None) -> None:
  108. if payload is None:
  109. payload = {}
  110. payload['ClientVersionString'] = 'usbmux.py'
  111. if isinstance(req, int):
  112. req = [self.TYPE_CONNECT, self.TYPE_LISTEN][req - 2]
  113. payload['MessageType'] = req
  114. payload['ProgName'] = 'tcprelay'
  115. BinaryProtocol.send_packet(self, self.TYPE_PLIST, tag, plistlib.dumps(payload))
  116. def get_packet(self) -> tuple:
  117. resp, tag, payload = BinaryProtocol.get_packet(self)
  118. if resp != self.TYPE_PLIST:
  119. raise MuxError(f"Received non-plist type {resp}")
  120. payload = plistlib.loads(payload)
  121. return payload['MessageType'], tag, payload
  122. class MuxConnection(object):
  123. def __init__(self, socketpath: str, protoclass: type):
  124. self.socketpath = socketpath
  125. if sys.platform in ['win32', 'cygwin']:
  126. family = socket.AF_INET
  127. address = ('127.0.0.1', 27015)
  128. else:
  129. family = socket.AF_UNIX
  130. address = self.socketpath
  131. self.socket = SafeStreamSocket(address, family)
  132. self.proto = protoclass(self.socket)
  133. self.pkttag = 1
  134. self.devices = list()
  135. def _getreply(self) -> tuple or None:
  136. while True:
  137. resp, tag, data = self.proto.get_packet()
  138. if resp == self.proto.TYPE_RESULT:
  139. return tag, data
  140. else:
  141. raise MuxError(f"Invalid packet type received: {resp}")
  142. def _processpacket(self) -> None:
  143. resp, tag, data = self.proto.get_packet()
  144. if resp == self.proto.TYPE_DEVICE_ADD:
  145. self.devices.append(
  146. MuxDevice(data['DeviceID'], data['Properties']['ProductID'], data['Properties']['SerialNumber'],
  147. data['Properties']['LocationID']))
  148. elif resp == self.proto.TYPE_DEVICE_REMOVE:
  149. for item in self.devices:
  150. if item.devid == data['DeviceID']:
  151. self.devices.remove(item)
  152. elif resp == self.proto.TYPE_RESULT:
  153. raise MuxError(f"Unexpected result: {resp}")
  154. else:
  155. raise MuxError(f"Invalid packet type received: {resp}")
  156. def _exchange(self, req: str, payload: dict = None) -> str:
  157. if payload is None:
  158. payload = {}
  159. mytag = self.pkttag
  160. self.pkttag += 1
  161. self.proto.send_packet(req, mytag, payload)
  162. recvtag, data = self._getreply()
  163. if recvtag != mytag:
  164. raise MuxError(f"Reply tag mismatch: expected {mytag}, got {recvtag}")
  165. return data['Number']
  166. def listen(self) -> None:
  167. ret = self._exchange(self.proto.TYPE_LISTEN)
  168. if ret != 0:
  169. raise MuxError(f"Listen failed: error {ret}")
  170. def process(self, timeout: int or float = None) -> None:
  171. if self.proto.connected:
  172. raise MuxError("Socket is connected, cannot process listener events")
  173. rlo, wlo, xlo = select.select([self.socket.sock], [], [self.socket.sock], timeout)
  174. if xlo:
  175. self.socket.sock.close()
  176. raise MuxError("Exception in listener socket")
  177. if rlo:
  178. self._processpacket()
  179. def connect(self, device: MuxDevice, port: int) -> socket:
  180. ret = self._exchange(self.proto.TYPE_CONNECT,
  181. {'DeviceID': device.devid, 'PortNumber': ((port << 8) & 0xFF00) | (port >> 8)})
  182. if ret != 0:
  183. raise MuxError(f"Connect failed: error {ret}")
  184. self.proto.connected = True
  185. return self.socket.sock
  186. def close(self) -> None:
  187. self.socket.sock.close()
  188. class USBMux(object):
  189. def __init__(self, socket_path: str = None):
  190. if socket_path is None:
  191. if sys.platform == 'darwin':
  192. socket_path = "/var/run/usbmuxd"
  193. else:
  194. socket_path = "/var/run/usbmuxd"
  195. self.socketpath = socket_path
  196. self.listener = MuxConnection(socket_path, BinaryProtocol)
  197. try:
  198. self.listener.listen()
  199. self.version = 0
  200. self.protoclass = BinaryProtocol
  201. except MuxVersionError:
  202. self.listener = MuxConnection(socket_path, PlistProtocol)
  203. self.listener.listen()
  204. self.protoclass = PlistProtocol
  205. self.version = 1
  206. self.devices = self.listener.devices
  207. def process(self, timeout=None) -> None:
  208. self.listener.process(timeout)
  209. def connect(self, device, port) -> socket:
  210. connector = MuxConnection(self.socketpath, self.protoclass)
  211. return connector.connect(device, port)
  212. if __name__ == "__main__":
  213. mux = USBMux()
  214. print("Waiting for devices...")
  215. if not mux.devices:
  216. mux.process(0.1)
  217. while True:
  218. print("Devices:")
  219. for dev in mux.devices:
  220. print(dev)
  221. mux.process()