Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make stream recorder work concurrently #73478

Merged
merged 5 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion homeassistant/components/stream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,6 @@ async def async_record(
recorder.video_path = video_path

await self.start()
self._logger.debug("Started a stream recording of %s seconds", duration)

# Take advantage of lookback
hls: HlsStreamOutput = cast(HlsStreamOutput, self.outputs().get(HLS_PROVIDER))
Expand All @@ -512,6 +511,9 @@ async def async_record(
await hls.recv()
recorder.prepend(list(hls.get_segments())[-num_segments - 1 : -1])
uvjustin marked this conversation as resolved.
Show resolved Hide resolved

self._logger.debug("Started a stream recording of %s seconds", duration)
await recorder.async_record()

async def async_get_image(
self,
width: int | None = None,
Expand Down
1 change: 0 additions & 1 deletion homeassistant/components/stream/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ def cleanup(self) -> None:
"""Handle cleanup."""
self._event.set()
self.idle_timer.clear()
self._segments = deque(maxlen=self._segments.maxlen)


class StreamView(HomeAssistantView):
Expand Down
5 changes: 5 additions & 0 deletions homeassistant/components/stream/hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def name(self) -> str:
"""Return provider name."""
return HLS_PROVIDER

def cleanup(self) -> None:
"""Handle cleanup."""
super().cleanup()
self._segments.clear()

@property
def target_duration(self) -> float:
"""Return the target duration."""
Expand Down
219 changes: 116 additions & 103 deletions homeassistant/components/stream/recorder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
"""Provide functionality to record stream."""
from __future__ import annotations

from collections import deque
from io import BytesIO
import logging
import os
import threading

import av
from av.container import OutputContainer

from homeassistant.core import HomeAssistant, callback

Expand All @@ -27,99 +24,9 @@ def async_setup_recorder(hass: HomeAssistant) -> None:
"""Only here so Provider Registry works."""


def recorder_save_worker(file_out: str, segments: deque[Segment]) -> None:
"""Handle saving stream."""

if not segments:
_LOGGER.error("Recording failed to capture anything")
return

os.makedirs(os.path.dirname(file_out), exist_ok=True)

pts_adjuster: dict[str, int | None] = {"video": None, "audio": None}
output: OutputContainer | None = None
output_v = None
output_a = None

last_stream_id = None
# The running duration of processed segments. Note that this is in av.time_base
# units which seem to be defined inversely to how stream time_bases are defined
running_duration = 0

last_sequence = float("-inf")
for segment in segments:
# Because the stream_worker is in a different thread from the record service,
# the lookback segments may still have some overlap with the recorder segments
if segment.sequence <= last_sequence:
continue
last_sequence = segment.sequence

# Open segment
source = av.open(
BytesIO(segment.init + segment.get_data()),
"r",
format=SEGMENT_CONTAINER_FORMAT,
)
# Skip this segment if it doesn't have data
if source.duration is None:
source.close()
continue
source_v = source.streams.video[0]
source_a = source.streams.audio[0] if len(source.streams.audio) > 0 else None

# Create output on first segment
if not output:
output = av.open(
file_out,
"w",
format=RECORDER_CONTAINER_FORMAT,
container_options={
"video_track_timescale": str(int(1 / source_v.time_base))
},
)

# Add output streams if necessary
if not output_v:
output_v = output.add_stream(template=source_v)
context = output_v.codec_context
context.flags |= "GLOBAL_HEADER"
if source_a and not output_a:
output_a = output.add_stream(template=source_a)

# Recalculate pts adjustments on first segment and on any discontinuity
# We are assuming time base is the same across all discontinuities
if last_stream_id != segment.stream_id:
last_stream_id = segment.stream_id
pts_adjuster["video"] = int(
(running_duration - source.start_time)
/ (av.time_base * source_v.time_base)
)
if source_a:
pts_adjuster["audio"] = int(
(running_duration - source.start_time)
/ (av.time_base * source_a.time_base)
)

# Remux video
for packet in source.demux():
if packet.dts is None:
continue
packet.pts += pts_adjuster[packet.stream.type]
packet.dts += pts_adjuster[packet.stream.type]
packet.stream = output_v if packet.stream.type == "video" else output_a
output.mux(packet)

running_duration += source.duration - source.start_time

source.close()

if output is not None:
output.close()


@PROVIDERS.register(RECORDER_PROVIDER)
class RecorderOutput(StreamOutput):
"""Represents HLS Output formats."""
"""Represents the Recorder Output format."""

def __init__(
self,
Expand All @@ -141,13 +48,119 @@ def prepend(self, segments: list[Segment]) -> None:
self._segments.extendleft(reversed(segments))

def cleanup(self) -> None:
"""Write recording and clean up."""
_LOGGER.debug("Starting recorder worker thread")
thread = threading.Thread(
name="recorder_save_worker",
target=recorder_save_worker,
args=(self.video_path, self._segments.copy()),
)
thread.start()

"""Handle cleanup."""
self.idle_timer.idle = True
super().cleanup()

async def async_record(self) -> None:
"""Handle saving stream."""

os.makedirs(os.path.dirname(self.video_path), exist_ok=True)

pts_adjuster: dict[str, int | None] = {"video": None, "audio": None}
output: av.container.OutputContainer | None = None
output_v = None
output_a = None

last_stream_id = -1
# The running duration of processed segments. Note that this is in av.time_base
# units which seem to be defined inversely to how stream time_bases are defined
running_duration = 0

last_sequence = float("-inf")

def write_segment(segment: Segment) -> None:
"""Write a segment to output."""
nonlocal output, output_v, output_a, last_stream_id, running_duration, last_sequence
# Because the stream_worker is in a different thread from the record service,
# the lookback segments may still have some overlap with the recorder segments
if segment.sequence <= last_sequence:
return
last_sequence = segment.sequence

# Open segment
source = av.open(
BytesIO(segment.init + segment.get_data()),
"r",
format=SEGMENT_CONTAINER_FORMAT,
)
# Skip this segment if it doesn't have data
if source.duration is None:
source.close()
return
source_v = source.streams.video[0]
source_a = (
source.streams.audio[0] if len(source.streams.audio) > 0 else None
)

# Create output on first segment
if not output:
output = av.open(
self.video_path + ".tmp",
"w",
format=RECORDER_CONTAINER_FORMAT,
container_options={
"video_track_timescale": str(int(1 / source_v.time_base))
},
)

# Add output streams if necessary
if not output_v:
output_v = output.add_stream(template=source_v)
context = output_v.codec_context
context.flags |= "GLOBAL_HEADER"
if source_a and not output_a:
output_a = output.add_stream(template=source_a)

# Recalculate pts adjustments on first segment and on any discontinuity
# We are assuming time base is the same across all discontinuities
if last_stream_id != segment.stream_id:
last_stream_id = segment.stream_id
pts_adjuster["video"] = int(
(running_duration - source.start_time)
/ (av.time_base * source_v.time_base)
)
if source_a:
pts_adjuster["audio"] = int(
(running_duration - source.start_time)
/ (av.time_base * source_a.time_base)
)

# Remux video
for packet in source.demux():
if packet.dts is None:
continue
packet.pts += pts_adjuster[packet.stream.type]
packet.dts += pts_adjuster[packet.stream.type]
packet.stream = output_v if packet.stream.type == "video" else output_a
output.mux(packet)

running_duration += source.duration - source.start_time

source.close()

# Write lookback segments
while len(self._segments) > 1: # The last segment is in progress
await self._hass.async_add_executor_job(
write_segment, self._segments.popleft()
)
# Make sure the first segment has been added
uvjustin marked this conversation as resolved.
Show resolved Hide resolved
if not self._segments:
await self.recv()
# Write segments as soon as they are completed
while not self.idle:
await self.recv()
await self._hass.async_add_executor_job(
write_segment, self._segments.popleft()
)
# Write remaining segments
# Should only have 0 or 1 segments, but loop through just in case
while self._segments:
await self._hass.async_add_executor_job(
write_segment, self._segments.popleft()
)
if output is None:
_LOGGER.error("Recording failed to capture anything")
else:
output.close()
os.rename(self.video_path + ".tmp", self.video_path)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put this part in (writing to a file with a different name first) because I believe some users currently use the existence of the output file to check whether the recording is done yet.

59 changes: 1 addition & 58 deletions tests/components/stream/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@
from __future__ import annotations

import asyncio
from collections import deque
from http import HTTPStatus
import logging
import threading
from typing import Generator
from unittest.mock import Mock, patch

from aiohttp import web
import async_timeout
import pytest

from homeassistant.components.stream.core import Segment, StreamOutput
from homeassistant.components.stream.core import StreamOutput
from homeassistant.components.stream.worker import StreamState

from .common import generate_h264_video, stream_teardown
Expand Down Expand Up @@ -73,61 +71,6 @@ def stream_worker_sync(hass):
yield sync


class SaveRecordWorkerSync:
"""
Test fixture to manage RecordOutput thread for recorder_save_worker.

This is used to assert that the worker is started and stopped cleanly
to avoid thread leaks in tests.
"""

def __init__(self, hass):
"""Initialize SaveRecordWorkerSync."""
self._hass = hass
self._save_event = None
self._segments = None
self._save_thread = None
self.reset()

def recorder_save_worker(self, file_out: str, segments: deque[Segment]):
"""Mock method for patch."""
logging.debug("recorder_save_worker thread started")
assert self._save_thread is None
self._segments = segments
self._save_thread = threading.current_thread()
self._hass.loop.call_soon_threadsafe(self._save_event.set)

async def get_segments(self):
"""Return the recorded video segments."""
async with async_timeout.timeout(TEST_TIMEOUT):
await self._save_event.wait()
return self._segments

async def join(self):
"""Verify save worker was invoked and block on shutdown."""
async with async_timeout.timeout(TEST_TIMEOUT):
await self._save_event.wait()
self._save_thread.join(timeout=TEST_TIMEOUT)
assert not self._save_thread.is_alive()

def reset(self):
"""Reset callback state for reuse in tests."""
self._save_thread = None
self._save_event = asyncio.Event()


@pytest.fixture()
def record_worker_sync(hass):
"""Patch recorder_save_worker for clean thread shutdown for test."""
sync = SaveRecordWorkerSync(hass)
with patch(
"homeassistant.components.stream.recorder.recorder_save_worker",
side_effect=sync.recorder_save_worker,
autospec=True,
):
yield sync


class HLSSync:
"""Test fixture that intercepts stream worker calls to StreamOutput."""

Expand Down
14 changes: 8 additions & 6 deletions tests/components/stream/test_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,12 @@ async def test_remove_incomplete_segment_on_exit(
assert len(segments) == 3
assert not segments[-1].complete
stream_worker_sync.resume()
stream._thread_quit.set()
stream._thread.join()
stream._thread = None
await hass.async_block_till_done()
assert segments[-1].complete
assert len(segments) == 2
with patch("homeassistant.components.stream.Stream.remove_provider"):
# Patch remove_provider so the deque is not cleared
stream._thread_quit.set()
stream._thread.join()
stream._thread = None
await hass.async_block_till_done()
assert segments[-1].complete
assert len(segments) == 2
await stream.stop()
Loading