diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index df6e712cfa..d710d8653b 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -218,14 +218,15 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.12" + python-version: "3.11" # TODO: upgrade when ray supports 3.12 - uses: Swatinem/rust-cache@v2 with: workspaces: python prefix-key: "manylinux2014" # use this to flush the cache - uses: ./.github/workflows/build_linux_wheel - name: Install dependencies - run: + run: | + pip install ray[data] pip install torch --index-url https://download.pytorch.org/whl/cpu - uses: ./.github/workflows/run_integtests # Make sure wheels are not included in the Rust cache diff --git a/python/python/lance/ray/sink.py b/python/python/lance/ray/sink.py index ea6d7eab33..706b7a2a25 100644 --- a/python/python/lance/ray/sink.py +++ b/python/python/lance/ray/sink.py @@ -102,6 +102,7 @@ def __init__( uri: str, schema: Optional[pa.Schema] = None, mode: Literal["create", "append", "overwrite"] = "create", + storage_options: Optional[Dict[str, Any]] = None, *args, **kwargs, ): @@ -112,6 +113,7 @@ def __init__( self.mode = mode self.read_version: int | None = None + self.storage_options = storage_options @property def supports_distributed_writes(self) -> bool: @@ -119,7 +121,7 @@ def supports_distributed_writes(self) -> bool: def on_write_start(self): if self.mode == "append": - ds = lance.LanceDataset(self.uri) + ds = lance.LanceDataset(self.uri, storage_options=self.storage_options) self.read_version = ds.version if self.schema is None: self.schema = ds.schema @@ -139,7 +141,12 @@ def on_write_complete( op = lance.LanceOperation.Overwrite(schema, fragments) elif self.mode == "append": op = lance.LanceOperation.Append(fragments) - lance.LanceDataset.commit(self.uri, op, read_version=self.read_version) + lance.LanceDataset.commit( + self.uri, + op, + read_version=self.read_version, + storage_options=self.storage_options, + ) class LanceDatasink(_BaseLanceDatasink): @@ -163,6 +170,8 @@ class LanceDatasink(_BaseLanceDatasink): The maximum number of rows per file. Default is 1024 * 1024. use_legacy_format : bool, optional Set True to use the legacy v1 format. Default is False + storage_options : Dict[str, Any], optional + The storage options for the writer. Default is None. """ NAME = "Lance" @@ -174,10 +183,18 @@ def __init__( mode: Literal["create", "append", "overwrite"] = "create", max_rows_per_file: int = 1024 * 1024, use_legacy_format: bool = False, + storage_options: Optional[Dict[str, Any]] = None, *args, **kwargs, ): - super().__init__(uri, schema=schema, mode=mode, *args, **kwargs) + super().__init__( + uri, + schema=schema, + mode=mode, + storage_options=storage_options, + *args, + **kwargs, + ) self.max_rows_per_file = max_rows_per_file self.use_legacy_format = use_legacy_format @@ -206,6 +223,7 @@ def write( schema=self.schema, max_rows_per_file=self.max_rows_per_file, use_legacy_format=self.use_legacy_format, + storage_options=self.storage_options, ) return [ (pickle.dumps(fragment), pickle.dumps(schema)) @@ -360,7 +378,9 @@ def write_lance( storage_options=storage_options, ), batch_size=max_rows_per_file, - ).write_datasink(LanceCommitter(output_uri, schema=schema)) + ).write_datasink( + LanceCommitter(output_uri, schema=schema, storage_options=storage_options) + ) def _register_hooks(): diff --git a/python/python/tests/test_s3_ddb.py b/python/python/tests/test_s3_ddb.py index 9191e084c5..b81003993f 100644 --- a/python/python/tests/test_s3_ddb.py +++ b/python/python/tests/test_s3_ddb.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors """ -Integration tests with S3 and DynamoDB. +Integration tests with S3 and DynamoDB. Also used to test storage_options are +passed correctly. See DEVELOPMENT.md under heading "Integration Tests" for more information. """ @@ -16,6 +17,8 @@ import lance import pyarrow as pa import pytest +from lance.dependencies import _RAY_AVAILABLE, ray +from lance.fragment import write_fragments # These are all keys that are accepted by storage_options CONFIG = { @@ -221,3 +224,42 @@ def test_s3_unsafe(s3_bucket: str): assert len(ds.versions()) == 1 assert ds.count_rows() == 3 assert ds.to_table() == data + + +@pytest.mark.integration +def test_s3_ddb_distributed_commit(s3_bucket: str, ddb_table: str): + table_name = uuid.uuid4().hex + table_dir = f"s3+ddb://{s3_bucket}/{table_name}?ddbTableName={ddb_table}" + + schema = pa.schema([pa.field("a", pa.int64())]) + fragments = write_fragments( + pa.table({"a": pa.array(range(1024))}), + f"s3+ddb://{s3_bucket}/distributed_commit?ddbTableName={ddb_table}", + storage_options=CONFIG, + ) + operation = lance.LanceOperation.Overwrite(schema, fragments) + ds = lance.LanceDataset.commit(table_dir, operation, storage_options=CONFIG) + assert ds.count_rows() == 1024 + + +@pytest.mark.integration +@pytest.mark.skipif(not _RAY_AVAILABLE, reason="ray is not available") +def test_ray_committer(s3_bucket: str, ddb_table: str): + from lance.ray.sink import write_lance + + table_name = uuid.uuid4().hex + table_dir = f"s3+ddb://{s3_bucket}/{table_name}?ddbTableName={ddb_table}" + + schema = pa.schema([pa.field("id", pa.int64()), pa.field("str", pa.string())]) + + ds = ray.data.range(10).map(lambda x: {"id": x["id"], "str": f"str-{x['id']}"}) + write_lance(ds, table_dir, schema=schema, storage_options=CONFIG) + + ds = lance.dataset(table_dir, storage_options=CONFIG) + assert ds.count_rows() == 10 + assert ds.schema == schema + + tbl = ds.to_table() + assert sorted(tbl["id"].to_pylist()) == list(range(10)) + assert set(tbl["str"].to_pylist()) == set([f"str-{i}" for i in range(10)]) + assert len(ds.get_fragments()) == 1 diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 71f2e9d0a2..b8863c4c05 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1066,10 +1066,13 @@ impl Dataset { commit_lock: Option<&PyAny>, storage_options: Option>, ) -> PyResult { - let object_store_params = storage_options.map(|storage_options| ObjectStoreParams { - storage_options: Some(storage_options), - ..Default::default() - }); + let object_store_params = + storage_options + .as_ref() + .map(|storage_options| ObjectStoreParams { + storage_options: Some(storage_options.clone()), + ..Default::default() + }); let commit_handler = commit_lock.map(|commit_lock| { Arc::new(PyCommitLock::new(commit_lock.to_object(commit_lock.py()))) @@ -1077,7 +1080,14 @@ impl Dataset { }); let ds = RT .block_on(commit_lock.map(|cl| cl.py()), async move { - let dataset = match DatasetBuilder::from_uri(dataset_uri).load().await { + let mut builder = DatasetBuilder::from_uri(dataset_uri); + if let Some(storage_options) = storage_options { + builder = builder.with_storage_options(storage_options); + } + if let Some(read_version) = read_version { + builder = builder.with_version(read_version); + } + let dataset = match builder.load().await { Ok(ds) => Some(ds), Err(lance::Error::DatasetNotFound { .. }) => None, Err(err) => return Err(err),