Skip to content

Commit

Permalink
Merge branch 'dev' into led_brightness
Browse files Browse the repository at this point in the history
  • Loading branch information
llluis authored Feb 8, 2024
2 parents 90b5f9a + c0e1b50 commit ba7bf37
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ build
htmlcov

/.venv/
/examples/.venv/
.mypy_cache/
__pycache__/

Expand Down
36 changes: 28 additions & 8 deletions examples/2mic_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
"""Controls the LEDs on the ReSpeaker 2mic HAT."""
"""Controls the LEDs and GPIO Button on the ReSpeaker 2mic HAT."""
import argparse
import asyncio
import logging
Expand All @@ -10,6 +10,8 @@

import gpiozero
import spidev
import RPi.GPIO as GPIO

from wyoming.asr import Transcript
from wyoming.event import Event
from wyoming.satellite import (
Expand All @@ -21,12 +23,13 @@
)
from wyoming.server import AsyncEventHandler, AsyncServer
from wyoming.vad import VoiceStarted
from wyoming.wake import Detection
from wyoming.wake import Detect, Detection

_LOGGER = logging.getLogger()

NUM_LEDS = 3
LEDS_GPIO = 12
BUTTON_GPIO = 17
RGB_MAP = {
"rgb": [3, 2, 1],
"rbg": [3, 1, 2],
Expand All @@ -43,24 +46,29 @@ async def main() -> None:
parser.add_argument("--uri", required=True, help="unix:// or tcp://")
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
parser.add_argument("--led-brightness", type=int, default=31, help="LED brightness (integer from 1 to 31)")
parser.add_argument("--log-format", default=logging.BASIC_FORMAT, help="Format for log messages")
args = parser.parse_args()

logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
_LOGGER.debug(args)
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO, format=args.log_format)

_LOGGER.info("Ready")
_LOGGER.debug(args)
_LOGGER.info("Event service Ready")

# Turn on power to LEDs
led_power = gpiozero.LED(LEDS_GPIO, active_high=False)
led_power.on()

leds = APA102(num_led=NUM_LEDS, global_brightness=args.led_brightness)

# GPIO Button
GPIO.setmode(GPIO.BCM)
GPIO.setup(BUTTON_GPIO, GPIO.IN)

# Start server
server = AsyncServer.from_uri(args.uri)

try:
await server.run(partial(LEDsEventHandler, args, leds))
await server.run(partial(EventHandler, args, leds))
except KeyboardInterrupt:
pass
finally:
Expand All @@ -78,7 +86,7 @@ async def main() -> None:
_GREEN = (0, 255, 0)


class LEDsEventHandler(AsyncEventHandler):
class EventHandler(AsyncEventHandler):
"""Event handler for clients."""

def __init__(
Expand All @@ -93,13 +101,25 @@ def __init__(
self.cli_args = cli_args
self.client_id = str(time.monotonic_ns())
self.leds = leds
self.detect_name = None

GPIO.add_event_detect(BUTTON_GPIO, GPIO.RISING, callback=self.button_callback)

_LOGGER.debug("Client connected: %s", self.client_id)


def button_callback(self, button_pin):
_LOGGER.debug("Button pressed #%s", button_pin)
asyncio.run(self.write_event(Detection(name=self.detect_name, timestamp=time.monotonic_ns()).event()))


async def handle_event(self, event: Event) -> bool:
_LOGGER.debug(event)

if StreamingStarted.is_type(event.type):
if Detect.is_type(event.type):
detect = Detect.from_event(event)
self.detect_name = detect.names[0]
elif StreamingStarted.is_type(event.type):
self.color(_YELLOW)
elif Detection.is_type(event.type):
self.color(_BLUE)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_wake_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(self) -> None:
self.wake_event = asyncio.Event()

async def read_event(self) -> Optional[Event]:
# Input only
return None
# Sends a detection event
return Detection().event()

async def write_event(self, event: Event) -> None:
if Detection.is_type(event.type):
Expand Down
91 changes: 80 additions & 11 deletions wyoming_satellite/satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ async def event_from_server(self, event: Event) -> None:
await self.trigger_detect()
elif Detection.is_type(event.type):
# Wake word detected
_LOGGER.debug("Wake word detected")
_LOGGER.debug("Remote wake word detected")
await self.trigger_detection(Detection.from_event(event))
elif VoiceStarted.is_type(event.type):
# STT start
Expand All @@ -285,7 +285,7 @@ async def event_from_server(self, event: Event) -> None:
if not AudioChunk.is_type(event.type):
await self.forward_event(event)

async def _send_run_pipeline(self) -> None:
async def _send_run_pipeline(self, ask: Optional[bool] = False) -> None:
"""Sends a RunPipeline event with the correct stages."""
if self.settings.wake.enabled:
# Local wake word detection
Expand All @@ -303,6 +303,16 @@ async def _send_run_pipeline(self) -> None:
# No audio output
end_stage = PipelineStage.HANDLE

if ask:
end_stage = PipelineStage.ASR
restart_on_end = False

_LOGGER.debug(
"RunPipeline from %s to %s",
start_stage,
end_stage,
)

run_pipeline = RunPipeline(
start_stage=start_stage, end_stage=end_stage, restart_on_end=restart_on_end
).event()
Expand Down Expand Up @@ -350,8 +360,6 @@ async def _connect_to_services(self) -> None:
self._event_task_proc(), name="event"
)

_LOGGER.info("Connected to services")

async def _disconnect_from_services(self) -> None:
"""Disconnects from running services."""
if self._mic_task is not None:
Expand Down Expand Up @@ -550,7 +558,8 @@ async def _disconnect() -> None:
event.type
):
await _disconnect()
await self.trigger_played()
if not hasattr(event, 'wav'):
await self.trigger_played()
snd_client = None # reconnect on next event
except asyncio.CancelledError:
break
Expand Down Expand Up @@ -596,6 +605,7 @@ async def _play_wav(
samples_per_chunk=self.settings.snd.samples_per_chunk,
volume_multiplier=self.settings.snd.volume_multiplier,
):
event.wav = True
await self.event_to_snd(event)
except Exception:
# Unmute in case of an error
Expand Down Expand Up @@ -720,6 +730,7 @@ async def _disconnect() -> None:
await asyncio.sleep(self.settings.wake.reconnect_seconds)
continue

