diff --git a/python/cudf/cudf/io/csv.py b/python/cudf/cudf/io/csv.py index 01f1fdf9020..4694243ad18 100644 --- a/python/cudf/cudf/io/csv.py +++ b/python/cudf/cudf/io/csv.py @@ -59,17 +59,10 @@ def read_csv( "`read_csv` does not yet support reading multiple files" ) - # Only need to pass byte_ranges to get_filepath_or_buffer - # if `use_python_file_object=False` - byte_ranges = None - if not use_python_file_object and byte_range: - byte_ranges = [byte_range] - filepath_or_buffer, compression = ioutils.get_filepath_or_buffer( path_or_data=filepath_or_buffer, compression=compression, iotypes=(BytesIO, StringIO, NativeFile), - byte_ranges=byte_ranges, use_python_file_object=use_python_file_object, **kwargs, ) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index 3e73e0c9e3d..a919b00692d 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -1,14 +1,11 @@ # Copyright (c) 2019-2022, NVIDIA CORPORATION. -import io -import json import warnings from collections import defaultdict from contextlib import ExitStack from typing import Dict, List, Tuple from uuid import uuid4 -import fsspec import numpy as np from pyarrow import dataset as ds, parquet as pq @@ -310,103 +307,6 @@ def _process_dataset( ) -def _get_byte_ranges(file_list, row_groups, columns, fs, **kwargs): - - # This utility is used to collect the footer metadata - # from a parquet file. This metadata is used to define - # the exact byte-ranges that will be needed to read the - # target column-chunks from the file. - # - # This utility is only used for remote storage. - # - # The calculated byte-range information is used within - # cudf.io.ioutils.get_filepath_or_buffer (which uses - # _fsspec_data_transfer to convert non-local fsspec file - # objects into local byte buffers). - - if row_groups is None: - if columns is None: - return None, None, None # No reason to construct this - row_groups = [None for path in file_list] - - # Construct a list of required byte-ranges for every file - all_byte_ranges, all_footers, all_sizes = [], [], [] - for path, rgs in zip(file_list, row_groups): - - # Step 0 - Get size of file - if fs is None: - file_size = path.size - else: - file_size = fs.size(path) - - # Step 1 - Get 32 KB from tail of file. - # - # This "sample size" can be tunable, but should - # always be >= 8 bytes (so we can read the footer size) - tail_size = min(kwargs.get("footer_sample_size", 32_000), file_size,) - if fs is None: - path.seek(file_size - tail_size) - footer_sample = path.read(tail_size) - else: - footer_sample = fs.tail(path, tail_size) - - # Step 2 - Read the footer size and re-read a larger - # tail if necessary - footer_size = int.from_bytes(footer_sample[-8:-4], "little") - if tail_size < (footer_size + 8): - if fs is None: - path.seek(file_size - (footer_size + 8)) - footer_sample = path.read(footer_size + 8) - else: - footer_sample = fs.tail(path, footer_size + 8) - - # Step 3 - Collect required byte ranges - byte_ranges = [] - md = pq.ParquetFile(io.BytesIO(footer_sample)).metadata - column_set = None if columns is None else set(columns) - if column_set is not None: - schema = md.schema.to_arrow_schema() - has_pandas_metadata = ( - schema.metadata is not None and b"pandas" in schema.metadata - ) - if has_pandas_metadata: - md_index = [ - ind - for ind in json.loads( - schema.metadata[b"pandas"].decode("utf8") - ).get("index_columns", []) - # Ignore RangeIndex information - if not isinstance(ind, dict) - ] - column_set |= set(md_index) - for r in range(md.num_row_groups): - # Skip this row-group if we are targetting - # specific row-groups - if rgs is None or r in rgs: - row_group = md.row_group(r) - for c in range(row_group.num_columns): - column = row_group.column(c) - name = column.path_in_schema - # Skip this column if we are targetting a - # specific columns - split_name = name.split(".")[0] - if ( - column_set is None - or name in column_set - or split_name in column_set - ): - file_offset0 = column.dictionary_page_offset - if file_offset0 is None: - file_offset0 = column.data_page_offset - num_bytes = column.total_compressed_size - byte_ranges.append((file_offset0, num_bytes)) - - all_byte_ranges.append(byte_ranges) - all_footers.append(footer_sample) - all_sizes.append(file_size) - return all_byte_ranges, all_footers, all_sizes - - @ioutils.doc_read_parquet() def read_parquet( filepath_or_buffer, @@ -418,13 +318,24 @@ def read_parquet( num_rows=None, strings_to_categorical=False, use_pandas_metadata=True, - use_python_file_object=False, + use_python_file_object=True, categorical_partitions=True, + open_file_options=None, *args, **kwargs, ): """{docstring}""" + # Do not allow the user to set file-opening options + # when `use_python_file_object=False` is specified + if use_python_file_object is False: + if open_file_options: + raise ValueError( + "open_file_options is not currently supported when " + "use_python_file_object is set to False." + ) + open_file_options = {} + # Multiple sources are passed as a list. If a single source is passed, # wrap it in a list for unified processing downstream. if not is_list_like(filepath_or_buffer): @@ -470,38 +381,18 @@ def read_parquet( raise ValueError("cudf cannot apply filters to open file objects.") filepath_or_buffer = paths if paths else filepath_or_buffer - # Check if we should calculate the specific byte-ranges - # needed for each parquet file. We always do this when we - # have a file-system object to work with and it is not a - # local filesystem object. We can also do it without a - # file-system object for `AbstractBufferedFile` buffers - byte_ranges, footers, file_sizes = None, None, None - if not use_python_file_object: - need_byte_ranges = fs is not None and not ioutils._is_local_filesystem( - fs - ) - if need_byte_ranges or ( - filepath_or_buffer - and isinstance( - filepath_or_buffer[0], fsspec.spec.AbstractBufferedFile, - ) - ): - byte_ranges, footers, file_sizes = _get_byte_ranges( - filepath_or_buffer, row_groups, columns, fs, **kwargs - ) - filepaths_or_buffers = [] + if use_python_file_object: + open_file_options = _default_open_file_options( + open_file_options, columns, row_groups, fs=fs, + ) for i, source in enumerate(filepath_or_buffer): - tmp_source, compression = ioutils.get_filepath_or_buffer( path_or_data=source, compression=None, fs=fs, - byte_ranges=byte_ranges[i] if byte_ranges else None, - footer=footers[i] if footers else None, - file_size=file_sizes[i] if file_sizes else None, - add_par1_magic=True, use_python_file_object=use_python_file_object, + open_file_options=open_file_options, **kwargs, ) @@ -953,3 +844,41 @@ def __enter__(self): def __exit__(self, *args): self.close() + + +def _default_open_file_options( + open_file_options, columns, row_groups, fs=None +): + """ + Set default fields in open_file_options. + + Copies and updates `open_file_options` to + include column and row-group information + under the "precache_options" key. By default, + we set "method" to "parquet", but precaching + will be disabled if the user chooses `method=None` + + Parameters + ---------- + open_file_options : dict or None + columns : list + row_groups : list + fs : fsspec.AbstractFileSystem, Optional + """ + if fs and ioutils._is_local_filesystem(fs): + # Quick return for local fs + return open_file_options or {} + # Assume remote storage if `fs` was not specified + open_file_options = (open_file_options or {}).copy() + precache_options = open_file_options.pop("precache_options", {}).copy() + if precache_options.get("method", "parquet") == "parquet": + precache_options.update( + { + "method": "parquet", + "engine": precache_options.get("engine", "pyarrow"), + "columns": columns, + "row_groups": row_groups, + } + ) + open_file_options["precache_options"] = precache_options + return open_file_options diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 519f24b7ca6..21556aad1eb 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -748,7 +748,10 @@ def test_parquet_reader_arrow_nativefile(parquet_path_or_buf): assert_eq(expect, got) -def test_parquet_reader_use_python_file_object(parquet_path_or_buf): +@pytest.mark.parametrize("use_python_file_object", [True, False]) +def test_parquet_reader_use_python_file_object( + parquet_path_or_buf, use_python_file_object +): # Check that the non-default `use_python_file_object=True` # option works as expected expect = cudf.read_parquet(parquet_path_or_buf("filepath")) @@ -756,11 +759,15 @@ def test_parquet_reader_use_python_file_object(parquet_path_or_buf): # Pass open fsspec file with fs.open(paths[0], mode="rb") as fil: - got1 = cudf.read_parquet(fil, use_python_file_object=True) + got1 = cudf.read_parquet( + fil, use_python_file_object=use_python_file_object + ) assert_eq(expect, got1) # Pass path only - got2 = cudf.read_parquet(paths[0], use_python_file_object=True) + got2 = cudf.read_parquet( + paths[0], use_python_file_object=use_python_file_object + ) assert_eq(expect, got2) diff --git a/python/cudf/cudf/tests/test_s3.py b/python/cudf/cudf/tests/test_s3.py index 5738e1f0d00..da1ffc1fc16 100644 --- a/python/cudf/cudf/tests/test_s3.py +++ b/python/cudf/cudf/tests/test_s3.py @@ -131,6 +131,9 @@ def pdf_ext(scope="module"): df["Integer"] = np.array([i for i in range(size)]) df["List"] = [[i] for i in range(size)] df["Struct"] = [{"a": i} for i in range(size)] + df["String"] = (["Alpha", "Beta", "Gamma", "Delta"] * (-(size // -4)))[ + :size + ] return df @@ -225,9 +228,16 @@ def test_write_csv(s3_base, s3so, pdf, chunksize): @pytest.mark.parametrize("bytes_per_thread", [32, 1024]) @pytest.mark.parametrize("columns", [None, ["Float", "String"]]) -@pytest.mark.parametrize("use_python_file_object", [False, True]) +@pytest.mark.parametrize("precache", [None, "parquet"]) +@pytest.mark.parametrize("use_python_file_object", [True, False]) def test_read_parquet( - s3_base, s3so, pdf, bytes_per_thread, columns, use_python_file_object + s3_base, + s3so, + pdf, + bytes_per_thread, + columns, + precache, + use_python_file_object, ): fname = "test_parquet_reader.parquet" bname = "parquet" @@ -239,10 +249,15 @@ def test_read_parquet( with s3_context(s3_base=s3_base, bucket=bname, files={fname: buffer}): got1 = cudf.read_parquet( "s3://{}/{}".format(bname, fname), - use_python_file_object=use_python_file_object, + open_file_options=( + {"precache_options": {"method": precache}} + if use_python_file_object + else None + ), storage_options=s3so, bytes_per_thread=bytes_per_thread, columns=columns, + use_python_file_object=use_python_file_object, ) expect = pdf[columns] if columns else pdf assert_eq(expect, got1) @@ -256,25 +271,18 @@ def test_read_parquet( with fs.open("s3://{}/{}".format(bname, fname), mode="rb") as f: got2 = cudf.read_parquet( f, - use_python_file_object=use_python_file_object, bytes_per_thread=bytes_per_thread, columns=columns, + use_python_file_object=use_python_file_object, ) assert_eq(expect, got2) @pytest.mark.parametrize("bytes_per_thread", [32, 1024]) @pytest.mark.parametrize("columns", [None, ["List", "Struct"]]) -@pytest.mark.parametrize("use_python_file_object", [False, True]) @pytest.mark.parametrize("index", [None, "Integer"]) def test_read_parquet_ext( - s3_base, - s3so, - pdf_ext, - bytes_per_thread, - columns, - use_python_file_object, - index, + s3_base, s3so, pdf_ext, bytes_per_thread, columns, index, ): fname = "test_parquet_reader_ext.parquet" bname = "parquet" @@ -290,7 +298,6 @@ def test_read_parquet_ext( with s3_context(s3_base=s3_base, bucket=bname, files={fname: buffer}): got1 = cudf.read_parquet( "s3://{}/{}".format(bname, fname), - use_python_file_object=use_python_file_object, storage_options=s3so, bytes_per_thread=bytes_per_thread, footer_sample_size=3200, @@ -326,12 +333,12 @@ def test_read_parquet_arrow_nativefile(s3_base, s3so, pdf, columns): assert_eq(expect, got) -@pytest.mark.parametrize("python_file", [True, False]) -def test_read_parquet_filters(s3_base, s3so, pdf, python_file): +@pytest.mark.parametrize("precache", [None, "parquet"]) +def test_read_parquet_filters(s3_base, s3so, pdf_ext, precache): fname = "test_parquet_reader_filters.parquet" bname = "parquet" buffer = BytesIO() - pdf.to_parquet(path=buffer) + pdf_ext.to_parquet(path=buffer) buffer.seek(0) filters = [("String", "==", "Omega")] with s3_context(s3_base=s3_base, bucket=bname, files={fname: buffer}): @@ -339,11 +346,11 @@ def test_read_parquet_filters(s3_base, s3so, pdf, python_file): "s3://{}/{}".format(bname, fname), storage_options=s3so, filters=filters, - use_python_file_object=python_file, + open_file_options={"precache_options": {"method": precache}}, ) # All row-groups should be filtered out - assert_eq(pdf.iloc[:0], got.reset_index(drop=True)) + assert_eq(pdf_ext.iloc[:0], got.reset_index(drop=True)) @pytest.mark.parametrize("partition_cols", [None, ["String"]]) diff --git a/python/cudf/cudf/utils/ioutils.py b/python/cudf/cudf/utils/ioutils.py index 6f958860dad..8f8a40ae4ab 100644 --- a/python/cudf/cudf/utils/ioutils.py +++ b/python/cudf/cudf/utils/ioutils.py @@ -3,6 +3,7 @@ import datetime import os import urllib +import warnings from io import BufferedWriter, BytesIO, IOBase, TextIOWrapper from threading import Thread @@ -17,6 +18,13 @@ from cudf.utils.docutils import docfmt_partial +try: + import fsspec.parquet as fsspec_parquet + +except ImportError: + fsspec_parquet = None + + _docstring_remote_sources = """ - cuDF supports local and remote data stores. See configuration details for available sources @@ -160,10 +168,17 @@ use_pandas_metadata : boolean, default True If True and dataset has custom PANDAS schema metadata, ensure that index columns are also loaded. -use_python_file_object : boolean, default False +use_python_file_object : boolean, default True If True, Arrow-backed PythonFile objects will be used in place of fsspec - AbstractBufferedFile objects at IO time. This option is likely to improve - performance when making small reads from larger parquet files. + AbstractBufferedFile objects at IO time. Setting this argument to `False` + will require the entire file to be copied to host memory, and is highly + discouraged. +open_file_options : dict, optional + Dictionary of key-value pairs to pass to the function used to open remote + files. By default, this will be `fsspec.parquet.open_parquet_file`. To + deactivate optimized precaching, set the "method" to `None` under the + "precache_options" key. Note that the `open_file_func` key can also be + used to specify a custom file-open function. Returns ------- @@ -1220,6 +1235,100 @@ def _get_filesystem_and_paths(path_or_data, **kwargs): return fs, return_paths +def _set_context(obj, stack): + # Helper function to place open file on context stack + if stack is None: + return obj + return stack.enter_context(obj) + + +def _open_remote_files( + paths, + fs, + context_stack=None, + open_file_func=None, + precache_options=None, + **kwargs, +): + """Return a list of open file-like objects given + a list of remote file paths. + + Parameters + ---------- + paths : list(str) + List of file-path strings. + fs : fsspec.AbstractFileSystem + Fsspec file-system object. + context_stack : contextlib.ExitStack, Optional + Context manager to use for open files. + open_file_func : Callable, Optional + Call-back function to use for opening. If this argument + is specified, all other arguments will be ignored. + precache_options : dict, optional + Dictionary of key-word arguments to pass to use for + precaching. Unless the input contains ``{"method": None}``, + ``fsspec.parquet.open_parquet_file`` will be used for remote + storage. + **kwargs : + Key-word arguments to be passed to format-specific + open functions. + """ + + # Just use call-back function if one was specified + if open_file_func is not None: + return [ + _set_context(open_file_func(path, **kwargs), context_stack) + for path in paths + ] + + # Check if the "precache" option is supported. + # In the future, fsspec should do this check for us + precache_options = (precache_options or {}).copy() + precache = precache_options.pop("method", None) + if precache not in ("parquet", None): + raise ValueError(f"{precache} not a supported `precache` option.") + + # Check that "parts" caching (used for all format-aware file handling) + # is supported by the installed fsspec/s3fs version + if precache == "parquet" and not fsspec_parquet: + warnings.warn( + f"This version of fsspec ({fsspec.__version__}) does " + f"not support parquet-optimized precaching. Please upgrade " + f"to the latest fsspec version for better performance." + ) + precache = None + + if precache == "parquet": + # Use fsspec.parquet module. + # TODO: Use `cat_ranges` to collect "known" + # parts for all files at once. + row_groups = precache_options.pop("row_groups", None) or ( + [None] * len(paths) + ) + return [ + ArrowPythonFile( + _set_context( + fsspec_parquet.open_parquet_file( + path, + fs=fs, + row_groups=rgs, + **precache_options, + **kwargs, + ), + context_stack, + ) + ) + for path, rgs in zip(paths, row_groups) + ] + + # Default open - Use pyarrow filesystem API + pa_fs = PyFileSystem(FSSpecHandler(fs)) + return [ + _set_context(pa_fs.open_input_file(fpath), context_stack) + for fpath in paths + ] + + def get_filepath_or_buffer( path_or_data, compression, @@ -1228,6 +1337,7 @@ def get_filepath_or_buffer( iotypes=(BytesIO, NativeFile), byte_ranges=None, use_python_file_object=False, + open_file_options=None, **kwargs, ): """Return either a filepath string to data, or a memory buffer of data. @@ -1249,6 +1359,9 @@ def get_filepath_or_buffer( use_python_file_object : boolean, default False If True, Arrow-backed PythonFile objects will be used in place of fsspec AbstractBufferedFile objects. + open_file_options : dict, optional + Optional dictionary of key-word arguments to pass to + `_open_remote_files` (used for remote storage only). Returns ------- @@ -1282,19 +1395,14 @@ def get_filepath_or_buffer( else: if use_python_file_object: - pa_fs = PyFileSystem(FSSpecHandler(fs)) - path_or_data = [ - pa_fs.open_input_file(fpath) for fpath in paths - ] + path_or_data = _open_remote_files( + paths, fs, **(open_file_options or {}), + ) else: path_or_data = [ BytesIO( _fsspec_data_transfer( - fpath, - fs=fs, - mode=mode, - byte_ranges=byte_ranges, - **kwargs, + fpath, fs=fs, mode=mode, **kwargs, ) ) for fpath in paths @@ -1309,9 +1417,7 @@ def get_filepath_or_buffer( path_or_data = ArrowPythonFile(path_or_data) else: path_or_data = BytesIO( - _fsspec_data_transfer( - path_or_data, mode=mode, byte_ranges=byte_ranges, **kwargs - ) + _fsspec_data_transfer(path_or_data, mode=mode, **kwargs) ) return path_or_data, compression @@ -1545,10 +1651,7 @@ def _ensure_filesystem(passed_filesystem, path, **kwargs): def _fsspec_data_transfer( path_or_fob, fs=None, - byte_ranges=None, - footer=None, file_size=None, - add_par1_magic=None, bytes_per_thread=256_000_000, max_gap=64_000, mode="rb", @@ -1568,48 +1671,22 @@ def _fsspec_data_transfer( file_size = file_size or fs.size(path_or_fob) # Check if a direct read makes the most sense - if not byte_ranges and bytes_per_thread >= file_size: + if bytes_per_thread >= file_size: if file_like: return path_or_fob.read() else: - return fs.open(path_or_fob, mode=mode, cache_type="none").read() + return fs.open(path_or_fob, mode=mode, cache_type="all").read() # Threaded read into "local" buffer buf = np.zeros(file_size, dtype="b") - if byte_ranges: - - # Optimize/merge the ranges - byte_ranges = _merge_ranges( - byte_ranges, max_block=bytes_per_thread, max_gap=max_gap, - ) - - # Call multi-threaded data transfer of - # remote byte-ranges to local buffer - _read_byte_ranges( - path_or_fob, byte_ranges, buf, fs=fs, **kwargs, - ) - - # Add Header & Footer bytes - if footer is not None: - footer_size = len(footer) - buf[-footer_size:] = np.frombuffer( - footer[-footer_size:], dtype="b" - ) - # Add parquet magic bytes (optional) - if add_par1_magic: - buf[:4] = np.frombuffer(b"PAR1", dtype="b") - if footer is None: - buf[-4:] = np.frombuffer(b"PAR1", dtype="b") - - else: - byte_ranges = [ - (b, min(bytes_per_thread, file_size - b)) - for b in range(0, file_size, bytes_per_thread) - ] - _read_byte_ranges( - path_or_fob, byte_ranges, buf, fs=fs, **kwargs, - ) + byte_ranges = [ + (b, min(bytes_per_thread, file_size - b)) + for b in range(0, file_size, bytes_per_thread) + ] + _read_byte_ranges( + path_or_fob, byte_ranges, buf, fs=fs, **kwargs, + ) return buf.tobytes() diff --git a/python/dask_cudf/dask_cudf/io/parquet.py b/python/dask_cudf/dask_cudf/io/parquet.py index a49d73493ec..ac5795fa2ec 100644 --- a/python/dask_cudf/dask_cudf/io/parquet.py +++ b/python/dask_cudf/dask_cudf/io/parquet.py @@ -20,7 +20,9 @@ import cudf from cudf.core.column import as_column, build_categorical_column from cudf.io import write_to_dataset +from cudf.io.parquet import _default_open_file_options from cudf.utils.dtypes import cudf_dtype_from_pa_type +from cudf.utils.ioutils import _is_local_filesystem, _open_remote_files class CudfEngine(ArrowDatasetEngine): @@ -64,6 +66,7 @@ def _read_paths( partitions=None, partitioning=None, partition_keys=None, + open_file_options=None, **kwargs, ): @@ -75,15 +78,15 @@ def _read_paths( # Non-local filesystem handling paths_or_fobs = paths - if not cudf.utils.ioutils._is_local_filesystem(fs): - - # Convert paths to file objects for remote data - paths_or_fobs = [ - stack.enter_context( - fs.open(path, mode="rb", cache_type="none") - ) - for path in paths - ] + if not _is_local_filesystem(fs): + paths_or_fobs = _open_remote_files( + paths_or_fobs, + fs, + context_stack=stack, + **_default_open_file_options( + open_file_options, columns, row_groups + ), + ) # Use cudf to read in data df = cudf.read_parquet( @@ -150,6 +153,7 @@ def read_partition( partitions=(), partitioning=None, schema=None, + open_file_options=None, **kwargs, ): @@ -168,7 +172,10 @@ def read_partition( if not isinstance(pieces, list): pieces = [pieces] + # Extract supported kwargs from `kwargs` strings_to_cats = kwargs.get("strings_to_categorical", False) + read_kwargs = kwargs.get("read", {}) + read_kwargs.update(open_file_options or {}) # Assume multi-piece read paths = [] @@ -192,7 +199,7 @@ def read_partition( partitions=partitions, partitioning=partitioning, partition_keys=last_partition_keys, - **kwargs.get("read", {}), + **read_kwargs, ) ) paths = rgs = [] @@ -215,13 +222,13 @@ def read_partition( partitions=partitions, partitioning=partitioning, partition_keys=last_partition_keys, - **kwargs.get("read", {}), + **read_kwargs, ) ) df = cudf.concat(dfs) if len(dfs) > 1 else dfs[0] # Re-set "object" dtypes align with pa schema - set_object_dtypes_from_pa_schema(df, kwargs.get("schema", None)) + set_object_dtypes_from_pa_schema(df, schema) if index and (index[0] in df.columns): df = df.set_index(index[0]) diff --git a/python/dask_cudf/dask_cudf/io/tests/test_s3.py b/python/dask_cudf/dask_cudf/io/tests/test_s3.py index ad53f5cfe0f..83ff1273b36 100644 --- a/python/dask_cudf/dask_cudf/io/tests/test_s3.py +++ b/python/dask_cudf/dask_cudf/io/tests/test_s3.py @@ -6,6 +6,7 @@ from io import BytesIO import pandas as pd +import pyarrow.fs as pa_fs import pytest import dask_cudf @@ -115,7 +116,15 @@ def test_read_csv(s3_base, s3so): assert df.a.sum().compute() == 4 -def test_read_parquet(s3_base, s3so): +@pytest.mark.parametrize( + "open_file_options", + [ + {"precache_options": {"method": None}}, + {"precache_options": {"method": "parquet"}}, + {"open_file_func": None}, + ], +) +def test_read_parquet(s3_base, s3so, open_file_options): pdf = pd.DataFrame({"a": [1, 2, 3, 4], "b": [2.1, 2.2, 2.3, 2.4]}) buffer = BytesIO() pdf.to_parquet(path=buffer) @@ -123,8 +132,15 @@ def test_read_parquet(s3_base, s3so): with s3_context( s3_base=s3_base, bucket="daskparquet", files={"file.parq": buffer} ): + if "open_file_func" in open_file_options: + fs = pa_fs.S3FileSystem( + endpoint_override=s3so["client_kwargs"]["endpoint_url"], + ) + open_file_options["open_file_func"] = fs.open_input_file df = dask_cudf.read_parquet( - "s3://daskparquet/*.parq", storage_options=s3so + "s3://daskparquet/*.parq", + storage_options=s3so, + open_file_options=open_file_options, ) assert df.a.sum().compute() == 10 assert df.b.sum().compute() == 9