Skip to content

Commit

Permalink
changes for feat:add-parquet-support PR review (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
deependujha authored Feb 6, 2025
1 parent 0e208ff commit c6edaa7
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 12 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -643,14 +643,18 @@ The `overwrite` mode will delete the existing data and start from fresh.
</details>

<details>
<summary> ✅ Index parquet datasets</summary>
<summary> ✅ Stream parquet datasets</summary>
&nbsp;

If your dataset is already in Parquet format, you can index it directly and use it with StreamingDataset & DataLoader.
You can stream Parquet datasets directly without the need to convert them into the LitData optimized binary format.

If your dataset is already in Parquet format, you can index and use it with StreamingDataset and DataLoader for efficient streaming.

Assumption:
Your dataset directory contains one or more Parquet files.

- **Index Parquet dataset**:

```python
import litdata as ld

Expand All @@ -659,6 +663,8 @@ pq_data_uri = "gs://deep-litdata-parquet/my-parquet-data"
ld.index_parquet_dataset(pq_data_uri)
```

- **Stream the dataset with `StreamingDataset` and `ParquetLoader`**

When using a Streaming Dataset, ensure you use `ParquetLoader`:

```python
Expand Down
2 changes: 1 addition & 1 deletion src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
_AZURE_STORAGE_AVAILABLE = RequirementCache("azure.storage.blob")
_TQDM_AVAILABLE = RequirementCache("tqdm")
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
_POLARS_AVAILABLE = RequirementCache("polars")
_POLARS_AVAILABLE = RequirementCache("polars>1.0.0")
_DEBUG = bool(int(os.getenv("DEBUG", "1")))

_MAX_WAIT_TIME = int(os.getenv("MAX_WAIT_TIME", "120"))
Expand Down
10 changes: 8 additions & 2 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,10 @@ def encode_data(cls, data: List[bytes], _: List[int], flattened: List[Any]) -> T
class ParquetLoader(BaseItemLoader):
def __init__(self) -> None:
if not _POLARS_AVAILABLE:
raise ModuleNotFoundError("Please, run: `pip install polars`")
raise ModuleNotFoundError(
"You are using the Parquet item loader, which depends on `Polars > 1.0.0`.",
"Please, run: `pip install polars>1.0.0`",
)
self._chunk_filepaths: Dict[str, bool] = {}

def setup(
Expand Down Expand Up @@ -458,7 +461,10 @@ def generate_intervals(self) -> List[Interval]:

def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
"""Logic to load the chunk in background to gain some time."""
pass
import polars as pl

if chunk_filepath not in self._df:
self._df[chunk_filepath] = pl.scan_parquet(chunk_filepath).collect()

def load_item_from_chunk(
self,
Expand Down
24 changes: 17 additions & 7 deletions tests/streaming/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import random
import sys
from collections import OrderedDict

import numpy as np
import pytest
Expand Down Expand Up @@ -278,14 +279,16 @@ def test_parquet_index_write(tmpdir):

os.mkdir(os.path.join(tmpdir, "data"))

pq_data = OrderedDict(
{
"name": ["Tom", "Jerry", "Micky", "Oggy", "Doraemon"],
"weight": [57.9, 72.5, 53.6, 83.1, 69.4], # (kg)
"height": [1.56, 1.77, 1.65, 1.75, 1.63], # (m)
}
)

for i in range(5):
df = pl.DataFrame(
{
"name": ["Tom", "Jerry", "Micky", "Oggy", "Doraemon"],
"weight": [57.9, 72.5, 53.6, 83.1, 69.4], # (kg)
"height": [1.56, 1.77, 1.65, 1.75, 1.63], # (m)
}
)
df = pl.DataFrame(pq_data)
file_path = os.path.join(tmpdir, "data", f"tmp-{i}.parquet")
df.write_parquet(file_path)

Expand All @@ -307,3 +310,10 @@ def test_parquet_index_write(tmpdir):
ds = StreamingDataset(os.path.join(tmpdir, "data"), item_loader=ParquetLoader())

assert len(ds) == 25 # 5 datasets for 5 loops

for i, _ds in enumerate(ds):
idx = i % 5
assert len(_ds) == 3
assert _ds[0] == pq_data["name"][idx]
assert _ds[1] == pq_data["weight"][idx]
assert _ds[2] == pq_data["height"][idx]

0 comments on commit c6edaa7

Please sign in to comment.