Skip to content

Commit

Permalink
Add support for USB connections
Browse files Browse the repository at this point in the history
Adds a new transport to manage USB connections.

Signed-off-by: Nate Karstens <[email protected]>
  • Loading branch information
nkarstens committed Feb 21, 2024
1 parent ac67149 commit 4c5ddff
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 4 deletions.
6 changes: 6 additions & 0 deletions pybricksdev/ble/pybricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,12 @@ def _standard_uuid(short: int) -> str:
.. availability:: Since Pybricks protocol v1.0.0.
"""

DEVICE_NAME_UUID = _standard_uuid(0x2A00)
"""Standard Device Name UUID
.. availability:: Since Pybricks protocol v1.0.0.
"""

FW_REV_UUID = _standard_uuid(0x2A26)
"""Standard Firmware Revision String characteristic UUID
Expand Down
172 changes: 168 additions & 4 deletions pybricksdev/connections/pybricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import logging
import os
import struct
from typing import Awaitable, Callable, List, Optional, Tuple, TypeVar
from typing import Awaitable, Callable, List, Optional, Tuple, TypeVar, Union
from uuid import UUID

import reactivex.operators as op
import semver
Expand All @@ -19,9 +20,15 @@
from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm

from usb.control import get_descriptor
from usb.core import Device as USBDevice
from usb.core import Endpoint, USBTimeoutError
from usb.util import ENDPOINT_IN, ENDPOINT_OUT, endpoint_direction, find_descriptor

from ..ble.lwp3.bytecodes import HubKind
from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID
from ..ble.pybricks import (
DEVICE_NAME_UUID,
FW_REV_UUID,
PNP_ID_UUID,
PYBRICKS_COMMAND_EVENT_UUID,
Expand All @@ -38,6 +45,7 @@
from ..compile import compile_file, compile_multi_file
from ..tools import chunk
from ..tools.checksum import xor_bytes
from ..usb import LegoUsbMsg, LegoUsbPid
from . import ConnectionState

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -138,6 +146,156 @@ def handler(_, data):
await self._client.start_notify(NUS_TX_UUID, handler)


class _USBTransport(_Transport):
_device: USBDevice
_disconnected_callback: Callable
_ep_in: Endpoint
_ep_out: Endpoint
_notify_callbacks = {}
_monitor_task: asyncio.Task
_response: asyncio.Future

def __init__(self, device: USBDevice):
self._device = device
self._notify_callbacks[
LegoUsbMsg.USB_PYBRICKS_MSG_COMMAND_RESPONSE
] = self._response_handler

async def connect(self, disconnected_callback: Callable) -> None:
self._disconnected_callback = disconnected_callback
self._device.set_configuration()

# Save input and output endpoints
cfg = self._device.get_active_configuration()
intf = cfg[(0, 0)]
self._ep_in = find_descriptor(
intf,
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
== ENDPOINT_IN,
)
self._ep_out = find_descriptor(
intf,
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
== ENDPOINT_OUT,
)

# Get length of BOS descriptor
bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0)
(ofst, bos_len) = struct.unpack("<BxHx", bos_descriptor)

# Get full BOS descriptor
bos_descriptor = get_descriptor(self._device, bos_len, 0x0F, 0)

while ofst < bos_len:
(len, desc_type, cap_type) = struct.unpack_from(
"<BBB", bos_descriptor, offset=ofst
)

if desc_type != 0x10:
raise Exception("Expected Device Capability descriptor")

# Look for platform descriptors
if cap_type == 0x05:
uuid_bytes = bos_descriptor[ofst + 4 : ofst + 4 + 16]
uuid_str = str(UUID(bytes_le=bytes(uuid_bytes)))

if uuid_str == DEVICE_NAME_UUID:
self._device_name = bytes(
bos_descriptor[ofst + 20 : ofst + len]
).decode()
print("Connected to hub '" + self._device_name + "'")

elif uuid_str == FW_REV_UUID:
fw_version = bytes(bos_descriptor[ofst + 20 : ofst + len])
self._fw_version = Version(fw_version.decode())

elif uuid_str == SW_REV_UUID:
protocol_version = bytes(bos_descriptor[ofst + 20 : ofst + len])
self._protocol_version = semver.VersionInfo.parse(
protocol_version.decode()
)

elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID:
caps = bytes(bos_descriptor[ofst + 20 : ofst + len])
(
self._max_write_size,
self._capability_flags,
self._max_user_program_size,
) = unpack_hub_capabilities(caps)

ofst += len

self._monitor_task = asyncio.create_task(self._monitor_usb())

async def disconnect(self) -> None:
# FIXME: Need to make sure this is called when the USB cable is unplugged
self._monitor_task.cancel()
self._disconnected_callback()

async def get_firmware_version(self) -> Version:
return self._fw_version

async def get_protocol_version(self) -> Version:
return self._protocol_version

async def get_hub_type(self) -> Tuple[HubKind, int]:
hub_types = {
LegoUsbPid.SPIKE_PRIME: (HubKind.TECHNIC_LARGE, 0),
LegoUsbPid.ROBOT_INVENTOR: (HubKind.TECHNIC_LARGE, 1),
LegoUsbPid.SPIKE_ESSENTIAL: (HubKind.TECHNIC_SMALL, 0),
}

return hub_types[self._device.idProduct]

async def get_hub_capabilities(self) -> Tuple[int, HubCapabilityFlag, int]:
return (
self._max_write_size,
self._capability_flags,
self._max_user_program_size,
)

async def send_command(self, command: bytes) -> None:
self._response = asyncio.Future()
self._ep_out.write(
struct.pack("B", LegoUsbMsg.USB_PYBRICKS_MSG_COMMAND) + command
)
try:
await asyncio.wait_for(self._response, 1)
if self._response.result() != 0:
print(
f"Received error response for command: {self._response.result()}"
)
except asyncio.TimeoutError:
print("Timed out waiting for a response")

async def set_service_handler(self, callback: Callable) -> None:
self._notify_callbacks[LegoUsbMsg.USB_PYBRICKS_MSG_EVENT] = callback

async def _monitor_usb(self):
loop = asyncio.get_running_loop()

while True:
msg = await loop.run_in_executor(None, self._read_usb)

if msg is None or len(msg) == 0:
continue

callback = self._notify_callbacks.get(msg[0])
if callback is not None:
callback(bytes(msg[1:]))

def _read_usb(self):
try:
msg = self._ep_in.read(self._ep_in.wMaxPacketSize)
return msg
except USBTimeoutError:
return None

def _response_handler(self, data: bytes) -> None:
(response,) = struct.unpack("<I", data)
self._response.set_result(response)


class PybricksHub:
EOL = b"\r\n" # MicroPython EOL

Expand Down Expand Up @@ -326,11 +484,12 @@ def _pybricks_service_handler(self, data: bytes) -> None:
if self._enable_line_handler:
self._handle_line_data(payload)

async def connect(self, device: BLEDevice):
async def connect(self, device: Union[BLEDevice, USBDevice]):
"""Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device`
or :meth:`usb.core.find`
Args:
device: The device to connect to.
device: The device to connect to (`BLEDevice` or `USBDevice`).
Raises:
BleakError: if connecting failed (or old firmware without Device
Expand All @@ -350,7 +509,12 @@ async def connect(self, device: BLEDevice):
self.connection_state_observable.on_next, ConnectionState.DISCONNECTED
)

self._transport = _BLETransport(device)
if isinstance(device, BLEDevice):
self._transport = _BLETransport(device)
elif isinstance(device, USBDevice):
self._transport = _USBTransport(device)
else:
raise TypeError("Unsupported device type")

def handle_disconnect():
logger.info("Disconnected!")
Expand Down
6 changes: 6 additions & 0 deletions pybricksdev/usb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class LegoUsbPid(_enum.IntEnum):
ROBOT_INVENTOR_DFU = 0x0011


class LegoUsbMsg(_enum.IntEnum):
USB_PYBRICKS_MSG_COMMAND = 0x00
USB_PYBRICKS_MSG_COMMAND_RESPONSE = 0x01
USB_PYBRICKS_MSG_EVENT = 0x02


PYBRICKS_USB_DEVICE_CLASS = 0xFF
PYBRICKS_USB_DEVICE_SUBCLASS = 0xC5
PYBRICKS_USB_DEVICE_PROTOCOL = 0xF5

0 comments on commit 4c5ddff

Please sign in to comment.