Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup callbacks #259

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ classifiers = [
requires-python = ">=3.8"
dynamic = ["version"]
dependencies = [
"tqdm>=4.63.1,<5",
"shortuuid>=0.5.0",
"funcy>=1.14",
"fsspec>=2022.10.0",
Expand All @@ -40,6 +39,7 @@ tests = [
"pytest-mock",
"pytest-benchmark",
"reflink",
"tqdm>=4.63.1,<5",
]
dev = [
"dvc-objects[tests]",
Expand Down
47 changes: 34 additions & 13 deletions src/dvc_objects/fs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
import shutil
from functools import partial
from functools import partial, wraps
from multiprocessing import cpu_count
from typing import (
IO,
Expand All @@ -26,16 +26,25 @@
from dvc_objects.executors import ThreadPoolExecutor, batch_coros
from dvc_objects.utils import cached_property

from .callbacks import DEFAULT_CALLBACK, Callback
from .callbacks import (
DEFAULT_CALLBACK,
Callback,
CallbackStream,
wrap_and_branch_callback,
)
from .errors import RemoteMissingDepsError

if TYPE_CHECKING:
from typing import BinaryIO, TextIO
from typing import BinaryIO, Callable, TextIO, TypeVar

from fsspec.spec import AbstractFileSystem
from typing_extensions import ParamSpec

from .path import Path

_P = ParamSpec("_P")
_R = TypeVar("_R")


logger = logging.getLogger(__name__)

Expand All @@ -59,6 +68,16 @@ def __init__(self, link: str, fs: "FileSystem", path: str) -> None:
)


def with_callback(callback: "Callback", fn: "Callable[_P, _R]") -> "Callable[_P, _R]":
@wraps(fn)
def wrapped(*args: "_P.args", **kwargs: "_P.kwargs") -> "_R":
res = fn(*args, **kwargs)
callback.relative_update()
return res

return wrapped


class FileSystem:
sep = "/"

Expand Down Expand Up @@ -338,9 +357,10 @@ def exists(
loop,
)
return fut.result()
executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True)
with executor:
return list(executor.map(callback.wrap_fn(self.fs.exists), path))

func = with_callback(callback, self.fs.exists)
with ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) as executor:
return list(executor.map(func, path))

def lexists(self, path: AnyFSPath) -> bool:
return self.fs.lexists(path)
Expand Down Expand Up @@ -478,10 +498,11 @@ def info(self, path, callback=DEFAULT_CALLBACK, batch_size=None, **kwargs):
loop,
)
return fut.result()
executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True)
with executor:
func = partial(self.fs.info, **kwargs)
return list(executor.map(callback.wrap_fn(func), path))

func = partial(self.fs.info, **kwargs)
wrapped = with_callback(callback, func)
with ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) as executor:
return list(executor.map(wrapped, path))

