Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SMB Music provider #540

Merged
merged 4 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions music_assistant/common/helpers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ def _select_free_port():
return await asyncio.to_thread(_select_free_port)


async def get_ip_from_host(dns_name: str) -> str:
async def get_ip_from_host(dns_name: str) -> str | None:
"""Resolve (first) IP-address for given dns name."""

def _resolve():
try:
return socket.gethostbyname(dns_name)
except Exception: # pylint: disable=broad-except
# fail gracefully!
return dns_name
return None

return await asyncio.to_thread(_resolve)

Expand Down
2 changes: 2 additions & 0 deletions music_assistant/common/models/media_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ class StreamDetails(DataClassDictMixin):
data: Any = None
# if the url/file is supported by ffmpeg directly, use direct stream
direct: str | None = None
# bool to indicate that the providers 'get_audio_stream' supports seeking of the item
can_seek: bool = True
# callback: optional callback function (or coroutine) to call when the stream completes.
# needed for streaming provivders to report what is playing
# receives the streamdetails as only argument from which to grab
Expand Down
11 changes: 7 additions & 4 deletions music_assistant/server/helpers/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,13 @@ async def get_media_stream(
strip_silence_end = False

# collect all arguments for ffmpeg
seek_pos = seek_position if (streamdetails.direct or not streamdetails.can_seek) else 0
args = await _get_ffmpeg_args(
streamdetails=streamdetails,
sample_rate=sample_rate,
bit_depth=bit_depth,
seek_position=seek_position,
# only use ffmpeg seeking if the provider stream does not support seeking
seek_position=seek_pos,
fade_in=fade_in,
)

Expand All @@ -412,7 +414,8 @@ async def writer():
"""Task that grabs the source audio and feeds it to ffmpeg."""
LOGGER.debug("writer started for %s", streamdetails.uri)
music_prov = mass.get_provider(streamdetails.provider)
async for audio_chunk in music_prov.get_audio_stream(streamdetails, seek_position):
seek_pos = seek_position if streamdetails.can_seek else 0
async for audio_chunk in music_prov.get_audio_stream(streamdetails, seek_pos):
await ffmpeg_proc.write(audio_chunk)
# write eof when last packet is received
ffmpeg_proc.write_eof()
Expand Down Expand Up @@ -745,6 +748,8 @@ async def _get_ffmpeg_args(
]
# collect input args
input_args = []
if seek_position:
input_args += ["-ss", str(seek_position)]
if streamdetails.direct:
# ffmpeg can access the inputfile (or url) directly
if streamdetails.direct.startswith("http"):
Expand All @@ -766,8 +771,6 @@ async def _get_ffmpeg_args(
"5xx",
]

if seek_position:
input_args += ["-ss", str(seek_position)]
input_args += ["-i", streamdetails.direct]
else:
# the input is received from pipe/stdin
Expand Down
8 changes: 7 additions & 1 deletion music_assistant/server/providers/filesystem_local/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
PLAYLIST_EXTENSIONS = ("m3u", "pls")
SUPPORTED_EXTENSIONS = TRACK_EXTENSIONS + PLAYLIST_EXTENSIONS
IMAGE_EXTENSIONS = ("jpg", "jpeg", "JPG", "JPEG", "png", "PNG", "gif", "GIF")
SEEKABLE_FILES = (ContentType.MP3, ContentType.WAV, ContentType.FLAC)

SUPPORTED_FEATURES = (
ProviderFeature.LIBRARY_ARTISTS,
Expand Down Expand Up @@ -253,8 +254,9 @@ async def sync_library(
continue

try:
cur_checksums[item.path] = item.checksum
# continue if the item did not change (checksum still the same)
if item.checksum == prev_checksums.get(item.path):
cur_checksums[item.path] = item.checksum
continue

if item.ext in TRACK_EXTENSIONS:
Expand All @@ -275,6 +277,9 @@ async def sync_library(
except Exception as err: # pylint: disable=broad-except
# we don't want the whole sync to crash on one file so we catch all exceptions here
self.logger.exception("Error processing %s - %s", item.path, str(err))
else:
# save item's checksum only if the parse succeeded
cur_checksums[item.path] = item.checksum

# save checksums every 100 processed items
# this allows us to pickup where we leftoff when initial scan gets interrupted
Expand Down Expand Up @@ -624,6 +629,7 @@ async def get_stream_details(self, item_id: str) -> StreamDetails:
sample_rate=prov_mapping.sample_rate,
bit_depth=prov_mapping.bit_depth,
direct=file_item.local_path,
can_seek=prov_mapping.content_type in SEEKABLE_FILES,
)

async def get_audio_stream(
Expand Down
102 changes: 56 additions & 46 deletions music_assistant/server/providers/filesystem_smb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""SMB filesystem provider for Music Assistant."""

import contextvars
import logging
import os
from collections.abc import AsyncGenerator
Expand All @@ -9,7 +8,9 @@
from smb.base import SharedFile

from music_assistant.common.helpers.util import get_ip_from_host
from music_assistant.constants import CONF_PASSWORD, CONF_PATH, CONF_USERNAME
from music_assistant.common.models.errors import LoginFailed
from music_assistant.constants import CONF_PASSWORD, CONF_USERNAME
from music_assistant.server.controllers.cache import use_cache
from music_assistant.server.providers.filesystem_local.base import (
FileSystemItem,
FileSystemProviderBase,
Expand All @@ -21,6 +22,10 @@

from .helpers import AsyncSMB

CONF_HOST = "host"
CONF_SHARE = "share"
CONF_SUBFOLDER = "subfolder"


async def create_item(file_path: str, entry: SharedFile, root_path: str) -> FileSystemItem:
"""Create FileSystemItem from smb.SharedFile."""
Expand All @@ -37,9 +42,6 @@ async def create_item(file_path: str, entry: SharedFile, root_path: str) -> File
)


smb_conn_ctx = contextvars.ContextVar("smb_conn_ctx", default=None)


class SMBFileSystemProvider(FileSystemProviderBase):
"""Implementation of an SMB File System Provider."""

Expand All @@ -51,25 +53,31 @@ class SMBFileSystemProvider(FileSystemProviderBase):
async def setup(self) -> None:
"""Handle async initialization of the provider."""
# silence SMB.SMBConnection logger a bit
logging.getLogger("SMB.SMBConnection").setLevel("INFO")
# extract params from path
if self.config.get_value(CONF_PATH).startswith("\\\\"):
path_parts = self.config.get_value(CONF_PATH)[2:].split("\\", 2)
elif self.config.get_value(CONF_PATH).startswith("//"):
path_parts = self.config.get_value(CONF_PATH)[2:].split("/", 2)
elif self.config.get_value(CONF_PATH).startswith("smb://"):
path_parts = self.config.get_value(CONF_PATH)[6:].split("/", 2)
else:
path_parts = self.config.get_value(CONF_PATH).split(os.sep)
self._remote_name = path_parts[0]
self._service_name = path_parts[1]
if len(path_parts) > 2:
self._root_path = os.sep + path_parts[2]

default_target_ip = await get_ip_from_host(self._remote_name)
self._target_ip = self.config.get_value("target_ip") or default_target_ip
logging.getLogger("SMB.SMBConnection").setLevel("WARNING")

self._remote_name = self.config.get_value(CONF_HOST)
self._service_name = self.config.get_value(CONF_SHARE)

# validate provided path
subfolder: str = self.config.get_value(CONF_SUBFOLDER)
subfolder.replace("\\", "/")
if not subfolder.startswith("/"):
subfolder = "/" + subfolder
if not subfolder.endswith("/"):
subfolder += "/"
self._root_path = subfolder

# resolve dns name to IP
target_ip = await get_ip_from_host(self._remote_name)
if target_ip is None:
raise LoginFailed(
f"Unable to resolve {self._remote_name}, maybe use an IP address as remote host ?"
)
self._target_ip = target_ip

# test connection and return
# this code will raise if the connection did not succeed
async with self._get_smb_connection():
# test connection and return
return

async def listdir(
Expand All @@ -93,21 +101,23 @@ async def listdir(
abs_path = get_absolute_path(self._root_path, path)
async with self._get_smb_connection() as smb_conn:
path_result: list[SharedFile] = await smb_conn.list_path(abs_path)
for entry in path_result:
if entry.filename.startswith("."):
# skip invalid/system files and dirs
continue
file_path = os.path.join(path, entry.filename)
item = await create_item(file_path, entry, self._root_path)
if recursive and item.is_dir:
# yield sublevel recursively
try:
async for subitem in self.listdir(file_path, True):
yield subitem
except (OSError, PermissionError) as err:
self.logger.warning("Skip folder %s: %s", item.path, str(err))
elif item.is_file or item.is_dir:
yield item

for entry in path_result:
if entry.filename.startswith("."):
# skip invalid/system files and dirs
continue
file_path = os.path.join(path, entry.filename)
item = await create_item(file_path, entry, self._root_path)
if recursive and item.is_dir:
# yield sublevel recursively
try:
async for subitem in self.listdir(file_path, True):
yield subitem
except (OSError, PermissionError) as err:
self.logger.warning("Skip folder %s: %s", item.path, str(err))
else:
# yield single item (file or directory)
yield item

async def resolve(self, file_path: str) -> FileSystemItem:
"""Resolve (absolute or relative) path to FileSystemItem."""
Expand All @@ -124,6 +134,7 @@ async def resolve(self, file_path: str) -> FileSystemItem:
file_size=entry.file_size,
)

@use_cache(15 * 60)
async def exists(self, file_path: str) -> bool:
"""Return bool if this FileSystem musicprovider has given file/dir."""
abs_path = get_absolute_path(self._root_path, file_path)
Expand All @@ -147,20 +158,19 @@ async def write_file_content(self, file_path: str, data: bytes) -> None:
@asynccontextmanager
async def _get_smb_connection(self) -> AsyncGenerator[AsyncSMB, None]:
"""Get instance of AsyncSMB."""
# for a task that consists of multiple steps,
# the smb connection may be reused (shared through a contextvar)
if existing := smb_conn_ctx.get():
yield existing
return
# For now we just create a connection per call
# as that is the most reliable (but a bit slower)
# this could be improved by creating a connection pool
# if really needed

async with AsyncSMB(
remote_name=self._remote_name,
service_name=self._service_name,
username=self.config.get_value(CONF_USERNAME),
password=self.config.get_value(CONF_PASSWORD),
target_ip=self._target_ip,
options={key: value.value for key, value in self.config.values.items()},
use_ntlm_v2=self.config.get_value("use_ntlm_v2"),
sign_options=self.config.get_value("sign_options"),
is_direct_tcp=self.config.get_value("is_direct_tcp"),
) as smb_conn:
token = smb_conn_ctx.set(smb_conn)
yield smb_conn
smb_conn_ctx.reset(token)
Loading