Skip to content

Commit

Permalink
cleanup callbacks (#259)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Dec 13, 2023
1 parent 9be478e commit 7fde798
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 130 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ classifiers = [
requires-python = ">=3.8"
dynamic = ["version"]
dependencies = [
"tqdm>=4.63.1,<5",
"shortuuid>=0.5.0",
"funcy>=1.14",
"fsspec>=2022.10.0",
Expand All @@ -40,6 +39,7 @@ tests = [
"pytest-mock",
"pytest-benchmark",
"reflink",
"tqdm>=4.63.1,<5",
]
dev = [
"dvc-objects[tests]",
Expand Down
47 changes: 34 additions & 13 deletions src/dvc_objects/fs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
import shutil
from functools import partial
from functools import partial, wraps
from multiprocessing import cpu_count
from typing import (
IO,
Expand All @@ -26,16 +26,25 @@
from dvc_objects.executors import ThreadPoolExecutor, batch_coros
from dvc_objects.utils import cached_property

from .callbacks import DEFAULT_CALLBACK, Callback
from .callbacks import (
DEFAULT_CALLBACK,
Callback,
CallbackStream,
wrap_and_branch_callback,
)
from .errors import RemoteMissingDepsError

if TYPE_CHECKING:
from typing import BinaryIO, TextIO
from typing import BinaryIO, Callable, TextIO, TypeVar

from fsspec.spec import AbstractFileSystem
from typing_extensions import ParamSpec

from .path import Path

_P = ParamSpec("_P")
_R = TypeVar("_R")


logger = logging.getLogger(__name__)

Expand All @@ -59,6 +68,16 @@ def __init__(self, link: str, fs: "FileSystem", path: str) -> None:
)


def with_callback(callback: "Callback", fn: "Callable[_P, _R]") -> "Callable[_P, _R]":
@wraps(fn)
def wrapped(*args: "_P.args", **kwargs: "_P.kwargs") -> "_R":
res = fn(*args, **kwargs)
callback.relative_update()
return res

return wrapped


class FileSystem:
sep = "/"

Expand Down Expand Up @@ -338,9 +357,10 @@ def exists(
loop,
)
return fut.result()
executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True)
with executor:
return list(executor.map(callback.wrap_fn(self.fs.exists), path))

func = with_callback(callback, self.fs.exists)
with ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) as executor:
return list(executor.map(func, path))

def lexists(self, path: AnyFSPath) -> bool:
return self.fs.lexists(path)
Expand Down Expand Up @@ -478,10 +498,11 @@ def info(self, path, callback=DEFAULT_CALLBACK, batch_size=None, **kwargs):
loop,
)
return fut.result()
executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True)
with executor:
func = partial(self.fs.info, **kwargs)
return list(executor.map(callback.wrap_fn(func), path))

func = partial(self.fs.info, **kwargs)
wrapped = with_callback(callback, func)
with ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True) as executor:
return list(executor.map(wrapped, path))

