diff --git a/examples/ndsi-recv-events.py b/examples/ndsi-recv-events.py new file mode 100644 index 0000000..e801869 --- /dev/null +++ b/examples/ndsi-recv-events.py @@ -0,0 +1,62 @@ +import time + +# https://github.com/pupil-labs/pyndsi/tree/v1.0 +import ndsi # Main requirement + +EVENT_TYPE = "event" # Type of sensors that we are interested in +SENSORS = {} # Will store connected sensors + + +def main(): + # Start auto-discovery of Pupil Invisible Companion devices + network = ndsi.Network(formats={ndsi.DataFormat.V4}, callbacks=(on_network_event,)) + network.start() + + try: + # Event loop, runs until interrupted + while network.running: + # Check for recently connected/disconnected devices + if network.has_events: + network.handle_event() + + # Iterate over all connected devices + for event_sensor in SENSORS.values(): + # Fetch recent sensor configuration changes, + # required for pyndsi internals + while event_sensor.has_notifications: + event_sensor.handle_notification() + + # Fetch recent event data + for event in event_sensor.fetch_data(): + # Output: EventValue(timestamp, label) + print(event_sensor, event) + + time.sleep(0.1) + + # Catch interruption and disconnect gracefully + except (KeyboardInterrupt, SystemExit): + network.stop() + + +def on_network_event(network, event): + # Handle event sensor attachment + if event["subject"] == "attach" and event["sensor_type"] == EVENT_TYPE: + # Create new sensor, start data streaming, + # and request current configuration + sensor = network.sensor(event["sensor_uuid"]) + sensor.set_control_value("streaming", True) + sensor.refresh_controls() + + # Save sensor s.t. we can fetch data from it in main() + SENSORS[event["sensor_uuid"]] = sensor + print(f"Added sensor {sensor}...") + + # Handle event sensor detachment + if event["subject"] == "detach" and event["sensor_uuid"] in SENSORS: + # Known sensor has disconnected, remove from list + SENSORS[event["sensor_uuid"]].unlink() + del SENSORS[event["sensor_uuid"]] + print(f"Removed sensor {event['sensor_uuid']}...") + + +main() # Execute example diff --git a/ndsi/formatter.py b/ndsi/formatter.py index 3b83bfe..9ad9fc4 100644 --- a/ndsi/formatter.py +++ b/ndsi/formatter.py @@ -12,11 +12,17 @@ __all__ = [ - 'DataFormat', 'DataFormatter', 'DataMessage', - 'VideoDataFormatter', 'VideoValue', - 'GazeDataFormatter', 'GazeValue', - 'AnnotateDataFormatter', 'AnnotateValue', - 'IMUDataFormatter', 'IMUValue', + "DataFormat", + "DataFormatter", + "DataMessage", + "VideoDataFormatter", + "VideoValue", + "GazeDataFormatter", + "GazeValue", + "AnnotateDataFormatter", + "AnnotateValue", + "IMUDataFormatter", + "IMUValue", ] @@ -37,15 +43,16 @@ class DataFormat(enum.Enum): """ `DataFormat` enum represents the format for serializing and deserializing data between NDSI hosts and clients. """ - V3 = 'v3' - V4 = 'v4' + + V3 = "v3" + V4 = "v4" @staticmethod - def latest() -> 'DataFormat': + def latest() -> "DataFormat": return max(DataFormat.supported_formats(), key=lambda f: f.version_major) @staticmethod - def supported_formats() -> typing.Set['DataFormat']: + def supported_formats() -> typing.Set["DataFormat"]: return set(DataFormat) @property @@ -62,13 +69,12 @@ class DataMessage(typing.NamedTuple): body: bytes -DT = typing.TypeVar('DataValue') +DT = typing.TypeVar("DataValue") class DataFormatter(typing.Generic[DT], abc.ABC): - @abc.abstractstaticmethod - def get_formatter(format: DataFormat) -> 'DataFormatter': + def get_formatter(format: DataFormat) -> "DataFormatter": pass @abc.abstractmethod @@ -87,7 +93,8 @@ class UnsupportedFormatter(DataFormatter[typing.Any]): """ Represents a formatter that is not supported for a specific data format and sensor type combination. """ - def get_formatter(format: DataFormat) -> 'UnsupportedFormatter': + + def get_formatter(format: DataFormat) -> "UnsupportedFormatter": return UnsupportedFormatter() def encode_msg(value: typing.Any) -> DataMessage: @@ -114,7 +121,9 @@ def reset(self): @staticmethod @functools.lru_cache(maxsize=1, typed=True) - def get_formatter(format: DataFormat) -> typing.Union['VideoDataFormatter', UnsupportedFormatter]: + def get_formatter( + format: DataFormat, + ) -> typing.Union["VideoDataFormatter", UnsupportedFormatter]: if format == DataFormat.V3: return _VideoDataFormatter_V3() if format == DataFormat.V4: @@ -129,7 +138,7 @@ class _VideoDataFormatter_V3(VideoDataFormatter): def decode_msg(self, data_msg: DataMessage) -> VideoValue: meta_data = struct.unpack(" us + meta_data[4] *= 1e6 # Convert timestamp s -> us meta_data = tuple(meta_data) if meta_data[0] == VIDEO_FRAME_FORMAT_MJPEG: return self._frame_factory.create_jpeg_frame(data_msg.body, meta_data) @@ -138,14 +147,14 @@ def decode_msg(self, data_msg: DataMessage) -> VideoValue: self._newest_h264_frame = frame or self._newest_h264_frame return self._newest_h264_frame else: - raise StreamError('Frame was not of format MJPEG or H264') + raise StreamError("Frame was not of format MJPEG or H264") class _VideoDataFormatter_V4(VideoDataFormatter): def decode_msg(self, data_msg: DataMessage) -> VideoValue: meta_data = struct.unpack(" us + meta_data[4] /= 1e3 # Convert timestamp ns -> us meta_data = tuple(meta_data) if meta_data[0] == VIDEO_FRAME_FORMAT_MJPEG: return self._frame_factory.create_jpeg_frame(data_msg.body, meta_data) @@ -154,7 +163,7 @@ def decode_msg(self, data_msg: DataMessage) -> VideoValue: self._newest_h264_frame = frame or self._newest_h264_frame return self._newest_h264_frame else: - raise StreamError('Frame was not of format MJPEG or H264') + raise StreamError("Frame was not of format MJPEG or H264") ########## @@ -169,7 +178,9 @@ class AnnotateValue(typing.NamedTuple): class AnnotateDataFormatter(DataFormatter[AnnotateValue]): @staticmethod @functools.lru_cache(maxsize=1, typed=True) - def get_formatter(format: DataFormat) -> typing.Union['AnnotateDataFormatter', UnsupportedFormatter]: + def get_formatter( + format: DataFormat, + ) -> typing.Union["AnnotateDataFormatter", UnsupportedFormatter]: if format == DataFormat.V3: return _AnnotateDataFormatter_V3() if format == DataFormat.V4: @@ -207,7 +218,9 @@ class GazeValue(typing.NamedTuple): class GazeDataFormatter(DataFormatter[GazeValue]): @staticmethod @functools.lru_cache(maxsize=1, typed=True) - def get_formatter(format: DataFormat) -> typing.Union['GazeDataFormatter', UnsupportedFormatter]: + def get_formatter( + format: DataFormat, + ) -> typing.Union["GazeDataFormatter", UnsupportedFormatter]: if format == DataFormat.V3: return UnsupportedFormatter() if format == DataFormat.V4: @@ -220,7 +233,7 @@ def encode_msg(self, value: GazeValue) -> DataMessage: class _GazeDataFormatter_V4(GazeDataFormatter): def decode_msg(self, data_msg: DataMessage) -> GazeValue: - ts, = struct.unpack(" typing.Union['IMUDataFormatter', UnsupportedFormatter]: + def get_formatter( + format: DataFormat, + ) -> typing.Union["IMUDataFormatter", UnsupportedFormatter]: if format == DataFormat.V3: return _IMUDataFormatter_V3() if format == DataFormat.V4: @@ -267,7 +282,9 @@ class _IMUDataFormatter_V3(IMUDataFormatter): ) def decode_msg(self, data_msg: DataMessage) -> IMUValue: - content = np.frombuffer(data_msg.body, dtype=self.CONTENT_DTYPE).view(np.recarray) + content = np.frombuffer(data_msg.body, dtype=self.CONTENT_DTYPE).view( + np.recarray + ) return IMUValue(*content) @@ -285,5 +302,56 @@ class _IMUDataFormatter_V4(IMUDataFormatter): ) def decode_msg(self, data_msg: DataMessage) -> IMUValue: - content = np.frombuffer(data_msg.body, dtype=self.CONTENT_DTYPE).view(np.recarray) + content = np.frombuffer(data_msg.body, dtype=self.CONTENT_DTYPE).view( + np.recarray + ) return IMUValue(*content) + + +########## + + +class EventValue(typing.NamedTuple): + timestamp: float + label: str + + +class EventDataFormatter(DataFormatter[EventValue]): + @staticmethod + @functools.lru_cache(maxsize=1, typed=True) + def get_formatter( + format: DataFormat, + ) -> typing.Union["EventDataFormatter", UnsupportedFormatter]: + if format == DataFormat.V3: + return UnsupportedFormatter() + if format == DataFormat.V4: + return _EventDataFormatter_V4() + raise ValueError(format) + + def encode_msg(self, value: EventValue) -> DataMessage: + raise NotImplementedError() + + +class _EventDataFormatter_V4(EventDataFormatter): + + _encoding_lookup = { + 0: "utf-8", + } + + def decode_msg(self, data_msg: DataMessage) -> EventValue: + """ + 1. sensor UUID + 2. header: + - int_64 timestamp_le + - uint32 body_length_le + - uint32 encoding_le + = 0 -> "utf-8" + 3. body: + - `encoding_le` encoded string of lenght `body_length_le` + """ + ts, len_, enc_code = struct.unpack(" typing.Set['SensorType']: + def supported_types() -> typing.Set["SensorType"]: return set(SensorType) @staticmethod - def supported_sensor_type_from_str(sensor_type: str) -> typing.Optional['SensorType']: + def supported_sensor_type_from_str( + sensor_type: str, + ) -> typing.Optional["SensorType"]: try: sensor_type = SensorType(sensor_type) except ValueError: @@ -76,7 +83,6 @@ def __str__(self) -> str: class Sensor: - @staticmethod def class_for_type(sensor_type: SensorType): try: @@ -85,27 +91,29 @@ def class_for_type(sensor_type: SensorType): raise ValueError("Unknown sensor type: {}".format(sensor_type)) @staticmethod - def create_sensor(sensor_type: SensorType, **kwargs) -> 'Sensor': + def create_sensor(sensor_type: SensorType, **kwargs) -> "Sensor": sensor_class = Sensor.class_for_type(sensor_type=sensor_type) # TODO: Passing sensor_type to the class init as str, to preserve API compatibility. # Ideally, the sensor_type passed and stored by Sensor is of type SensorType. - kwargs['sensor_type'] = str(sensor_type) + kwargs["sensor_type"] = str(sensor_type) return sensor_class(**kwargs) - def __init__(self, - format: DataFormat, - host_uuid, - host_name, - sensor_uuid, - sensor_name, - sensor_type, - notify_endpoint, - command_endpoint, - data_endpoint=None, - context=None, - callbacks=()): + def __init__( + self, + format: DataFormat, + host_uuid, + host_name, + sensor_uuid, + sensor_name, + sensor_type, + notify_endpoint, + command_endpoint, + data_endpoint=None, + context=None, + callbacks=(), + ): self.format = format - self.callbacks = [self.on_notification]+list(callbacks) + self.callbacks = [self.on_notification] + list(callbacks) self.context = context or zmq.Context() self.host_uuid = host_uuid self.host_name = host_name @@ -162,22 +170,32 @@ def has_data(self): raise NotDataSubSupportedError() def __str__(self): - return '<{} {}@{} [{}]>'.format(__name__, self.name, self.host_name, self.type) + return "<{} {}@{} [{}]>".format(__name__, self.name, self.host_name, self.type) def handle_notification(self): raw_notification = self.notify_sub.recv_multipart() if len(raw_notification) != 2: - logger.debug('Message for sensor {} has not correct amount of frames: {}'.format(self.uuid,raw_notification)) + logger.debug( + "Message for sensor {} has not correct amount of frames: {}".format( + self.uuid, raw_notification + ) + ) return sender_id = raw_notification[0].decode() notification_payload = raw_notification[1].decode() try: if sender_id != self.uuid: - raise ValueError('Message was destined for {} but was recieved by {}'.format(sender_id, self.uuid)) + raise ValueError( + "Message was destined for {} but was recieved by {}".format( + sender_id, self.uuid + ) + ) notification = serial.loads(notification_payload) - notification['subject'] + notification["subject"] except serial.decoder.JSONDecodeError: - logger.debug('JSONDecodeError for payload: `{}`'.format(notification_payload)) + logger.debug( + "JSONDecodeError for payload: `{}`".format(notification_payload) + ) except Exception: logger.debug(tb.format_exc()) else: @@ -191,31 +209,38 @@ def execute_callbacks(self, event): callback(self, event) def on_notification(self, caller, notification): - if notification['subject'] == 'update': + if notification["subject"] == "update": + class UnsettableDict(dict): def __getitem__(self, key): return self.get(key) + def __setitem__(self, key, value): - raise ValueError('Dictionary is read-only. Use Sensor.set_control_value instead.') + raise ValueError( + "Dictionary is read-only. Use Sensor.set_control_value instead." + ) - ctrl_id_key = notification['control_id'] + ctrl_id_key = notification["control_id"] if ctrl_id_key in self.controls: - self.controls[ctrl_id_key].update(UnsettableDict(notification['changes'])) - else: self.controls[ctrl_id_key] = UnsettableDict(notification['changes']) - elif notification['subject'] == 'remove': + self.controls[ctrl_id_key].update( + UnsettableDict(notification["changes"]) + ) + else: + self.controls[ctrl_id_key] = UnsettableDict(notification["changes"]) + elif notification["subject"] == "remove": try: - del self.controls[notification['control_id']] + del self.controls[notification["control_id"]] except KeyError: pass - def get_data(self,copy=True): + def get_data(self, copy=True): try: return self.data_sub.recv_multipart(copy=copy) except AttributeError: raise NotDataSubSupportedError() def refresh_controls(self): - cmd = serial.dumps({'action': 'refresh_controls'}) + cmd = serial.dumps({"action": "refresh_controls"}) self.command_push.send_string(self.uuid, flags=zmq.SNDMORE) self.command_push.send_string(cmd) @@ -225,36 +250,46 @@ def reset_all_control_values(self): def reset_control_value(self, control_id): if control_id in self.controls: - if 'def' in self.controls[control_id]: - value = self.controls[control_id]['def'] + if "def" in self.controls[control_id]: + value = self.controls[control_id]["def"] self.set_control_value(control_id, value) else: - logger.error(('Could not reset control `{}` because it does not have a default value.').format(control_id)) - else: logger.error('Could not reset unknown control `{}`'.format(control_id)) + logger.error( + ( + "Could not reset control `{}` because it does not have a default value." + ).format(control_id) + ) + else: + logger.error("Could not reset unknown control `{}`".format(control_id)) def set_control_value(self, control_id, value): try: - dtype = self.controls[control_id]['dtype'] - if dtype == 'bool': value = bool(value) - elif dtype == 'string': value = str(value) - elif dtype == 'integer': value = int(value) - elif dtype == 'float': value = float(value) - elif dtype == 'intmapping': value = int(value) - elif dtype == 'strmapping': value = str(value) + dtype = self.controls[control_id]["dtype"] + if dtype == "bool": + value = bool(value) + elif dtype == "string": + value = str(value) + elif dtype == "integer": + value = int(value) + elif dtype == "float": + value = float(value) + elif dtype == "intmapping": + value = int(value) + elif dtype == "strmapping": + value = str(value) except KeyError: pass - cmd = serial.dumps({ - "action": "set_control_value", - "control_id": control_id, - "value": value}) + cmd = serial.dumps( + {"action": "set_control_value", "control_id": control_id, "value": value} + ) self.command_push.send_string(self.uuid, flags=zmq.SNDMORE) self.command_push.send_string(cmd) -SensorFetchDataValue = typing.TypeVar('FetchDataValue') +SensorFetchDataValue = typing.TypeVar("FetchDataValue") -class SensorFetchDataMixin(typing.Generic[SensorFetchDataValue], abc.ABC): +class SensorFetchDataMixin(typing.Generic[SensorFetchDataValue], abc.ABC): @property @abc.abstractmethod def formatter(self) -> DataFormatter[SensorFetchDataValue]: @@ -296,9 +331,9 @@ def get_newest_data_frame(self, timeout=None): if newest_frame is not None: return newest_frame else: - raise StreamError('Operation timed out.') + raise StreamError("Operation timed out.") else: - raise StreamError('Operation timed out.') + raise StreamError("Operation timed out.") class AnnotateSensor(SensorFetchDataMixin[AnnotateValue], Sensor): @@ -328,10 +363,17 @@ def formatter(self) -> IMUDataFormatter: return IMUDataFormatter.get_formatter(format=self.format) +class EventSensor(SensorFetchDataMixin[EventValue], Sensor): + @property + def formatter(self) -> EventDataFormatter: + return EventDataFormatter.get_formatter(format=self.format) + + _SENSOR_TYPE_CLASS_MAP = { SensorType.HARDWARE: Sensor, SensorType.VIDEO: VideoSensor, SensorType.ANNOTATE: AnnotateSensor, SensorType.GAZE: GazeSensor, SensorType.IMU: IMUSensor, + SensorType.EVENT: EventSensor, }