diff --git a/pybricksdev/ble/pybricks.py b/pybricksdev/ble/pybricks.py index cb058b9..d33bdf4 100644 --- a/pybricksdev/ble/pybricks.py +++ b/pybricksdev/ble/pybricks.py @@ -328,6 +328,12 @@ def _standard_uuid(short: int) -> str: .. availability:: Since Pybricks protocol v1.0.0. """ +DEVICE_NAME_UUID = _standard_uuid(0x2A00) +"""Standard Device Name UUID + +.. availability:: Since Pybricks protocol v1.0.0. +""" + FW_REV_UUID = _standard_uuid(0x2A26) """Standard Firmware Revision String characteristic UUID diff --git a/pybricksdev/connections/pybricks.py b/pybricksdev/connections/pybricks.py index e481812..df996d8 100644 --- a/pybricksdev/connections/pybricks.py +++ b/pybricksdev/connections/pybricks.py @@ -8,6 +8,7 @@ import os import struct from typing import Awaitable, Callable, List, Optional, Tuple, TypeVar +from uuid import UUID import reactivex.operators as op import semver @@ -19,9 +20,15 @@ from tqdm.auto import tqdm from tqdm.contrib.logging import logging_redirect_tqdm +from usb.control import get_descriptor +from usb.core import Device as USBDevice +from usb.core import Endpoint, USBTimeoutError +from usb.util import ENDPOINT_IN, ENDPOINT_OUT, endpoint_direction, find_descriptor + from ..ble.lwp3.bytecodes import HubKind from ..ble.nus import NUS_RX_UUID, NUS_TX_UUID from ..ble.pybricks import ( + DEVICE_NAME_UUID, FW_REV_UUID, PNP_ID_UUID, PYBRICKS_COMMAND_EVENT_UUID, @@ -38,6 +45,7 @@ from ..compile import compile_file, compile_multi_file from ..tools import chunk from ..tools.checksum import xor_bytes +from ..usb import LegoUsbPid from . import ConnectionState logger = logging.getLogger(__name__) @@ -138,6 +146,139 @@ def handler(_, data): await self._client.start_notify(NUS_TX_UUID, handler) +class _USBTransport(_Transport): + _device: USBDevice + _disconnected_callback: Callable + _ep_in: Endpoint + _ep_out: Endpoint + _notify_callbacks = {} + _monitor_task: asyncio.Task + + def __init__(self, device: USBDevice): + self._device = device + + async def connect(self, disconnected_callback: Callable) -> None: + self._disconnected_callback = disconnected_callback + self._device.set_configuration() + + # Save input and output endpoints + cfg = self._device.get_active_configuration() + intf = cfg[(0, 0)] + self._ep_in = find_descriptor( + intf, + custom_match=lambda e: endpoint_direction(e.bEndpointAddress) + == ENDPOINT_IN, + ) + self._ep_out = find_descriptor( + intf, + custom_match=lambda e: endpoint_direction(e.bEndpointAddress) + == ENDPOINT_OUT, + ) + + # Get length of BOS descriptor + bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0) + (ofst, bos_len) = struct.unpack(" None: + self._monitor_task.cancel() + self._disconnected_callback() + + async def get_firmware_version(self) -> Version: + return self._fw_version + + async def get_protocol_version(self) -> Version: + return self._protocol_version + + async def get_hub_type(self) -> Tuple[HubKind, int]: + hub_types = { + LegoUsbPid.SPIKE_PRIME: (HubKind.TECHNIC_LARGE, 0), + LegoUsbPid.ROBOT_INVENTOR: (HubKind.TECHNIC_LARGE, 1), + LegoUsbPid.SPIKE_ESSENTIAL: (HubKind.TECHNIC_SMALL, 0), + } + + return hub_types[self._device.idProduct] + + async def get_hub_capabilities(self) -> Tuple[int, HubCapabilityFlag, int]: + return ( + self._max_write_size, + self._capability_flags, + self._max_user_program_size, + ) + + async def send_command(self, command: bytes) -> None: + self._ep_out.write(UUID(PYBRICKS_COMMAND_EVENT_UUID).bytes_le + command) + + async def set_service_handler(self, callback: Callable) -> None: + self._notify_callbacks[PYBRICKS_COMMAND_EVENT_UUID] = callback + + async def _monitor_usb(self): + loop = asyncio.get_running_loop() + + while True: + msg = await loop.run_in_executor(None, self._read_usb) + + if msg is None: + continue + + if len(msg) > 16: + uuid = str(UUID(bytes_le=bytes(msg[:16]))) + if uuid in self._notify_callbacks: + callback = self._notify_callbacks[uuid] + if callback: + callback(bytes(msg[16:])) + + def _read_usb(self): + try: + msg = self._ep_in.read(self._ep_in.wMaxPacketSize) + return msg + except USBTimeoutError: + return None + + class PybricksHub: EOL = b"\r\n" # MicroPython EOL @@ -326,11 +467,12 @@ def _pybricks_service_handler(self, data: bytes) -> None: if self._enable_line_handler: self._handle_line_data(payload) - async def connect(self, device: BLEDevice): + async def connect(self, device): """Connects to a device that was discovered with :meth:`pybricksdev.ble.find_device` + or :meth:`usb.core.find` Args: - device: The device to connect to. + device: The device to connect to (`BLEDevice` or `USBDevice`). Raises: BleakError: if connecting failed (or old firmware without Device @@ -350,7 +492,12 @@ async def connect(self, device: BLEDevice): self.connection_state_observable.on_next, ConnectionState.DISCONNECTED ) - self._transport = _BLETransport(device) + if isinstance(device, BLEDevice): + self._transport = _BLETransport(device) + elif isinstance(device, USBDevice): + self._transport = _USBTransport(device) + else: + raise TypeError("Unsupported device type") def handle_disconnect(): logger.info("Disconnected!")