def mkdir(
self, path: AnyFSPath, create_parents: bool = True, **kwargs: Any
Expand All @@ -502,7 +523,7 @@ def put_file(
if size:
callback.set_size(size)
if hasattr(from_file, "read"):
stream = callback.wrap_attr(cast("BinaryIO", from_file))
stream = cast("BinaryIO", CallbackStream(from_file, callback))
self.upload_fobj(stream, to_info, size=size)
else:
assert isinstance(from_file, str)
Expand Down Expand Up @@ -573,7 +594,7 @@ def put(
callback.set_size(len(from_infos))
executor = ThreadPoolExecutor(max_workers=jobs, cancel_on_error=True)
with executor:
put_file = callback.wrap_and_branch(self.put_file)
put_file = wrap_and_branch_callback(callback, self.put_file)
list(executor.imap_unordered(put_file, from_infos, to_infos))

def get(
Expand All @@ -592,7 +613,7 @@ def get_file(rpath, lpath, **kwargs):
localfs.makedirs(localfs.path.parent(lpath), exist_ok=True)
self.fs.get_file(rpath, lpath, **kwargs)

get_file = callback.wrap_and_branch(get_file)
get_file = wrap_and_branch_callback(callback, get_file)

if isinstance(from_info, list) and isinstance(to_info, list):
from_infos: List[AnyFSPath] = from_info
Expand Down
203 changes: 98 additions & 105 deletions src/dvc_objects/fs/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,44 @@
from contextlib import ExitStack
from functools import wraps
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, Dict, Optional

import fsspec

from dvc_objects.utils import cached_property

if TYPE_CHECKING:
from typing import Awaitable, BinaryIO, Callable, TextIO, Union

from typing_extensions import ParamSpec
from typing import BinaryIO, Callable, Union

from dvc_objects._tqdm import Tqdm

_P = ParamSpec("_P")
_R = TypeVar("_R")


class _CallbackProtocol(Protocol):
def relative_update(self, inc: int = 1) -> None:
...

def branch(
self,
path_1: "Union[str, BinaryIO]",
path_2: str,
kwargs: Dict[str, Any],
child: Optional["Callback"] = None,
) -> "Callback":
...


class _DVCCallbackMixin(_CallbackProtocol):
@overload
def wrap_attr(self, fobj: "BinaryIO", method: str = "read") -> "BinaryIO":
...

@overload
def wrap_attr(self, fobj: "TextIO", method: str = "read") -> "TextIO":
...

def wrap_attr(
self, fobj: "Union[TextIO, BinaryIO]", method: str = "read"
) -> "Union[TextIO, BinaryIO]":
from tqdm.utils import CallbackIOWrapper

wrapped = CallbackIOWrapper(self.relative_update, fobj, method)
return cast("Union[TextIO, BinaryIO]", wrapped)

def wrap_fn(self, fn: "Callable[_P, _R]") -> "Callable[_P, _R]":
@wraps(fn)
def wrapped(*args: "_P.args", **kwargs: "_P.kwargs") -> "_R":
res = fn(*args, **kwargs)
self.relative_update()
return res

return wrapped

def wrap_coro(
self, fn: "Callable[_P, Awaitable[_R]]"
) -> "Callable[_P, Awaitable[_R]]":
@wraps(fn)
async def wrapped(*args: "_P.args", **kwargs: "_P.kwargs") -> "_R":
res = await fn(*args, **kwargs)
self.relative_update()
return res

return wrapped
class CallbackStream:
def __init__(self, stream, callback, method="read"):
self.stream = stream
if method == "write":

def wrap_and_branch(self, fn: "Callable") -> "Callable":
"""
Wraps a function, and pass a new child callback to it.
When the function completes, we increment the parent callback by 1.
"""
wrapped = self.wrap_fn(fn)
@wraps(stream.write)
def write(data, *args, **kwargs):
res = stream.write(data, *args, **kwargs)
callback.relative_update(len(data))
return res

@wraps(fn)
def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
kw: Dict[str, Any] = dict(kwargs)
with self.branch(path1, path2, kw):
return wrapped(path1, path2, **kw)
self.write = write
else:

return func
@wraps(stream.read)
def read(*args, **kwargs):
data = stream.read(*args, **kwargs)
callback.relative_update(len(data))
return data

def wrap_and_branch_coro(self, fn: "Callable") -> "Callable":
"""
Wraps a coroutine, and pass a new child callback to it.
When the coroutine completes, we increment the parent callback by 1.
"""
wrapped = self.wrap_coro(fn)
self.read = read

@wraps(fn)
async def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
kw: Dict[str, Any] = dict(kwargs)
with self.branch(path1, path2, kw):
return await wrapped(path1, path2, **kw)
def __getattr__(self, attr):
return getattr(self.stream, attr)

return func

class ScopedCallback(fsspec.Callback):
def __enter__(self):
return self

Expand All @@ -107,21 +48,24 @@ def __exit__(self, *exc_args):
def close(self):
"""Handle here on exit."""

@classmethod
def as_tqdm_callback(
cls,
callback: Optional[fsspec.callbacks.Callback] = None,
**tqdm_kwargs: Any,
def branch(
self,
path_1: "Union[str, BinaryIO]",
path_2: str,
kwargs: Dict[str, Any],
child: Optional["Callback"] = None,
) -> "Callback":
if callback is None:
return TqdmCallback(**tqdm_kwargs)
if isinstance(callback, Callback):
return callback
return cast("Callback", _FsspecCallbackWrapper(callback))
child = kwargs["callback"] = child or DEFAULT_CALLBACK
return child


class Callback(fsspec.Callback, _DVCCallbackMixin):
"""Callback usable as a context manager, and a few helper methods."""
class Callback(ScopedCallback):
def __getattr__(self, item):
if item in ["wrap_fn", "wrap_coro", "wrap_and_branch", "wrap_and_branch_coro"]:
raise AttributeError(
f"{type(self).__name__!r} object has no attribute {item!r}"
)
return super().__getattr__(item)

def relative_update(self, inc: int = 1) -> None:
inc = inc if inc is not None else 0
Expand All @@ -139,17 +83,17 @@ def as_callback(
return DEFAULT_CALLBACK
if isinstance(maybe_callback, Callback):
return maybe_callback
return _FsspecCallbackWrapper(maybe_callback)
return FsspecCallbackWrapper(maybe_callback)

def branch(
self,
path_1: "Union[str, BinaryIO]",
path_2: str,
kwargs: Dict[str, Any],
child: Optional["Callback"] = None,
@classmethod
def as_tqdm_callback(
cls,
callback: Optional[fsspec.callbacks.Callback] = None,
**tqdm_kwargs: Any,
) -> "Callback":
child = kwargs["callback"] = child or DEFAULT_CALLBACK
return child
if callback is None:
return TqdmCallback(**tqdm_kwargs)
return cls.as_callback(callback)


class NoOpCallback(Callback, fsspec.callbacks.NoOpCallback):
Expand Down Expand Up @@ -209,7 +153,7 @@ def branch(
return super().branch(path_1, path_2, kwargs, child=child)


class _FsspecCallbackWrapper(fsspec.callbacks.Callback, _DVCCallbackMixin):
class FsspecCallbackWrapper(Callback):
def __init__(self, callback: fsspec.callbacks.Callback):
object.__setattr__(self, "_callback", callback)

Expand All @@ -219,8 +163,57 @@ def __getattr__(self, name: str):
def __setattr__(self, name: str, value: Any):
setattr(self._callback, name, value)

def branch(self, *args, **kwargs):
return _FsspecCallbackWrapper(self._callback.branch(*args, **kwargs))
def relative_update(self, inc: int = 1) -> None:
inc = inc if inc is not None else 0
return self._callback.relative_update(inc)

def absolute_update(self, value: int) -> None:
value = value if value is not None else self.value
return self._callback.absolute_update(value)

def branch(
self,
path_1: "Union[str, BinaryIO]",
path_2: str,
kwargs: Dict[str, Any],
child: Optional["Callback"] = None,
) -> "Callback":
if not child:
self._callback.branch(path_1, path_2, kwargs)
child = self.as_callback(kwargs.get("callback"))
return super().branch(path_1, path_2, kwargs, child=child)


def wrap_and_branch_callback(callback: "Callback", fn: "Callable") -> "Callable":
"""
Wraps a function, and pass a new child callback to it.
When the function completes, we increment the parent callback by 1.
"""

@wraps(fn)
def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
with callback.branch(path1, path2, kwargs):
res = fn(path1, path2, **kwargs)
callback.relative_update()
return res

return func


def wrap_and_branch_coro(callback: "Callback", fn: "Callable") -> "Callable":
"""
Wraps a coroutine, and pass a new child callback to it.
When the coroutine completes, we increment the parent callback by 1.
"""

@wraps(fn)
async def func(path1: "Union[str, BinaryIO]", path2: str, **kwargs):
with callback.branch(path1, path2, kwargs):
res = await fn(path1, path2, **kwargs)
callback.relative_update()
return res

return func


DEFAULT_CALLBACK = NoOpCallback()
Loading

0 comments on commit 7fde798

Please sign in to comment.