_LOGGER.debug("Event received from wake service")
await self.event_from_wake(event)

except asyncio.CancelledError:
Expand Down Expand Up @@ -837,22 +848,69 @@ async def _disconnect() -> None:
if self._event_queue is None:
self._event_queue = asyncio.Queue()

event = await self._event_queue.get()

if event_client is None:
event_client = self._make_event_client()
assert event_client is not None
await event_client.connect()
_LOGGER.debug("Connected to event service")

await event_client.write_event(event)
# Reset
from_client_task = None
to_client_task = None
pending = set()
self._event_queue = asyncio.Queue()

# Inform event service of the wake word handled by this satellite instance
await self.forward_event(Detect(names=self.settings.wake.names).event())

# Read/write in "parallel"
if to_client_task is None:
# From satellite to event service
to_client_task = asyncio.create_task(
self._event_queue.get(), name="event_to_client"
)
pending.add(to_client_task)

if from_client_task is None:
# From event service to satellite
from_client_task = asyncio.create_task(
event_client.read_event(), name="event_from_client"
)
pending.add(from_client_task)

done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)

if to_client_task in done:
# Forward event to event service for handling
assert to_client_task is not None
event = to_client_task.result()
to_client_task = None
await event_client.write_event(event)

if from_client_task in done:
# Event from event service (button for detection)
assert from_client_task is not None
event = from_client_task.result()
from_client_task = None

if event is None:
_LOGGER.warning("Event service disconnected")
await _disconnect()
event_client = None # reconnect
await asyncio.sleep(self.settings.event.reconnect_seconds)
continue

_LOGGER.debug("Event received from event service")
if Detection.is_type(event.type):
await self.event_from_wake(event)
except asyncio.CancelledError:
break
except Exception:
_LOGGER.exception("Unexpected error in event read task")
await _disconnect()
event_client = None # reconnect
self._event_queue = None
await asyncio.sleep(self.settings.event.reconnect_seconds)

await _disconnect()
Expand All @@ -866,6 +924,7 @@ class AlwaysStreamingSatellite(SatelliteBase):

def __init__(self, settings: SatelliteSettings) -> None:
super().__init__(settings)
_LOGGER.debug("Initiating an AlwaysStreamingSatellite")
self.is_streaming = False

if settings.vad.enabled:
Expand Down Expand Up @@ -932,6 +991,7 @@ def __init__(self, settings: SatelliteSettings) -> None:
raise ValueError("VAD is not enabled")

super().__init__(settings)
_LOGGER.debug("Initiating a VadStreamingSatellite")
self.is_streaming = False
self.vad = SileroVad(
threshold=settings.vad.threshold, trigger_level=settings.vad.trigger_level
Expand Down Expand Up @@ -1093,6 +1153,7 @@ def __init__(self, settings: SatelliteSettings) -> None:
raise ValueError("Local wake word detection is not enabled")

super().__init__(settings)
_LOGGER.debug("Initiating a WakeStreamingSatellite")
self.is_streaming = False

# Timestamp in the future when the refractory period is over (set with
Expand All @@ -1116,7 +1177,14 @@ async def event_from_server(self, event: Event) -> None:
is_transcript = False
is_error = False

if RunSatellite.is_type(event.type):
if Detection.is_type(event.type):
if ((event.data.get("name") == "remote") or (event.data.get("name") == "ask")):
_LOGGER.debug("Detection called. Name: %s", event.data.get("name"))
# Remote request for Detection
await self.event_from_wake(event)
return

elif RunSatellite.is_type(event.type):
is_run_satellite = True
self._is_paused = False

Expand Down Expand Up @@ -1207,6 +1275,7 @@ async def event_from_wake(self, event: Event) -> None:
return

if Detection.is_type(event.type):
_LOGGER.debug("Detection triggered from event")
detection = Detection.from_event(event)

# Check refractory period to avoid multiple back-to-back detections
Expand Down Expand Up @@ -1238,7 +1307,7 @@ async def event_from_wake(self, event: Event) -> None:
# No refractory period
self.refractory_timestamp.pop(detection.name, None)

await self._send_run_pipeline()
await self._send_run_pipeline(event.data.get("name") == "ask")
await self.forward_event(event) # forward to event service
await self.trigger_detection(Detection.from_event(event))
await self.trigger_streaming_start()

0 comments on commit ba7bf37

Please sign in to comment.