def mkdir(
self, path: AnyFSPath, create_parents: bool = True, **kwargs: Any
Expand All @@ -502,7 +523,7 @@ def put_file(
if size:
callback.set_size(size)
if hasattr(from_file, "read"):
stream = callback.wrap_attr(cast("BinaryIO", from_file))
stream = cast("BinaryIO", CallbackStream(from_file, callback))
self.upload_fobj(stream, to_info, size=size)
else:
assert isinstance(from_file, str)
Expand Down Expand Up @@ -573,7 +594,7 @@ def put(
callback.set_size(len(from_infos))
executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True)
with executor:
put_file = callback.wrap_and_branch(self.put_file)
put_file = wrap_and_branch_callback(callback, self.put_file)
list(executor.imap_unordered(put_file, from_infos, to_infos))

def get(
Expand All @@ -592,7 +613,7 @@ def get_file(rpath, lpath, **kwargs):
localfs.makedirs(localfs.path.parent(lpath), exist_ok=True)
self.fs.get_file(rpath, lpath, **kwargs)

get_file = callback.wrap_and_branch(get_file)
get_file = wrap_and_branch_callback(callback, get_file)

if isinstance(from_info, list) and isinstance(to_info, list):
from_infos: List[AnyFSPath] = from_info
Expand Down
203 changes: 98 additions & 105 deletions src/dvc_objects/fs/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,44 @@
from contextlib import ExitStack
from functools import wraps
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, Dict, Optional

import fsspec

from dvc_objects.utils import cached_property

if TYPE_CHECKING:
from typing import Awaitable, BinaryIO, Callable, TextIO, Union

from typing_extensions import ParamSpec
from typing import BinaryIO, Callable, Union

from dvc_objects._tqdm import Tqdm

_P = ParamSpec("_P")
_R = TypeVar("_R")


class _CallbackProtocol(Protocol):
def relative_update(self, inc: int = 1) -> None:
...

def branch(
self,
path_1: "Union[str, BinaryIO]",
path_2: str,
kwargs: Dict[str, Any],
child: Optional["Callback"] = None,
) -> "Callback":
...


class _DVCCallbackMixin(_CallbackProtocol):
@overload
def wrap_attr(self, fobj: "BinaryIO", method: str = "read") -> "BinaryIO":
...

@overload
def wrap_attr(self, fobj: "TextIO", method: str = "read") -> "TextIO":
...

def wrap_attr(
self, fobj: "Union[TextIO, BinaryIO]", method: str = "read"
) -> "Union[TextIO, BinaryIO]":
from tqdm.utils import CallbackIOWrapper

wrapped = CallbackIOWrapper(self.relative_update, fobj, method)
return cast("Union[TextIO, BinaryIO]", wrapped)

def wrap_fn(self, fn: "Callable[_P, _R]") -> "Callable[_P, _R]":
@wraps(fn)
def wrapped(*args: "_P.args", **kwargs: "_P.kwargs") -> "_R":
res = fn(*args, **kwargs)
self.relative_update()
return res

return wrapped

def wrap_coro(
self, fn: "Callable[_P, Awaitable[_R]]"
) -> "Callable[_P, Awaitable[_R]]":
@wraps(fn)
async def wrapped(*args: "_P.args", **kwargs: "_P.kwargs") -> "_R":
res = await fn(*args, **kwargs)
self.relative_update()
return res

return wrapped
class CallbackStream:
def __init__(self, stream, callback, method="read"):
self.stream = stream
if method == "write":

def wrap_and_branch(self, fn: "Callable") -> "Callable":
"""
Wraps a function, and pass a new child callback to it.
When the function completes, we increment the parent callback by 1.
"""
wrapped = self.wrap_fn(fn)
@wraps(stream.write)
def write(data, *args, **kwargs):
res = stream.write(data, *args, **kwargs)
callback.relative_update(len(data))
return res

@wraps(fn)
def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
kw: Dict[str, Any] = dict(kwargs)
with self.branch(path1, path2, kw):
return wrapped(path1, path2, **kw)
self.write = write
else:

return func
@wraps(stream.read)
def read(*args, **kwargs):
data = stream.read(*args, **kwargs)
callback.relative_update(len(data))
return data

def wrap_and_branch_coro(self, fn: "Callable") -> "Callable":
"""
Wraps a coroutine, and pass a new child callback to it.
When the coroutine completes, we increment the parent callback by 1.
"""
wrapped = self.wrap_coro(fn)
self.read = read

@wraps(fn)
async def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
kw: Dict[str, Any] = dict(kwargs)
with self.branch(path1, path2, kw):
return await wrapped(path1, path2, **kw)
def __getattr__(self, attr):
return getattr(self.stream, attr)

return func

class ScopedCallback(fsspec.Callback):
def __enter__(self):
return self

Expand All @@ -107,21 +48,24 @@ def __exit__(self, *exc_args):
def close(self):
"""Handle here on exit."""

@classmethod
def as_tqdm_callback(
cls,
callback: Optional[fsspec.callbacks.Callback] = None,
**tqdm_kwargs: Any,
def branch(
self,
path_1: "Union[str, BinaryIO]",
path_2: str,
kwargs: Dict[str, Any],
child: Optional["Callback"] = None,
) -> "Callback":
if callback is None:
return TqdmCallback(**tqdm_kwargs)
if isinstance(callback, Callback):
return callback
return cast("Callback", _FsspecCallbackWrapper(callback))
child = kwargs["callback"] = child or DEFAULT_CALLBACK
return child


class Callback(fsspec.Callback, _DVCCallbackMixin):
"""Callback usable as a context manager, and a few helper methods."""
class Callback(ScopedCallback):
def __getattr__(self, item):
if item in ["wrap_fn", "wrap_coro", "wrap_and_branch", "wrap_and_branch_coro"]:
raise AttributeError(
f"{type(self).__name__!r} object has no attribute {item!r}"
)
return super().__getattr__(item)

def relative_update(self, inc: int = 1) -> None:
inc = inc if inc is not None else 0
Expand All @@ -139,17 +83,17 @@ def as_callback(
return DEFAULT_CALLBACK
if isinstance(maybe_callback, Callback):
return maybe_callback
return _FsspecCallbackWrapper(maybe_callback)
return FsspecCallbackWrapper(maybe_callback)

def branch(
self,
path_1: "Union[str, BinaryIO]",
path_2: str,
kwargs: Dict[str, Any],
child: Optional["Callback"] = None,
@classmethod
def as_tqdm_callback(
cls,
callback: Optional[fsspec.callbacks.Callback] = None,
**tqdm_kwargs: Any,
) -> "Callback":
child = kwargs["callback"] = child or DEFAULT_CALLBACK
return child
if callback is None:
return TqdmCallback(**tqdm_kwargs)
return cls.as_callback(callback)


class NoOpCallback(Callback, fsspec.callbacks.NoOpCallback):
Expand Down Expand Up @@ -209,7 +153,7 @@ def branch(
return super().branch(path_1, path_2, kwargs, child=child)


class _FsspecCallbackWrapper(fsspec.callbacks.Callback, _DVCCallbackMixin):
class FsspecCallbackWrapper(Callback):
def __init__(self, callback: fsspec.callbacks.Callback):
object.__setattr__(self, "_callback", callback)

Expand All @@ -219,8 +163,57 @@ def __getattr__(self, name: str):
def __setattr__(self, name: str, value: Any):
setattr(self._callback, name, value)

def branch(self, *args, **kwargs):
return _FsspecCallbackWrapper(self._callback.branch(*args, **kwargs))
def relative_update(self, inc: int = 1) -> None:
inc = inc if inc is not None else 0
return self._callback.relative_update(inc)

def absolute_update(self, value: int) -> None:
value = value if value is not None else self.value
return self._callback.absolute_update(value)

def branch(
self,
path_1: "Union[str, BinaryIO]",
path_2: str,
kwargs: Dict[str, Any],
child: Optional["Callback"] = None,
) -> "Callback":
if not child:
self._callback.branch(path_1, path_2, kwargs)
child = self.as_callback(kwargs.get("callback"))
return super().branch(path_1, path_2, kwargs, child=child)


def wrap_and_branch_callback(callback: "Callback", fn: "Callable") -> "Callable":
"""
Wraps a function, and pass a new child callback to it.
When the function completes, we increment the parent callback by 1.
"""

@wraps(fn)
def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
with callback.branch(path1, path2, kwargs):
res = fn(path1, path2, **kwargs)
callback.relative_update()
return res

return func


def wrap_and_branch_coro(callback: "Callback", fn: "Callable") -> "Callable":
"""
Wraps a coroutine, and pass a new child callback to it.
When the coroutine completes, we increment the parent callback by 1.
"""

@wraps(fn)
async def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
with callback.branch(path1, path2, kwargs):
res = await fn(path1, path2, **kwargs)
callback.relative_update()
return res

return func


DEFAULT_CALLBACK = NoOpCallback()
Loading