diff --git a/ibis/backends/__init__.py b/ibis/backends/__init__.py index f1c06f1a4adb..3e85111f26bb 100644 --- a/ibis/backends/__init__.py +++ b/ibis/backends/__init__.py @@ -4,6 +4,7 @@ import collections.abc import contextlib import functools +import glob import importlib.metadata import keyword import re @@ -22,6 +23,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, MutableMapping + from io import BytesIO from urllib.parse import ParseResult import pandas as pd @@ -1269,6 +1271,100 @@ def has_operation(cls, operation: type[ops.Value]) -> bool: f"{cls.name} backend has not implemented `has_operation` API" ) + @util.experimental + def read_parquet( + self, path: str | Path | BytesIO, table_name: str | None = None, **kwargs: Any + ) -> ir.Table: + """Register a parquet file as a table in the current backend. + + This function reads a Parquet file and registers it as a table in the current + backend. Note that for Impala and Trino backends, the performance + may be suboptimal. + + Parameters + ---------- + path + The data source. May be a path to a file, glob pattern to match Parquet files, + directory of parquet files, or BytseIO. + table_name + An optional name to use for the created table. This defaults to + a sequentially generated name. + **kwargs + Additional keyword arguments passed to the pyarrow loading function. + See https://arrow.apache.org/docs/python/generated/pyarrow.parquet.read_table.html + for more information. + + Returns + ------- + ir.Table + The just-registered table + + Examples + -------- + Connect to a SQLite database: + + >>> con = ibis.sqlite.connect() + + Read a single parquet file: + + >>> table = con.read_parquet("path/to/file.parquet") + + Read all parquet files in a directory: + + >>> table = con.read_parquet("path/to/parquet_directory/") + + Read parquet files with a glob pattern + + >>> table = con.read_parquet("path/to/parquet_directory/data_*.parquet") + + Read from Amazon S3 + + >>> table = con.read_parquet("s3://bucket-name/path/to/file.parquet") + + Read from Google Cloud Storage + + >>> table = con.read_parquet("gs://bucket-name/path/to/file.parquet") + + Read with a custom table name + + >>> table = con.read_parquet("s3://bucket/data.parquet", table_name="my_table") + + Read with additional pyarrow options + + >>> table = con.read_parquet("gs://bucket/data.parquet", columns=["col1", "col2"]) + + Read from Amazon S3 with secret info + + >>> from pyarrow import fs + >>> s3_fs = fs.S3FileSystem( + ... access_key="YOUR_ACCESS_KEY", secret_key="YOUR_SECRET_KEY", region="YOUR_AWS_REGION" + ... ) + >>> table = con.read_parquet("s3://bucket/data.parquet", filesystem=s3_fs) + + Read from HTTPS URL + + >>> import fsspec + >>> from io import BytesIO + >>> url = "https://example.com/data/file.parquet" + >>> credentials = {} + >>> f = fsspec.open(url, **credentials).open() + >>> reader = BytesIO(f.read()) + >>> table = con.read_parquet(reader) + >>> reader.close() + >>> f.close() + """ + import pyarrow.parquet as pq + + table_name = table_name or util.gen_name("read_parquet") + paths = list(glob.glob(str(path))) + if paths: + table = pq.read_table(paths, **kwargs) + else: + table = pq.read_table(path, **kwargs) + + self.create_table(table_name, table) + return self.table(table_name) + def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str: # only transpile if dialect was passed if dialect is None: diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_register.py index 0ed6925a2ce7..e655e81dcdd3 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_register.py @@ -4,6 +4,7 @@ import csv import gzip import os +from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING @@ -21,7 +22,6 @@ import pyarrow as pa pytestmark = [ - pytest.mark.notimpl(["druid", "exasol", "oracle"]), pytest.mark.notyet( ["pyspark"], condition=IS_SPARK_REMOTE, raises=PySparkAnalysisException ), @@ -103,6 +103,7 @@ def gzip_csv(data_dir, tmp_path): "trino", ] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): with pushd(data_dir / "csv"): with pytest.warns(FutureWarning, match="v9.1"): @@ -114,7 +115,7 @@ def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): # TODO: rewrite or delete test when register api is removed -@pytest.mark.notimpl(["datafusion"]) +@pytest.mark.notimpl(["datafusion", "druid", "exasol", "oracle"]) @pytest.mark.notyet( [ "bigquery", @@ -154,6 +155,7 @@ def test_register_csv_gz(con, data_dir, gzip_csv): "trino", ] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_register_with_dotted_name(con, data_dir, tmp_path): basename = "foo.bar.baz/diamonds.csv" f = tmp_path.joinpath(basename) @@ -211,6 +213,7 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: "trino", ] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_register_parquet( con, tmp_path, data_dir, fname, in_table_name, out_table_name ): @@ -249,6 +252,7 @@ def test_register_parquet( "trino", ] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_register_iterator_parquet( con, tmp_path, @@ -277,7 +281,7 @@ def test_register_iterator_parquet( # TODO: remove entirely when `register` is removed # This same functionality is implemented across all backends # via `create_table` and tested in `test_client.py` -@pytest.mark.notimpl(["datafusion"]) +@pytest.mark.notimpl(["datafusion", "druid", "exasol", "oracle"]) @pytest.mark.notyet( [ "bigquery", @@ -311,7 +315,7 @@ def test_register_pandas(con): # TODO: remove entirely when `register` is removed # This same functionality is implemented across all backends # via `create_table` and tested in `test_client.py` -@pytest.mark.notimpl(["datafusion", "polars"]) +@pytest.mark.notimpl(["datafusion", "polars", "druid", "exasol", "oracle"]) @pytest.mark.notyet( [ "bigquery", @@ -352,6 +356,7 @@ def test_register_pyarrow_tables(con): "trino", ] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_csv_reregister_schema(con, tmp_path): foo = tmp_path.joinpath("foo.csv") with foo.open("w", newline="") as csvfile: @@ -380,10 +385,13 @@ def test_csv_reregister_schema(con, tmp_path): "bigquery", "clickhouse", "datafusion", + "druid", + "exasol", "flink", "impala", "mysql", "mssql", + "oracle", "polars", "postgres", "risingwave", @@ -414,12 +422,17 @@ def test_register_garbage(con, monkeypatch): ("functional_alltypes.parquet", "funk_all"), ], ) -@pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] -) +@pytest.mark.notyet(["flink"]) +@pytest.mark.notimpl(["druid"]) def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name): pq = pytest.importorskip("pyarrow.parquet") + if con.name in ("trino", "impala"): + # TODO: remove after trino and impala have efficient insertion + pytest.skip( + "Both Impala and Trino lack efficient data insertion methods from Python." + ) + fname = Path(fname) fname = Path(data_dir) / "parquet" / fname.name table = pq.read_table(fname) @@ -445,18 +458,8 @@ def ft_data(data_dir): return table.slice(0, nrows) -@pytest.mark.notyet( - [ - "flink", - "impala", - "mssql", - "mysql", - "postgres", - "risingwave", - "sqlite", - "trino", - ] -) +@pytest.mark.notyet(["flink"]) +@pytest.mark.notimpl(["druid"]) def test_read_parquet_glob(con, tmp_path, ft_data): pq = pytest.importorskip("pyarrow.parquet") @@ -473,6 +476,30 @@ def test_read_parquet_glob(con, tmp_path, ft_data): assert table.count().execute() == nrows * ntables +@pytest.mark.notyet(["flink"]) +@pytest.mark.notimpl(["druid"]) +@pytest.mark.never( + [ + "duckdb", + "polars", + "bigquery", + "clickhouse", + "datafusion", + "snowflake", + "pyspark", + ], + reason="backend implements its own read_parquet", +) +def test_read_parquet_bytesio(con, ft_data): + pq = pytest.importorskip("pyarrow.parquet") + + bytes_io = BytesIO() + pq.write_table(ft_data, bytes_io) + bytes_io.seek(0) + table = con.read_parquet(bytes_io) + assert table.count().execute() == ft_data.num_rows + + @pytest.mark.notyet( [ "flink", @@ -485,6 +512,7 @@ def test_read_parquet_glob(con, tmp_path, ft_data): "trino", ] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_read_csv_glob(con, tmp_path, ft_data): pc = pytest.importorskip("pyarrow.csv") @@ -519,6 +547,7 @@ def test_read_csv_glob(con, tmp_path, ft_data): raises=ValueError, reason="read_json() missing required argument: 'schema'", ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_read_json_glob(con, tmp_path, ft_data): nrows = len(ft_data) ntables = 2 @@ -565,6 +594,7 @@ def num_diamonds(data_dir): @pytest.mark.notyet( ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] ) +@pytest.mark.notimpl(["druid", "exasol", "oracle"]) def test_read_csv(con, data_dir, in_table_name, num_diamonds): fname = "diamonds.csv" with pushd(data_dir / "csv"):