Skip to content

Commit

Permalink
Add support for USB connections
Browse files Browse the repository at this point in the history
Adds a new subclass of PybricksHub that manages USB connections.

Signed-off-by: Nate Karstens <[email protected]>
  • Loading branch information
nkarstens committed Nov 29, 2023
1 parent eb28049 commit db05dcf
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 4 deletions.
19 changes: 15 additions & 4 deletions pybricksdev/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,12 @@ def add_parser(self, subparsers: argparse._SubParsersAction):
)

async def run(self, args: argparse.Namespace):
from ..ble import find_device
from usb.core import find as find_usb

from ..ble import find_device as find_ble
from ..connections.ev3dev import EV3Connection
from ..connections.lego import REPLHub
from ..connections.pybricks import PybricksHubBLE
from ..connections.pybricks import PybricksHubBLE, PybricksHubUSB

# Pick the right connection
if args.conntype == "ssh":
Expand All @@ -185,14 +187,23 @@ async def run(self, args: argparse.Namespace):

device_or_address = socket.gethostbyname(args.name)
hub = EV3Connection(device_or_address)

elif args.conntype == "ble":
# It is a Pybricks Hub with BLE. Device name or address is given.
print(f"Searching for {args.name or 'any hub with Pybricks service'}...")
device_or_address = await find_device(args.name)
device_or_address = await find_ble(args.name)
hub = PybricksHubBLE(device_or_address)

elif args.conntype == "usb":
hub = REPLHub()
device_or_address = find_usb(idVendor=0x0483, idProduct=0x5740)

if (
device_or_address is not None
and device_or_address.product == "Pybricks Hub"
):
hub = PybricksHubUSB(device_or_address)
else:
hub = REPLHub()
else:
raise ValueError(f"Unknown connection type: {args.conntype}")

Expand Down
110 changes: 110 additions & 0 deletions pybricksdev/connections/pybricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import struct
from typing import Awaitable, Callable, List, Optional, TypeVar
from uuid import UUID

import reactivex.operators as op
import semver
Expand All @@ -17,6 +18,10 @@
from reactivex.subject import BehaviorSubject, Subject
from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
from usb.control import get_descriptor
from usb.core import Device as USBDevice
from usb.core import Endpoint, USBTimeoutError
from usb.util import ENDPOINT_IN, ENDPOINT_OUT, endpoint_direction, find_descriptor

from ..ble.lwp3.bytecodes import HubKind
from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID
Expand Down Expand Up @@ -705,3 +710,108 @@ async def write_gatt_char(self, uuid: str, data, response: bool) -> None:

async def start_notify(self, uuid: str, callback: Callable) -> None:
return await self._client.start_notify(uuid, callback)


class PybricksHubUSB(PybricksHub):
_device: USBDevice
_ep_in: Endpoint
_ep_out: Endpoint
_notify_callbacks = {}
_monitor_task: asyncio.Task

def __init__(self, device: USBDevice):
super().__init__()
self._device = device

async def _client_connect(self) -> bool:
self._device.set_configuration()

# Save input and output endpoints
cfg = self._device.get_active_configuration()
intf = cfg[(0, 0)]
self._ep_in = find_descriptor(
intf,
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
== ENDPOINT_IN,
)
self._ep_out = find_descriptor(
intf,
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
== ENDPOINT_OUT,
)

# Set write size to endpoint packet size minus length of UUID
self._max_write_size = self._ep_out.wMaxPacketSize - 16

# Get length of BOS descriptor
bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0)
(ofst, _, bos_len, _) = struct.unpack("<BBHB", bos_descriptor)

# Get full BOS descriptor
bos_descriptor = get_descriptor(self._device, bos_len, 0x0F, 0)

while ofst < bos_len:
(len, desc_type, cap_type) = struct.unpack_from(
"<BBB", bos_descriptor, offset=ofst
)

if desc_type != 0x10:
logger.error("Expected Device Capability descriptor")
exit(1)

# Look for platform descriptors
if cap_type == 0x05:
uuid_bytes = bos_descriptor[ofst + 4 : ofst + 4 + 16]
uuid_str = str(UUID(bytes_le=bytes(uuid_bytes)))

if uuid_str == FW_REV_UUID:
fw_version = bytearray(
bos_descriptor[ofst + 20 : ofst + len - 1]
) # Remove null-terminator
self.fw_version = Version(fw_version.decode())

elif uuid_str == SW_REV_UUID:
self._protocol_version = bytearray(
bos_descriptor[ofst + 20 : ofst + len - 1]
) # Remove null-terminator

elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID:
caps = bytearray(bos_descriptor[ofst + 20 : ofst + len])
(
_,
self._capability_flags,
self._max_user_program_size,
) = unpack_hub_capabilities(caps)

ofst += len

self._monitor_task = asyncio.create_task(self._monitor_usb())

return True

async def _client_disconnect(self) -> bool:
self._handle_disconnect()

async def read_gatt_char(self, uuid: str) -> bytearray:
return None

async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
self._ep_out.write(UUID(uuid).bytes_le + data)
# TODO: Handle response

async def start_notify(self, uuid: str, callback: Callable) -> None:
self._notify_callbacks[uuid] = callback

async def _monitor_usb(self):
while True:
try:
msg = self._ep_in.read(self._ep_in.wMaxPacketSize)
except USBTimeoutError:
await asyncio.sleep(0)
continue

if len(msg) > 16:
uuid = UUID(bytes_le=bytes(msg[0:16]))
callback = self._notify_callbacks[str(uuid)]
if callback:
callback(None, bytes(msg[16:]))

0 comments on commit db05dcf

Please sign in to comment.