Skip to content

Commit

Permalink
Add support for s3_folders (#450)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Jan 14, 2025
1 parent 1eaa79f commit 08bf77d
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/litdata/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import time

__version__ = "0.2.35"
__version__ = "0.2.36"
__author__ = "Lightning AI et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
14 changes: 13 additions & 1 deletion src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,22 @@ def _remove_target(input_dir: Dir, cache_dir: str, queue_in: Queue) -> None:
if os.path.exists(path):
os.remove(path)

elif os.path.exists(path) and "s3_connections" not in path:
elif keep_path(path) and os.path.exists(path):
os.remove(path)


def keep_path(path: str) -> bool:
paths = [
"efs_connections",
"efs_folders",
"gcs_connections",
"s3_connections",
"s3_folders",
"snowflake_connections",
]
return all(p not in path for p in paths)


def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_dir: Dir) -> None:
"""Upload optimised chunks from a local to remote dataset directory."""
obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path)
Expand Down
9 changes: 6 additions & 3 deletions src/litdata/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,12 @@ def _apply_delete(self, chunk_index: int) -> None:
chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)]
self._item_loader.delete(chunk_index, chunk_filepath)

locak_chunk_path = chunk_filepath + ".lock"
if os.path.exists(locak_chunk_path):
os.remove(locak_chunk_path)
try:
locak_chunk_path = chunk_filepath + ".lock"
if os.path.exists(locak_chunk_path):
os.remove(locak_chunk_path)
except FileNotFoundError:
pass

def stop(self) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
Expand Down
2 changes: 1 addition & 1 deletion src/litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _resolve_s3_folders(dir_path: str) -> Dir:
if not data_connection:
raise ValueError(f"We didn't find any matching data connection with the provided name `{target_name}`.")

return Dir(path=dir_path, url=data_connection[0].s3_folder.source)
return Dir(path=dir_path, url=os.path.join(data_connection[0].s3_folder.source, *dir_path.split("/")[4:]))


def _resolve_datasets(dir_path: str) -> Dir:
Expand Down
8 changes: 7 additions & 1 deletion src/litdata/utilities/dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ def _should_replace_path(path: Optional[str]) -> bool:
if path is None or path == "":
return True

return path.startswith("/teamspace/datasets/") or path.startswith("/teamspace/s3_connections/")
return (
path.startswith("/teamspace/datasets/")
or path.startswith("/teamspace/s3_connections/")
or path.startswith("/teamspace/s3_folders/")
or path.startswith("/teamspace/gcs_folders/")
or path.startswith("/teamspace/gcs_connections/")
)


def _read_updated_at(input_dir: Optional[Dir], storage_options: Optional[Dict] = {}) -> str:
Expand Down
2 changes: 1 addition & 1 deletion tests/streaming/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_src_resolver_s3_folders(monkeypatch, lightning_cloud_mock):

expected = "s3://imagenet-bucket"
assert resolver._resolve_dir("/teamspace/s3_folders/debug_folder").url == expected

assert resolver._resolve_dir("/teamspace/s3_folders/debug_folder/a/b/c").url == expected + "/a/b/c"
auth.clear()


Expand Down
3 changes: 3 additions & 0 deletions tests/utilities/test_dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def test_should_replace_path():
assert not _should_replace_path(".../s3__connections/...")
assert _should_replace_path("/teamspace/datasets/...")
assert _should_replace_path("/teamspace/s3_connections/...")
assert _should_replace_path("/teamspace/s3_folders/...")
assert _should_replace_path("/teamspace/gcs_folders/...")
assert _should_replace_path("/teamspace/gcs_connections/...")
assert not _should_replace_path("something_else")


Expand Down

0 comments on commit 08bf77d

Please sign in to comment.