Skip to content

Commit

Permalink
AsyncGenerator → AsyncIterator
Browse files Browse the repository at this point in the history
Refactor Store list methods.
  • Loading branch information
DimitriPapadopoulos committed Oct 16, 2024
1 parent 0dc2bc1 commit 844ccad
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 80 deletions.
6 changes: 3 additions & 3 deletions docs/roadmap.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,13 @@ expose the required methods as async methods.
async def set_partial_values(self, key_start_values: List[Tuple[str, int, Union[bytes, bytearray, memoryview]]]) -> None:
... # required for writable stores
async def list(self) -> List[str]:
async def list(self) -> AsyncIterator[str]:
... # required for listable stores
async def list_prefix(self, prefix: str) -> List[str]:
async def list_prefix(self, prefix: str) -> AsyncIterator[str]:
... # required for listable stores
async def list_dir(self, prefix: str) -> List[str]:
async def list_dir(self, prefix: str) -> AsyncIterator[str]:
... # required for listable stores
# additional (optional methods)
Expand Down
20 changes: 10 additions & 10 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterable
from collections.abc import AsyncGenerator, AsyncIterator, Iterable
from types import TracebackType
from typing import Any, Self, TypeAlias

Expand Down Expand Up @@ -329,17 +329,17 @@ def supports_listing(self) -> bool:
...

@abstractmethod
async def list(self) -> AsyncGenerator[str]:
async def list(self) -> AsyncIterator[str]:
"""Retrieve all keys in the store.
Returns
-------
AsyncGenerator[str, None]
AsyncIterator[str, None]
"""
yield ""
return iter([])

@abstractmethod
async def list_prefix(self, prefix: str) -> AsyncGenerator[str]:
async def list_prefix(self, prefix: str) -> AsyncIterator[str]:
"""
Retrieve all keys in the store that begin with a given prefix. Keys are returned with the
common leading prefix removed.
Expand All @@ -350,12 +350,12 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str]:
Returns
-------
AsyncGenerator[str, None]
AsyncIterator[str, None]
"""
yield ""
return iter([])

@abstractmethod
async def list_dir(self, prefix: str) -> AsyncGenerator[str]:
async def list_dir(self, prefix: str) -> AsyncIterator[str]:
"""
Retrieve all keys and prefixes with a given prefix and which do not contain the character
“/” after the given prefix.
Expand All @@ -366,9 +366,9 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str]:
Returns
-------
AsyncGenerator[str, None]
AsyncIterator[str, None]
"""
yield ""
return iter([])

def close(self) -> None:
"""Close the store."""
Expand Down
30 changes: 14 additions & 16 deletions src/zarr/storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from zarr.core.common import concurrent_map

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterable
from collections.abc import AsyncIterator, Iterable

from zarr.core.buffer import BufferPrototype
from zarr.core.common import AccessModeLiteral
Expand Down Expand Up @@ -217,28 +217,26 @@ async def exists(self, key: str) -> bool:
path = self.root / key
return await asyncio.to_thread(path.is_file)

async def list(self) -> AsyncGenerator[str]:
async def list(self) -> AsyncIterator[str]:
# docstring inherited
to_strip = str(self.root) + "/"
for p in list(self.root.rglob("*")):
if p.is_file():
yield str(p).replace(to_strip, "")
allfiles = [str(p).replace(to_strip, "") for p in self.root.rglob("*") if p.is_file()]
return (onefile for onefile in allfiles)

async def list_prefix(self, prefix: str) -> AsyncGenerator[str]:
async def list_prefix(self, prefix: str) -> AsyncIterator[str]:
# docstring inherited
to_strip = os.path.join(str(self.root / prefix))
for p in (self.root / prefix).rglob("*"):
if p.is_file():
yield str(p.relative_to(to_strip))

async def list_dir(self, prefix: str) -> AsyncGenerator[str]:
allfiles = [
str(p.relative_to(to_strip)) for p in (self.root / prefix).rglob("*") if p.is_file()
]
return (onefile for onefile in allfiles)

async def list_dir(self, prefix: str) -> AsyncIterator[str]:
# docstring inherited
base = self.root / prefix
to_strip = str(base) + "/"

try:
key_iter = base.iterdir()
for key in key_iter:
yield str(key).replace(to_strip, "")
allfiles = [str(key).replace(to_strip, "") for key in base.iterdir()]
return (onefile for onefile in allfiles)
except (FileNotFoundError, NotADirectoryError):
pass
return iter([])
17 changes: 7 additions & 10 deletions src/zarr/storage/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from zarr.abc.store import AccessMode, ByteRangeRequest, Store

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator, Iterable
from collections.abc import AsyncIterator, Generator, Iterable

