Skip to content

Commit

Permalink
feat: add storage_options to _BaseLanceDatasink, LanceDatasink, Lance…
Browse files Browse the repository at this point in the history
…Committer (#2619)

Adds the ability to pass `storage_options` to `_BaseLanceDatasink` to
allow specifying them for `LanceDatasink` and `LanceCommitter`
(`LanceFragmentWriter` already supports passing these).

We need to customize them because some of our datasets are hosted on
`R2`

---------

Co-authored-by: Will Jones <[email protected]>
  • Loading branch information
BitPhinix and wjones127 authored Jul 23, 2024
1 parent 0969d9b commit fa089be
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 12 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,15 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
python-version: "3.11" # TODO: upgrade when ray supports 3.12
- uses: Swatinem/rust-cache@v2
with:
workspaces: python
prefix-key: "manylinux2014" # use this to flush the cache
- uses: ./.github/workflows/build_linux_wheel
- name: Install dependencies
run:
run: |
pip install ray[data]
pip install torch --index-url https://download.pytorch.org/whl/cpu
- uses: ./.github/workflows/run_integtests
# Make sure wheels are not included in the Rust cache
Expand Down
28 changes: 24 additions & 4 deletions python/python/lance/ray/sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
uri: str,
schema: Optional[pa.Schema] = None,
mode: Literal["create", "append", "overwrite"] = "create",
storage_options: Optional[Dict[str, Any]] = None,
*args,
**kwargs,
):
Expand All @@ -112,14 +113,15 @@ def __init__(
self.mode = mode

self.read_version: int | None = None
self.storage_options = storage_options

@property
def supports_distributed_writes(self) -> bool:
return True

def on_write_start(self):
if self.mode == "append":
ds = lance.LanceDataset(self.uri)
ds = lance.LanceDataset(self.uri, storage_options=self.storage_options)
self.read_version = ds.version
if self.schema is None:
self.schema = ds.schema
Expand All @@ -139,7 +141,12 @@ def on_write_complete(
op = lance.LanceOperation.Overwrite(schema, fragments)
elif self.mode == "append":
op = lance.LanceOperation.Append(fragments)
lance.LanceDataset.commit(self.uri, op, read_version=self.read_version)
lance.LanceDataset.commit(
self.uri,
op,
read_version=self.read_version,
storage_options=self.storage_options,
)


class LanceDatasink(_BaseLanceDatasink):
Expand All @@ -163,6 +170,8 @@ class LanceDatasink(_BaseLanceDatasink):
The maximum number of rows per file. Default is 1024 * 1024.
use_legacy_format : bool, optional
Set True to use the legacy v1 format. Default is False
storage_options : Dict[str, Any], optional
The storage options for the writer. Default is None.
"""

NAME = "Lance"
Expand All @@ -174,10 +183,18 @@ def __init__(
mode: Literal["create", "append", "overwrite"] = "create",
max_rows_per_file: int = 1024 * 1024,
use_legacy_format: bool = False,
storage_options: Optional[Dict[str, Any]] = None,
*args,
**kwargs,
):
super().__init__(uri, schema=schema, mode=mode, *args, **kwargs)
super().__init__(
uri,
schema=schema,
mode=mode,
storage_options=storage_options,
*args,
**kwargs,
)

self.max_rows_per_file = max_rows_per_file
self.use_legacy_format = use_legacy_format
Expand Down Expand Up @@ -206,6 +223,7 @@ def write(
schema=self.schema,
max_rows_per_file=self.max_rows_per_file,
use_legacy_format=self.use_legacy_format,
storage_options=self.storage_options,
)
return [
(pickle.dumps(fragment), pickle.dumps(schema))
Expand Down Expand Up @@ -360,7 +378,9 @@ def write_lance(
storage_options=storage_options,
),
batch_size=max_rows_per_file,
).write_datasink(LanceCommitter(output_uri, schema=schema))
).write_datasink(
LanceCommitter(output_uri, schema=schema, storage_options=storage_options)
)


def _register_hooks():
Expand Down
44 changes: 43 additions & 1 deletion python/python/tests/test_s3_ddb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
"""
Integration tests with S3 and DynamoDB.
Integration tests with S3 and DynamoDB. Also used to test storage_options are
passed correctly.
See DEVELOPMENT.md under heading "Integration Tests" for more information.
"""
Expand All @@ -16,6 +17,8 @@
import lance
import pyarrow as pa
import pytest
from lance.dependencies import _RAY_AVAILABLE, ray
from lance.fragment import write_fragments

# These are all keys that are accepted by storage_options
CONFIG = {
Expand Down Expand Up @@ -221,3 +224,42 @@ def test_s3_unsafe(s3_bucket: str):
assert len(ds.versions()) == 1
assert ds.count_rows() == 3
assert ds.to_table() == data


@pytest.mark.integration
def test_s3_ddb_distributed_commit(s3_bucket: str, ddb_table: str):
table_name = uuid.uuid4().hex
table_dir = f"s3+ddb://{s3_bucket}/{table_name}?ddbTableName={ddb_table}"

schema = pa.schema([pa.field("a", pa.int64())])
fragments = write_fragments(
pa.table({"a": pa.array(range(1024))}),
f"s3+ddb://{s3_bucket}/distributed_commit?ddbTableName={ddb_table}",
storage_options=CONFIG,
)
operation = lance.LanceOperation.Overwrite(schema, fragments)
ds = lance.LanceDataset.commit(table_dir, operation, storage_options=CONFIG)
assert ds.count_rows() == 1024


@pytest.mark.integration
@pytest.mark.skipif(not _RAY_AVAILABLE, reason="ray is not available")
def test_ray_committer(s3_bucket: str, ddb_table: str):
from lance.ray.sink import write_lance

table_name = uuid.uuid4().hex
table_dir = f"s3+ddb://{s3_bucket}/{table_name}?ddbTableName={ddb_table}"

schema = pa.schema([pa.field("id", pa.int64()), pa.field("str", pa.string())])

ds = ray.data.range(10).map(lambda x: {"id": x["id"], "str": f"str-{x['id']}"})
write_lance(ds, table_dir, schema=schema, storage_options=CONFIG)

ds = lance.dataset(table_dir, storage_options=CONFIG)
assert ds.count_rows() == 10
assert ds.schema == schema

tbl = ds.to_table()
assert sorted(tbl["id"].to_pylist()) == list(range(10))
assert set(tbl["str"].to_pylist()) == set([f"str-{i}" for i in range(10)])
assert len(ds.get_fragments()) == 1
20 changes: 15 additions & 5 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1066,18 +1066,28 @@ impl Dataset {
commit_lock: Option<&PyAny>,
storage_options: Option<HashMap<String, String>>,
) -> PyResult<Self> {
let object_store_params = storage_options.map(|storage_options| ObjectStoreParams {
storage_options: Some(storage_options),
..Default::default()
});
let object_store_params =
storage_options
.as_ref()
.map(|storage_options| ObjectStoreParams {
storage_options: Some(storage_options.clone()),
..Default::default()
});

let commit_handler = commit_lock.map(|commit_lock| {
Arc::new(PyCommitLock::new(commit_lock.to_object(commit_lock.py())))
as Arc<dyn CommitHandler>
});
let ds = RT
.block_on(commit_lock.map(|cl| cl.py()), async move {
let dataset = match DatasetBuilder::from_uri(dataset_uri).load().await {
let mut builder = DatasetBuilder::from_uri(dataset_uri);
if let Some(storage_options) = storage_options {
builder = builder.with_storage_options(storage_options);
}
if let Some(read_version) = read_version {
builder = builder.with_version(read_version);
}
let dataset = match builder.load().await {
Ok(ds) => Some(ds),
Err(lance::Error::DatasetNotFound { .. }) => None,
Err(err) => return Err(err),
Expand Down

0 comments on commit fa089be

Please sign in to comment.