Skip to content

Commit

Permalink
Fix SMB Music provider (#540)
Browse files Browse the repository at this point in the history
Create a single connection per action. This a bit slower but much more
reliable.
Now it seems to handle all test cases I throw at it just fine.

Also adjust the configuration a bit and split out the path into
server/host, share and subfolder
  • Loading branch information
marcelveldt authored Mar 17, 2023
1 parent 86edd56 commit 7034b70
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 108 deletions.
4 changes: 2 additions & 2 deletions music_assistant/common/helpers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ def _select_free_port():
return await asyncio.to_thread(_select_free_port)


async def get_ip_from_host(dns_name: str) -> str:
async def get_ip_from_host(dns_name: str) -> str | None:
"""Resolve (first) IP-address for given dns name."""

def _resolve():
try:
return socket.gethostbyname(dns_name)
except Exception: # pylint: disable=broad-except
# fail gracefully!
return dns_name
return None

return await asyncio.to_thread(_resolve)

Expand Down
2 changes: 2 additions & 0 deletions music_assistant/common/models/media_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ class StreamDetails(DataClassDictMixin):
data: Any = None
# if the url/file is supported by ffmpeg directly, use direct stream
direct: str | None = None
# bool to indicate that the providers 'get_audio_stream' supports seeking of the item
can_seek: bool = True
# callback: optional callback function (or coroutine) to call when the stream completes.
# needed for streaming provivders to report what is playing
# receives the streamdetails as only argument from which to grab
Expand Down
11 changes: 7 additions & 4 deletions music_assistant/server/helpers/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,13 @@ async def get_media_stream(
strip_silence_end = False

# collect all arguments for ffmpeg
seek_pos = seek_position if (streamdetails.direct or not streamdetails.can_seek) else 0
args = await _get_ffmpeg_args(
streamdetails=streamdetails,
sample_rate=sample_rate,
bit_depth=bit_depth,
seek_position=seek_position,
# only use ffmpeg seeking if the provider stream does not support seeking
seek_position=seek_pos,
fade_in=fade_in,
)

Expand All @@ -412,7 +414,8 @@ async def writer():
"""Task that grabs the source audio and feeds it to ffmpeg."""
LOGGER.debug("writer started for %s", streamdetails.uri)
music_prov = mass.get_provider(streamdetails.provider)
async for audio_chunk in music_prov.get_audio_stream(streamdetails, seek_position):
seek_pos = seek_position if streamdetails.can_seek else 0
async for audio_chunk in music_prov.get_audio_stream(streamdetails, seek_pos):
await ffmpeg_proc.write(audio_chunk)
# write eof when last packet is received
ffmpeg_proc.write_eof()
Expand Down Expand Up @@ -745,6 +748,8 @@ async def _get_ffmpeg_args(
]
# collect input args
input_args = []
if seek_position:
input_args += ["-ss", str(seek_position)]
if streamdetails.direct:
# ffmpeg can access the inputfile (or url) directly
if streamdetails.direct.startswith("http"):
Expand All @@ -766,8 +771,6 @@ async def _get_ffmpeg_args(
"5xx",
]

if seek_position:
input_args += ["-ss", str(seek_position)]
input_args += ["-i", streamdetails.direct]
else:
# the input is received from pipe/stdin
Expand Down
8 changes: 7 additions & 1 deletion music_assistant/server/providers/filesystem_local/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
PLAYLIST_EXTENSIONS = ("m3u", "pls")
SUPPORTED_EXTENSIONS = TRACK_EXTENSIONS + PLAYLIST_EXTENSIONS
IMAGE_EXTENSIONS = ("jpg", "jpeg", "JPG", "JPEG", "png", "PNG", "gif", "GIF")
SEEKABLE_FILES = (ContentType.MP3, ContentType.WAV, ContentType.FLAC)

SUPPORTED_FEATURES = (
ProviderFeature.LIBRARY_ARTISTS,
Expand Down Expand Up @@ -253,8 +254,9 @@ async def sync_library(
continue

try:
cur_checksums[item.path] = item.checksum
# continue if the item did not change (checksum still the same)
if item.checksum == prev_checksums.get(item.path):
cur_checksums[item.path] = item.checksum
continue

if item.ext in TRACK_EXTENSIONS:
Expand All @@ -275,6 +277,9 @@ async def sync_library(
except Exception as err: # pylint: disable=broad-except
# we don't want the whole sync to crash on one file so we catch all exceptions here
self.logger.exception("Error processing %s - %s", item.path, str(err))
else:
# save item's checksum only if the parse succeeded
cur_checksums[item.path] = item.checksum

# save checksums every 100 processed items
# this allows us to pickup where we leftoff when initial scan gets interrupted
Expand Down Expand Up @@ -624,6 +629,7 @@ async def get_stream_details(self, item_id: str) -> StreamDetails:
sample_rate=prov_mapping.sample_rate,
bit_depth=prov_mapping.bit_depth,
direct=file_item.local_path,
can_seek=prov_mapping.content_type in SEEKABLE_FILES,
)

async def get_audio_stream(
Expand Down
102 changes: 56 additions & 46 deletions music_assistant/server/providers/filesystem_smb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""SMB filesystem provider for Music Assistant."""

import contextvars
import logging
import os
from collections.abc import AsyncGenerator
Expand All @@ -9,7 +8,9 @@
from smb.base import SharedFile

from music_assistant.common.helpers.util import get_ip_from_host
from music_assistant.constants import CONF_PASSWORD, CONF_PATH, CONF_USERNAME
from music_assistant.common.models.errors import LoginFailed
from music_assistant.constants import CONF_PASSWORD, CONF_USERNAME
from music_assistant.server.controllers.cache import use_cache
from music_assistant.server.providers.filesystem_local.base import (
FileSystemItem,
FileSystemProviderBase,
Expand All @@ -21,6 +22,10 @@

from .helpers import AsyncSMB

CONF_HOST = "host"
CONF_SHARE = "share"
CONF_SUBFOLDER = "subfolder"


async def create_item(file_path: str, entry: SharedFile, root_path: str) -> FileSystemItem:
"""Create FileSystemItem from smb.SharedFile."""
Expand All @@ -37,9 +42,6 @@ async def create_item(file_path: str, entry: SharedFile, root_path: str) -> File
)


smb_conn_ctx = contextvars.ContextVar("smb_conn_ctx", default=None)


class SMBFileSystemProvider(FileSystemProviderBase):
"""Implementation of an SMB File System Provider."""

Expand All @@ -51,25 +53,31 @@ class SMBFileSystemProvider(FileSystemProviderBase):
async def setup(self) -> None:
"""Handle async initialization of the provider."""
# silence SMB.SMBConnection logger a bit
logging.getLogger("SMB.SMBConnection").setLevel("INFO")
# extract params from path
if self.config.get_value(CONF_PATH).startswith("\\\\"):
path_parts = self.config.get_value(CONF_PATH)[2:].split("\\", 2)
elif self.config.get_value(CONF_PATH).startswith("//"):
path_parts = self.config.get_value(CONF_PATH)[2:].split("/", 2)
elif self.config.get_value(CONF_PATH).startswith("smb://"):
path_parts = self.config.get_value(CONF_PATH)[6:].split("/", 2)
else:
path_parts = self.config.get_value(CONF_PATH).split(os.sep)
self._remote_name = path_parts[0]
self._service_name = path_parts[1]
if len(path_parts) > 2:
self._root_path = os.sep + path_parts[2]

default_target_ip = await get_ip_from_host(self._remote_name)
self._target_ip = self.config.get_value("target_ip") or default_target_ip
logging.getLogger("SMB.SMBConnection").setLevel("WARNING")

self._remote_name = self.config.get_value(CONF_HOST)
self._service_name = self.config.get_value(CONF_SHARE)

# validate provided path
subfolder: str = self.config.get_value(CONF_SUBFOLDER)
subfolder.replace("\\", "/")
if not subfolder.startswith("/"):
subfolder = "/" + subfolder
if not subfolder.endswith("/"):
subfolder += "/"
self._root_path = subfolder

# resolve dns name to IP
target_ip = await get_ip_from_host(self._remote_name)
if target_ip is None:
raise LoginFailed(
f"Unable to resolve {self._remote_name}, maybe use an IP address as remote host ?"
)
self._target_ip = target_ip

# test connection and return
# this code will raise if the connection did not succeed
async with self._get_smb_connection():
# test connection and return
return

async def listdir(
Expand All @@ -93,21 +101,23 @@ async def listdir(
abs_path = get_absolute_path(self._root_path, path)
async with self._get_smb_connection() as smb_conn:
path_result: list[SharedFile] = await smb_conn.list_path(abs_path)
for entry in path_result:
if entry.filename.startswith("."):
# skip invalid/system files and dirs
continue
file_path = os.path.join(path, entry.filename)
item = await create_item(file_path, entry, self._root_path)
if recursive and item.is_dir:
# yield sublevel recursively
try:
async for subitem in self.listdir(file_path, True):
yield subitem
except (OSError, PermissionError) as err:
self.logger.warning("Skip folder %s: %s", item.path, str(err))
elif item.is_file or item.is_dir:
yield item

for entry in path_result:
if entry.filename.startswith("."):
# skip invalid/system files and dirs
continue
file_path = os.path.join(path, entry.filename)
item = await create_item(file_path, entry, self._root_path)
if recursive and item.is_dir:
# yield sublevel recursively
try:
async for subitem in self.listdir(file_path, True):
yield subitem
except (OSError, PermissionError) as err:
self.logger.warning("Skip folder %s: %s", item.path, str(err))
else:
# yield single item (file or directory)
yield item

async def resolve(self, file_path: str) -> FileSystemItem:
"""Resolve (absolute or relative) path to FileSystemItem."""
Expand All @@ -124,6 +134,7 @@ async def resolve(self, file_path: str) -> FileSystemItem:
file_size=entry.file_size,
)

@use_cache(15 * 60)
async def exists(self, file_path: str) -> bool:
"""Return bool if this FileSystem musicprovider has given file/dir."""
abs_path = get_absolute_path(self._root_path, file_path)
Expand All @@ -147,20 +158,19 @@ async def write_file_content(self, file_path: str, data: bytes) -> None:
@asynccontextmanager
async def _get_smb_connection(self) -> AsyncGenerator[AsyncSMB, None]:
"""Get instance of AsyncSMB."""
# for a task that consists of multiple steps,
# the smb connection may be reused (shared through a contextvar)
if existing := smb_conn_ctx.get():
yield existing
return
# For now we just create a connection per call
# as that is the most reliable (but a bit slower)
# this could be improved by creating a connection pool
# if really needed

async with AsyncSMB(
remote_name=self._remote_name,
service_name=self._service_name,
username=self.config.get_value(CONF_USERNAME),
password=self.config.get_value(CONF_PASSWORD),
target_ip=self._target_ip,
options={key: value.value for key, value in self.config.values.items()},
use_ntlm_v2=self.config.get_value("use_ntlm_v2"),
sign_options=self.config.get_value("sign_options"),
is_direct_tcp=self.config.get_value("is_direct_tcp"),
) as smb_conn:
token = smb_conn_ctx.set(smb_conn)
yield smb_conn
smb_conn_ctx.reset(token)
Loading

0 comments on commit 7034b70

Please sign in to comment.