from zarr.core.buffer import Buffer, BufferPrototype
from zarr.core.common import AccessModeLiteral
Expand Down Expand Up @@ -204,23 +204,20 @@ async def set_partial_values(
with self.log(keys):
return await self._store.set_partial_values(key_start_values=key_start_values)

async def list(self) -> AsyncGenerator[str]:
async def list(self) -> AsyncIterator[str]:
# docstring inherited
with self.log():
async for key in self._store.list():
yield key
return (key async for key in self._store.list())

async def list_prefix(self, prefix: str) -> AsyncGenerator[str]:
async def list_prefix(self, prefix: str) -> AsyncIterator[str]:
# docstring inherited
with self.log(prefix):
async for key in self._store.list_prefix(prefix=prefix):
yield key
return (key async for key in self._store.list_prefix(prefix=prefix))

async def list_dir(self, prefix: str) -> AsyncGenerator[str]:
async def list_dir(self, prefix: str) -> AsyncIterator[str]:
# docstring inherited
with self.log(prefix):
async for key in self._store.list_dir(prefix=prefix):
yield key
return (key async for key in self._store.list_dir(prefix=prefix))

def with_mode(self, mode: AccessModeLiteral) -> Self:
# docstring inherited
Expand Down
18 changes: 7 additions & 11 deletions src/zarr/storage/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from zarr.storage._utils import _normalize_interval_index

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterable, MutableMapping
from collections.abc import AsyncIterator, Iterable, MutableMapping

from zarr.core.buffer import BufferPrototype
from zarr.core.common import AccessModeLiteral
Expand Down Expand Up @@ -143,18 +143,15 @@ async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, by
# docstring inherited
raise NotImplementedError

async def list(self) -> AsyncGenerator[str]:
async def list(self) -> AsyncIterator[str]:
# docstring inherited
for key in self._store_dict:
yield key
return (key for key in self._store_dict)

async def list_prefix(self, prefix: str) -> AsyncGenerator[str]:
async def list_prefix(self, prefix: str) -> AsyncIterator[str]:
# docstring inherited
for key in self._store_dict:
if key.startswith(prefix):
yield key.removeprefix(prefix)
return (key.removeprefix(prefix) for key in self._store_dict if key.startswith(prefix))

async def list_dir(self, prefix: str) -> AsyncGenerator[str]:
async def list_dir(self, prefix: str) -> AsyncIterator[str]:
# docstring inherited
if prefix.endswith("/"):
prefix = prefix[:-1]
Expand All @@ -171,8 +168,7 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str]:
if key.startswith(prefix + "/") and key != prefix
}

for key in keys_unique:
yield key
return (key for key in keys_unique)


class GpuMemoryStore(MemoryStore):
Expand Down
23 changes: 12 additions & 11 deletions src/zarr/storage/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from zarr.storage.common import _dereference_path

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterable
from collections.abc import AsyncIterator, Iterable

from fsspec.asyn import AsyncFileSystem

Expand Down Expand Up @@ -280,24 +280,25 @@ async def set_partial_values(
# docstring inherited
raise NotImplementedError

async def list(self) -> AsyncGenerator[str]:
async def list(self) -> AsyncIterator[str]:
# docstring inherited
allfiles = await self.fs._find(self.path, detail=False, withdirs=False)
for onefile in (a.replace(self.path + "/", "") for a in allfiles):
yield onefile
return (onefile for onefile in (a.replace(self.path + "/", "") for a in allfiles))

async def list_dir(self, prefix: str) -> AsyncGenerator[str]:
async def list_dir(self, prefix: str) -> AsyncIterator[str]:
# docstring inherited
prefix = f"{self.path}/{prefix.rstrip('/')}"
try:
allfiles = await self.fs._ls(prefix, detail=False)
except FileNotFoundError:
return
for onefile in (a.replace(prefix + "/", "") for a in allfiles):
yield onefile.removeprefix(self.path).removeprefix("/")
return iter([])
return (
onefile.removeprefix(self.path).removeprefix("/")
for onefile in (a.replace(prefix + "/", "") for a in allfiles)
)

async def list_prefix(self, prefix: str) -> AsyncGenerator[str]:
async def list_prefix(self, prefix: str) -> AsyncIterator[str]:
# docstring inherited
find_str = f"{self.path}/{prefix}"
for onefile in await self.fs._find(find_str, detail=False, maxdepth=None, withdirs=False):
yield onefile.removeprefix(find_str)
allfiles = await self.fs._find(find_str, detail=False, maxdepth=None, withdirs=False)
return (onefile.removeprefix(find_str) for onefile in allfiles)
32 changes: 13 additions & 19 deletions src/zarr/storage/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from zarr.core.buffer import Buffer, BufferPrototype

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterable
from collections.abc import AsyncIterator, Iterable

ZipStoreAccessModeLiteral = Literal["r", "w", "a"]

Expand Down Expand Up @@ -231,19 +231,16 @@ async def exists(self, key: str) -> bool:
else:
return True

async def list(self) -> AsyncGenerator[str]:
async def list(self) -> AsyncIterator[str]:
# docstring inherited
with self._lock:
for key in self._zf.namelist():
yield key
return (key for key in self._zf.namelist())

async def list_prefix(self, prefix: str) -> AsyncGenerator[str]:
async def list_prefix(self, prefix: str) -> AsyncIterator[str]:
# docstring inherited
async for key in self.list():
if key.startswith(prefix):
yield key.removeprefix(prefix)
return (key.removeprefix(prefix) async for key in self.list() if key.startswith(prefix))

async def list_dir(self, prefix: str) -> AsyncGenerator[str]:
async def list_dir(self, prefix: str) -> AsyncIterator[str]:
# docstring inherited
if prefix.endswith("/"):
prefix = prefix[:-1]
Expand All @@ -252,14 +249,11 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str]:
seen = set()
if prefix == "":
keys_unique = {k.split("/")[0] for k in keys}
for key in keys_unique:
if key not in seen:
seen.add(key)
yield key
return (key for key in keys_unique if key not in seen and not seen.add(key))
else:
for key in keys:
if key.startswith(prefix + "/") and key != prefix:
k = key.removeprefix(prefix + "/").split("/")[0]
if k not in seen:
seen.add(k)
yield k
return (
k for key in keys
if key.startswith(prefix + "/") and key != prefix
for k in [key.removeprefix(prefix + "/").split("/")[0]]
if k not in seen and not seen.add(k)
)

0 comments on commit 844ccad

Please sign in to comment.