diff --git a/pyproject.toml b/pyproject.toml index ef5fc7f..e06b909 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ requires-python = ">=3.8" dynamic = ["version"] dependencies = [ - "tqdm>=4.63.1,<5", "shortuuid>=0.5.0", "funcy>=1.14", "fsspec>=2022.10.0", @@ -40,6 +39,7 @@ tests = [ "pytest-mock", "pytest-benchmark", "reflink", + "tqdm>=4.63.1,<5", ] dev = [ "dvc-objects[tests]", diff --git a/src/dvc_objects/fs/base.py b/src/dvc_objects/fs/base.py index 4b31830..3295587 100644 --- a/src/dvc_objects/fs/base.py +++ b/src/dvc_objects/fs/base.py @@ -3,7 +3,7 @@ import logging import os import shutil -from functools import partial +from functools import partial, wraps from multiprocessing import cpu_count from typing import ( IO, @@ -26,16 +26,25 @@ from dvc_objects.executors import ThreadPoolExecutor, batch_coros from dvc_objects.utils import cached_property -from .callbacks import DEFAULT_CALLBACK, Callback +from .callbacks import ( + DEFAULT_CALLBACK, + Callback, + CallbackStream, + wrap_and_branch_callback, +) from .errors import RemoteMissingDepsError if TYPE_CHECKING: - from typing import BinaryIO, TextIO + from typing import BinaryIO, Callable, TextIO, TypeVar from fsspec.spec import AbstractFileSystem + from typing_extensions import ParamSpec from .path import Path + _P = ParamSpec("_P") + _R = TypeVar("_R") + logger = logging.getLogger(__name__) @@ -59,6 +68,16 @@ def __init__(self, link: str, fs: "FileSystem", path: str) -> None: ) +def with_callback(callback: "Callback", fn: "Callable[_P, _R]") -> "Callable[_P, _R]": + @wraps(fn) + def wrapped(*args: "_P.args", **kwargs: "_P.kwargs") -> "_R": + res = fn(*args, **kwargs) + callback.relative_update() + return res + + return wrapped + + class FileSystem: sep = "/" @@ -338,9 +357,10 @@ def exists( loop, ) return fut.result() - executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) - with executor: - return list(executor.map(callback.wrap_fn(self.fs.exists), path)) + + func = with_callback(callback, self.fs.exists) + with ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) as executor: + return list(executor.map(func, path)) def lexists(self, path: AnyFSPath) -> bool: return self.fs.lexists(path) @@ -478,10 +498,11 @@ def info(self, path, callback=DEFAULT_CALLBACK, batch_size=None, **kwargs): loop, ) return fut.result() - executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) - with executor: - func = partial(self.fs.info, **kwargs) - return list(executor.map(callback.wrap_fn(func), path)) + + func = partial(self.fs.info, **kwargs) + wrapped = with_callback(callback, func) + with ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) as executor: + return list(executor.map(wrapped, path)) def mkdir( self, path: AnyFSPath, create_parents: bool = True, **kwargs: Any @@ -502,7 +523,7 @@ def put_file( if size: callback.set_size(size) if hasattr(from_file, "read"): - stream = callback.wrap_attr(cast("BinaryIO", from_file)) + stream = cast("BinaryIO", CallbackStream(from_file, callback)) self.upload_fobj(stream, to_info, size=size) else: assert isinstance(from_file, str) @@ -573,7 +594,7 @@ def put( callback.set_size(len(from_infos)) executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) with executor: - put_file = callback.wrap_and_branch(self.put_file) + put_file = wrap_and_branch_callback(callback, self.put_file) list(executor.imap_unordered(put_file, from_infos, to_infos)) def get( @@ -592,7 +613,7 @@ def get_file(rpath, lpath, **kwargs): localfs.makedirs(localfs.path.parent(lpath), exist_ok=True) self.fs.get_file(rpath, lpath, **kwargs) - get_file = callback.wrap_and_branch(get_file) + get_file = wrap_and_branch_callback(callback, get_file) if isinstance(from_info, list) and isinstance(to_info, list): from_infos: List[AnyFSPath] = from_info diff --git a/src/dvc_objects/fs/callbacks.py b/src/dvc_objects/fs/callbacks.py index c5fda62..b84fa5f 100644 --- a/src/dvc_objects/fs/callbacks.py +++ b/src/dvc_objects/fs/callbacks.py @@ -1,103 +1,44 @@ from contextlib import ExitStack from functools import wraps -from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Dict, Optional import fsspec from dvc_objects.utils import cached_property if TYPE_CHECKING: - from typing import Awaitable, BinaryIO, Callable, TextIO, Union - - from typing_extensions import ParamSpec + from typing import BinaryIO, Callable, Union from dvc_objects._tqdm import Tqdm - _P = ParamSpec("_P") - _R = TypeVar("_R") - - -class _CallbackProtocol(Protocol): - def relative_update(self, inc: int = 1) -> None: - ... - - def branch( - self, - path_1: "Union[str, BinaryIO]", - path_2: str, - kwargs: Dict[str, Any], - child: Optional["Callback"] = None, - ) -> "Callback": - ... - - -class _DVCCallbackMixin(_CallbackProtocol): - @overload - def wrap_attr(self, fobj: "BinaryIO", method: str = "read") -> "BinaryIO": - ... - - @overload - def wrap_attr(self, fobj: "TextIO", method: str = "read") -> "TextIO": - ... - - def wrap_attr( - self, fobj: "Union[TextIO, BinaryIO]", method: str = "read" - ) -> "Union[TextIO, BinaryIO]": - from tqdm.utils import CallbackIOWrapper - - wrapped = CallbackIOWrapper(self.relative_update, fobj, method) - return cast("Union[TextIO, BinaryIO]", wrapped) - - def wrap_fn(self, fn: "Callable[_P, _R]") -> "Callable[_P, _R]": - @wraps(fn) - def wrapped(*args: "_P.args", **kwargs: "_P.kwargs") -> "_R": - res = fn(*args, **kwargs) - self.relative_update() - return res - - return wrapped - - def wrap_coro( - self, fn: "Callable[_P, Awaitable[_R]]" - ) -> "Callable[_P, Awaitable[_R]]": - @wraps(fn) - async def wrapped(*args: "_P.args", **kwargs: "_P.kwargs") -> "_R": - res = await fn(*args, **kwargs) - self.relative_update() - return res - return wrapped +class CallbackStream: + def __init__(self, stream, callback, method="read"): + self.stream = stream + if method == "write": - def wrap_and_branch(self, fn: "Callable") -> "Callable": - """ - Wraps a function, and pass a new child callback to it. - When the function completes, we increment the parent callback by 1. - """ - wrapped = self.wrap_fn(fn) + @wraps(stream.write) + def write(data, *args, **kwargs): + res = stream.write(data, *args, **kwargs) + callback.relative_update(len(data)) + return res - @wraps(fn) - def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs): - kw: Dict[str, Any] = dict(kwargs) - with self.branch(path1, path2, kw): - return wrapped(path1, path2, **kw) + self.write = write + else: - return func + @wraps(stream.read) + def read(*args, **kwargs): + data = stream.read(*args, **kwargs) + callback.relative_update(len(data)) + return data - def wrap_and_branch_coro(self, fn: "Callable") -> "Callable": - """ - Wraps a coroutine, and pass a new child callback to it. - When the coroutine completes, we increment the parent callback by 1. - """ - wrapped = self.wrap_coro(fn) + self.read = read - @wraps(fn) - async def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs): - kw: Dict[str, Any] = dict(kwargs) - with self.branch(path1, path2, kw): - return await wrapped(path1, path2, **kw) + def __getattr__(self, attr): + return getattr(self.stream, attr) - return func +class ScopedCallback(fsspec.Callback): def __enter__(self): return self @@ -107,21 +48,24 @@ def __exit__(self, *exc_args): def close(self): """Handle here on exit.""" - @classmethod - def as_tqdm_callback( - cls, - callback: Optional[fsspec.callbacks.Callback] = None, - **tqdm_kwargs: Any, + def branch( + self, + path_1: "Union[str, BinaryIO]", + path_2: str, + kwargs: Dict[str, Any], + child: Optional["Callback"] = None, ) -> "Callback": - if callback is None: - return TqdmCallback(**tqdm_kwargs) - if isinstance(callback, Callback): - return callback - return cast("Callback", _FsspecCallbackWrapper(callback)) + child = kwargs["callback"] = child or DEFAULT_CALLBACK + return child -class Callback(fsspec.Callback, _DVCCallbackMixin): - """Callback usable as a context manager, and a few helper methods.""" +class Callback(ScopedCallback): + def __getattr__(self, item): + if item in ["wrap_fn", "wrap_coro", "wrap_and_branch", "wrap_and_branch_coro"]: + raise AttributeError( + f"{type(self).__name__!r} object has no attribute {item!r}" + ) + return super().__getattr__(item) def relative_update(self, inc: int = 1) -> None: inc = inc if inc is not None else 0 @@ -139,17 +83,17 @@ def as_callback( return DEFAULT_CALLBACK if isinstance(maybe_callback, Callback): return maybe_callback - return _FsspecCallbackWrapper(maybe_callback) + return FsspecCallbackWrapper(maybe_callback) - def branch( - self, - path_1: "Union[str, BinaryIO]", - path_2: str, - kwargs: Dict[str, Any], - child: Optional["Callback"] = None, + @classmethod + def as_tqdm_callback( + cls, + callback: Optional[fsspec.callbacks.Callback] = None, + **tqdm_kwargs: Any, ) -> "Callback": - child = kwargs["callback"] = child or DEFAULT_CALLBACK - return child + if callback is None: + return TqdmCallback(**tqdm_kwargs) + return cls.as_callback(callback) class NoOpCallback(Callback, fsspec.callbacks.NoOpCallback): @@ -209,7 +153,7 @@ def branch( return super().branch(path_1, path_2, kwargs, child=child) -class _FsspecCallbackWrapper(fsspec.callbacks.Callback, _DVCCallbackMixin): +class FsspecCallbackWrapper(Callback): def __init__(self, callback: fsspec.callbacks.Callback): object.__setattr__(self, "_callback", callback) @@ -219,8 +163,57 @@ def __getattr__(self, name: str): def __setattr__(self, name: str, value: Any): setattr(self._callback, name, value) - def branch(self, *args, **kwargs): - return _FsspecCallbackWrapper(self._callback.branch(*args, **kwargs)) + def relative_update(self, inc: int = 1) -> None: + inc = inc if inc is not None else 0 + return self._callback.relative_update(inc) + + def absolute_update(self, value: int) -> None: + value = value if value is not None else self.value + return self._callback.absolute_update(value) + + def branch( + self, + path_1: "Union[str, BinaryIO]", + path_2: str, + kwargs: Dict[str, Any], + child: Optional["Callback"] = None, + ) -> "Callback": + if not child: + self._callback.branch(path_1, path_2, kwargs) + child = self.as_callback(kwargs.get("callback")) + return super().branch(path_1, path_2, kwargs, child=child) + + +def wrap_and_branch_callback(callback: "Callback", fn: "Callable") -> "Callable": + """ + Wraps a function, and pass a new child callback to it. + When the function completes, we increment the parent callback by 1. + """ + + @wraps(fn) + def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs): + with callback.branch(path1, path2, kwargs): + res = fn(path1, path2, **kwargs) + callback.relative_update() + return res + + return func + + +def wrap_and_branch_coro(callback: "Callback", fn: "Callable") -> "Callable": + """ + Wraps a coroutine, and pass a new child callback to it. + When the coroutine completes, we increment the parent callback by 1. + """ + + @wraps(fn) + async def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs): + with callback.branch(path1, path2, kwargs): + res = await fn(path1, path2, **kwargs) + callback.relative_update() + return res + + return func DEFAULT_CALLBACK = NoOpCallback() diff --git a/src/dvc_objects/fs/generic.py b/src/dvc_objects/fs/generic.py index 80b3cf3..2d724ed 100644 --- a/src/dvc_objects/fs/generic.py +++ b/src/dvc_objects/fs/generic.py @@ -10,7 +10,7 @@ from dvc_objects.executors import ThreadPoolExecutor, batch_coros -from .callbacks import DEFAULT_CALLBACK +from .callbacks import DEFAULT_CALLBACK, wrap_and_branch_callback, wrap_and_branch_coro from .local import LocalFileSystem, localfs from .utils import as_atomic, umask @@ -103,7 +103,7 @@ def copy( ) jobs = batch_size or to_fs.jobs - put_file = callback.wrap_and_branch(to_fs.put_file) + put_file = wrap_and_branch_callback(callback, to_fs.put_file) put_file_kwargs = {} if hasattr(to_fs.fs, "max_concurrency"): put_file_kwargs["max_concurrency"] = jobs if len(from_path) == 1 else 1 @@ -138,7 +138,7 @@ def _put( on_error: Optional[TransferErrorHandler] = None, ) -> None: jobs = batch_size or to_fs.jobs - put_file = callback.wrap_and_branch(to_fs.put_file) + put_file = wrap_and_branch_callback(callback, to_fs.put_file) put_file_kwargs = {} if hasattr(to_fs.fs, "max_concurrency"): put_file_kwargs["max_concurrency"] = jobs if len(from_paths) == 1 else 1 @@ -156,7 +156,7 @@ def _put_one(from_path: "AnyFSPath", to_path: "AnyFSPath"): return _put_one(from_paths[0], to_paths[0]) if to_fs.fs.async_impl: - put_coro = callback.wrap_and_branch_coro(to_fs.fs._put_file) + put_coro = wrap_and_branch_coro(callback, to_fs.fs._put_file) loop = get_loop() fut = asyncio.run_coroutine_threadsafe( batch_coros( @@ -191,7 +191,7 @@ def _get( # noqa: C901 on_error: Optional[TransferErrorHandler] = None, ) -> None: jobs = batch_size or from_fs.jobs - get_file = callback.wrap_and_branch(from_fs.get_file) + get_file = wrap_and_branch_callback(callback, from_fs.get_file) get_file_kwargs = {} if hasattr(from_fs.fs, "max_concurrency"): get_file_kwargs["max_concurrency"] = jobs if len(from_paths) == 1 else 1 @@ -214,7 +214,7 @@ def _get_one(from_path: "AnyFSPath", to_path: "AnyFSPath"): if from_fs.fs.async_impl: async def _get_one_coro(from_path: "AnyFSPath", to_path: "AnyFSPath"): - get_coro = callback.wrap_and_branch_coro(from_fs.fs._get_file) + get_coro = wrap_and_branch_coro(callback, from_fs.fs._get_file) with as_atomic(localfs, to_path, create_parents=True) as tmp_file: return await get_coro( from_path, tmp_file, callback=callback, **get_file_kwargs diff --git a/src/dvc_objects/fs/utils.py b/src/dvc_objects/fs/utils.py index 9792124..f816777 100644 --- a/src/dvc_objects/fs/utils.py +++ b/src/dvc_objects/fs/utils.py @@ -12,7 +12,7 @@ from dvc_objects.executors import ThreadPoolExecutor from . import system -from .callbacks import DEFAULT_CALLBACK +from .callbacks import DEFAULT_CALLBACK, CallbackStream if TYPE_CHECKING: from .base import AnyFSPath, FileSystem @@ -169,8 +169,8 @@ def copyfile( callback.set_size(total) with open(src, "rb") as fsrc, open(dest, "wb+") as fdest: - wrapped = callback.wrap_attr(fdest, "write") - shutil.copyfileobj(fsrc, wrapped, length=LOCAL_CHUNK_SIZE) # type: ignore[misc] + wrapped = CallbackStream(fdest, callback, "write") + shutil.copyfileobj(fsrc, wrapped, length=LOCAL_CHUNK_SIZE) def tmp_fname(fname: "AnyFSPath" = "") -> "AnyFSPath": diff --git a/tests/fs/test_callbacks.py b/tests/fs/test_callbacks.py index cba3193..ea02ec1 100644 --- a/tests/fs/test_callbacks.py +++ b/tests/fs/test_callbacks.py @@ -47,7 +47,9 @@ def _branch_fn(*args, callback: Optional["Callback"] = None, **kwargs): assert cb.value == 1 assert callback.value == 1 - fn = cb.wrap_and_branch(_branch_fn) - fn("foo", "bar", callback=callback) + with cb.branch("foo", "bar", {}) as child: + _branch_fn("foo", "bar", callback=child) + cb.relative_update() + assert cb.value == 2 assert callback.value == 2