diff --git a/nvtabular/io/dataset.py b/nvtabular/io/dataset.py index 8677fa65947..f83136477f4 100644 --- a/nvtabular/io/dataset.py +++ b/nvtabular/io/dataset.py @@ -197,6 +197,13 @@ class Dataset: Optional reference to the original "base" Dataset object used to construct the current Dataset instance. This object is used to preserve file-partition mapping information. + **kwargs : + Key-word arguments to pass through to Dask.dataframe IO function. + For the Parquet engine(s), notable arguments include `filters`, + `aggregate_files`, and `gather_statistics`. Note that users who + do not need to know the number of rows in their dataset (and do + not plan to preserve a file-partition mapping) may wish to use + `gather_statistics=False` for better client-side performance. """ def __init__( @@ -454,6 +461,10 @@ def shuffle_by_keys(self, keys, hive_data=None, npartitions=None): for c in keys: typ = ddf._meta[c].dtype if c in cols: + if typ == "category": + # Cannot cast directly to categorical unless we + # first cast to the underlying dtype of the categories + hive_mapping[c] = hive_mapping[c].astype(typ.categories.dtype) hive_mapping[c] = hive_mapping[c].astype(typ) # Generate simple-shuffle plan diff --git a/nvtabular/io/dataset_engine.py b/nvtabular/io/dataset_engine.py index 6f7db339458..94e6b76d7e7 100644 --- a/nvtabular/io/dataset_engine.py +++ b/nvtabular/io/dataset_engine.py @@ -26,7 +26,10 @@ def __init__(self, paths, part_size, cpu=False, storage_options=None): self.paths = paths self.part_size = part_size self.storage_options = storage_options - fs, fs_token, _ = get_fs_token_paths(paths, mode="rb", storage_options=self.storage_options) + fs, fs_token, paths2 = get_fs_token_paths( + paths, mode="rb", storage_options=self.storage_options + ) + self.stripped_paths = paths2 self.fs = fs self.fs_token = fs_token self.cpu = cpu diff --git a/nvtabular/io/parquet.py b/nvtabular/io/parquet.py index b2cdab42d20..8081737175d 100644 --- a/nvtabular/io/parquet.py +++ b/nvtabular/io/parquet.py @@ -30,17 +30,20 @@ import cudf import dask_cudf from cudf.io.parquet import ParquetWriter as pwriter_cudf + from dask_cudf.io.parquet import CudfEngine except ImportError: cudf = None import dask import dask.dataframe as dd import fsspec -import numpy as np import pandas as pd import pyarrow as pa +import pyarrow.dataset as pa_ds import toolz as tlz from dask.base import tokenize from dask.dataframe.core import _concat, new_dd_object +from dask.dataframe.io.parquet.arrow import ArrowDatasetEngine +from dask.dataframe.io.parquet.core import apply_filters from dask.dataframe.io.parquet.utils import _analyze_paths from dask.delayed import Delayed from dask.highlevelgraph import HighLevelGraph @@ -49,6 +52,11 @@ from pyarrow import parquet as pq from pyarrow.parquet import ParquetWriter as pwriter_pyarrow +if LooseVersion(dask.__version__) >= "2021.07.1": + from dask.dataframe.io.parquet.core import aggregate_row_groups +else: + aggregate_row_groups = None + from .dataset_engine import DatasetEngine from .shuffle import Shuffle, _shuffle_df from .writer import ThreadedWriter @@ -56,6 +64,126 @@ LOG = logging.getLogger("nvtabular") +class CPUParquetEngine(ArrowDatasetEngine): + @staticmethod + def read_metadata(*args, **kwargs): + return _override_read_metadata(ArrowDatasetEngine, *args, **kwargs) + + @classmethod + def multi_support(cls): + return hasattr(ArrowDatasetEngine, "multi_support") and ArrowDatasetEngine.multi_support() + + +# Define GPUParquetEngine if cudf is available +if cudf is not None: + + class GPUParquetEngine(CudfEngine): + @staticmethod + def read_metadata(*args, **kwargs): + return _override_read_metadata(CudfEngine, *args, **kwargs) + + @classmethod + def multi_support(cls): + return hasattr(CudfEngine, "multi_support") and CudfEngine.multi_support() + + +def _override_read_metadata( + engine, + fs, + paths, + index=None, + gather_statistics=None, + split_row_groups=None, + filters=None, + aggregate_files=None, + dataset=None, + chunksize=None, + **global_kwargs, +): + # This function is used by both CPU and GPU-backed + # ParquetDatasetEngine instances to override the `read_metadata` + # component of the upstream `read_parquet` logic. This provides + # NVTabular with direct access to the final partitioning behavior. + + # For now, disallow the user from setting `chunksize` + if chunksize: + raise ValueError( + "NVTabular does not yet support the explicit use " "of Dask's `chunksize` argument." + ) + + # Extract metadata_collector from the dataset "container" + dataset = dataset or {} + metadata_collector = dataset.pop("metadata_collector", None) + + # Gather statistics by default. + # This enables optimized length calculations + if gather_statistics is None: + gather_statistics = True + + # Use a local_kwarg dictionary to make it easier to exclude + # `aggregate_files` for older Dask versions + local_kwargs = { + "index": index, + "filters": filters, + # Use chunksize=1 to "ensure" statistics are gathered + # if `gather_statistics=True`. Note that Dask will bail + # from statistics gathering if it does not expect statistics + # to be "used" after `read_metadata` returns. + "chunksize": 1 if gather_statistics else None, + "gather_statistics": gather_statistics, + "split_row_groups": split_row_groups, + } + if aggregate_row_groups is not None: + # File aggregation is only available for Dask>=2021.07.1 + local_kwargs["aggregate_files"] = aggregate_files + elif aggregate_files: + raise ValueError("This version of Dask does not support the `aggregate_files` argument.") + + # Start with "super-class" read_metadata logic + read_metadata_result = engine.read_metadata( + fs, + paths, + **local_kwargs, + **global_kwargs, + ) + parts = read_metadata_result[2].copy() + statistics = read_metadata_result[1].copy() + + # Process the statistics. + # Note that these steps are usaually performed after + # `engine.read_metadata` returns (in Dask), but we are doing + # it ourselves in NVTabular (to capture the expected output + # partitioning plan) + if statistics: + result = list( + zip(*[(part, stats) for part, stats in zip(parts, statistics) if stats["num-rows"] > 0]) + ) + parts, statistics = result or [[], []] + + # Apply filters + if filters: + parts, statistics = apply_filters(parts, statistics, filters) + + # Apply file aggregation + if aggregate_row_groups is not None: + + # Convert `aggregate_files` to an integer `aggregation_depth` + aggregation_depth = False + if len(parts) and aggregate_files: + aggregation_depth = parts[0].get("aggregation_depth", aggregation_depth) + + # Aggregate parts/statistics if we are splitting by row-group + if chunksize or (split_row_groups and int(split_row_groups) > 1): + parts, statistics = aggregate_row_groups( + parts, statistics, chunksize, split_row_groups, fs, aggregation_depth + ) + + # Update `metadata_collector` and return the "original" `read_metadata_result` + metadata_collector["stats"] = statistics + metadata_collector["parts"] = parts + return read_metadata_result + + class ParquetDatasetEngine(DatasetEngine): """ParquetDatasetEngine is a Dask-based version of cudf.read_parquet.""" @@ -68,12 +196,21 @@ def __init__( legacy=False, batch_size=None, # Ignored cpu=False, + **kwargs, ): super().__init__(paths, part_size, cpu=cpu, storage_options=storage_options) self._pp_map = None self._pp_nrows = None + self._pp_metadata = None + + # Process `kwargs` + self.read_parquet_kwargs = kwargs.copy() + self.aggregate_files = self.read_parquet_kwargs.pop("aggregate_files", False) + self.filters = self.read_parquet_kwargs.pop("filters", None) + self.dataset_kwargs = self.read_parquet_kwargs.pop("dataset", {}) + if row_groups_per_part is None: - path0 = self._dataset.pieces[0].path + path0 = next(self._dataset.get_fragments()).path if cpu: with self.fs.open(path0, "rb") as f0: # Use pyarrow for CPU version. @@ -105,7 +242,9 @@ def __init__( @property @functools.lru_cache(1) - def _dataset(self): + def _legacy_dataset(self): + # TODO: Remove this after finding a way to avoid + # the use of `ParquetDataset` in `validate_dataset` paths = self.paths fs = self.fs if len(paths) > 1: @@ -119,6 +258,19 @@ def _dataset(self): dataset = pq.ParquetDataset(paths[0], filesystem=fs) return dataset + @property + @functools.lru_cache(1) + def _dataset(self): + paths = self.stripped_paths + fs = self.fs + if len(paths) > 1: + # This is a list of files + dataset = pa_ds.dataset(paths, filesystem=fs) + else: + # This is a directory or a single file + dataset = pa_ds.dataset(paths[0], filesystem=fs) + return dataset + @property def _file_partition_map(self): if self._pp_map is None: @@ -135,94 +287,65 @@ def _partition_lens(self): def num_rows(self): # TODO: Avoid parsing metadata once upstream dask # can get the length efficiently (in all practical cases) - return sum(self._partition_lens) + if self._partition_lens: + return sum(self._partition_lens) + return len(self.to_ddf().index) def _process_parquet_metadata(self): # Utility shared by `_file_partition_map` and `_partition_lens` # to collect useful information from the parquet metadata - _pp_nrows = [] + # First, we need to populate `self._pp_metadata` + if self._pp_metadata is None: + _ = self.to_ddf() - def _update_partition_lens(md, num_row_groups, rg_offset=None): - # Helper function to calculate the row count for each - # output partition (and add it to `_pp_nrows`) - rg_offset = rg_offset or 0 - for rg_i in range(0, num_row_groups, self.row_groups_per_part): - rg_f = min(rg_i + self.row_groups_per_part, num_row_groups) - _pp_nrows.append( - sum([md.row_group(rg + rg_offset).num_rows for rg in range(rg_i, rg_f)]) - ) - return - - dataset = self._dataset - if dataset.metadata: - # We have a metadata file. - # Determing the row-group count per file. - _path_row_groups = defaultdict(int) - for rg in range(dataset.metadata.num_row_groups): - fn = dataset.metadata.row_group(rg).column(0).file_path - _path_row_groups[fn] += 1 - - # Convert the per-file row-group count to the - # file-to-partition mapping - ind, rg = 0, 0 - _pp_map = defaultdict(list) - for fn, num_row_groups in _path_row_groups.items(): - part_count = math.ceil(num_row_groups / self.row_groups_per_part) - _pp_map[fn] = np.arange(ind, ind + part_count) - _update_partition_lens(dataset.metadata, num_row_groups, rg_offset=rg) - ind += part_count - rg += num_row_groups - else: - # No metadata file. Construct file-to-partition map manually - ind = 0 - _pp_map = {} - for piece in dataset.pieces: - md = piece.get_metadata() - num_row_groups = md.num_row_groups - part_count = math.ceil(num_row_groups / self.row_groups_per_part) - fn = piece.path.split(self.fs.sep)[-1] - _pp_map[fn] = np.arange(ind, ind + part_count) - _update_partition_lens(md, num_row_groups) - ind += part_count + # Second, we can use the path and num-rows information + # in parts and stats + parts = self._pp_metadata["parts"] + stats = self._pp_metadata["stats"] + _pp_map = {} + _pp_nrows = [] + distinct_files = True + for i, (part, stat) in enumerate(zip(parts, stats)): + if distinct_files: + if isinstance(part, list): + if len(part) > 1: + distinct_files = False + else: + path = part[0]["piece"][0] + _pp_map[path] = i + else: + path = part["piece"][0] + _pp_map[path] = i + _pp_nrows.append(stat["num-rows"]) - self._pp_map = _pp_map self._pp_nrows = _pp_nrows + self._pp_map = _pp_map def to_ddf(self, columns=None, cpu=None): - # Check if we are using cpu + # Check if we are using cpu or gpu backend cpu = self.cpu if cpu is None else cpu + backend_engine = CPUParquetEngine if cpu else GPUParquetEngine - if cpu: - # Return a Dask-Dataframe in CPU memory - for try_engine in ["pyarrow-dataset", "pyarrow"]: - # Try to use the "pyarrow-dataset" engine, if - # available, but fall back on vanilla "pyarrow" - # for older Dask versions. - try: - return dd.read_parquet( - self.paths, - engine=try_engine, - columns=columns, - index=None if columns is None else False, - gather_statistics=False, - split_row_groups=self.row_groups_per_part, - storage_options=self.storage_options, - ) - except ValueError: - pass - raise RuntimeError("dask.dataframe.read_parquet failed.") - - return dask_cudf.read_parquet( + # Use dask-dataframe with appropriate engine + metadata_collector = {"stats": [], "parts": []} + dataset_kwargs = {"metadata_collector": metadata_collector} + dataset_kwargs.update(self.dataset_kwargs) + ddf = dd.read_parquet( self.paths, columns=columns, - # can't omit reading the index in if we aren't being passed columns - index=None if columns is None else False, - gather_statistics=False, + engine=backend_engine, + index=False, + aggregate_files=self.aggregate_files, + filters=self.filters, split_row_groups=self.row_groups_per_part, storage_options=self.storage_options, + dataset=dataset_kwargs, + **self.read_parquet_kwargs, ) + self._pp_metadata = metadata_collector + return ddf def to_cpu(self): self.cpu = True @@ -299,7 +422,7 @@ def validate_dataset( file_min_size = parse_bytes(file_min_size) # Get dataset and path list - pa_dataset = self._dataset + pa_dataset = self._legacy_dataset paths = [p.path for p in pa_dataset.pieces] root_dir, fns = _analyze_paths(paths, self.fs) diff --git a/tests/unit/test_io.py b/tests/unit/test_io.py index 2322468c894..3cc2f12b3a6 100644 --- a/tests/unit/test_io.py +++ b/tests/unit/test_io.py @@ -638,10 +638,18 @@ def test_validate_and_regenerate_dataset(tmpdir): ds2.validate_dataset(file_min_size=1) # Check that dataset content is correct - assert_eq(ddf, ds2.to_ddf().compute()) + assert_eq( + ddf.reset_index(drop=False), + ds2.to_ddf().compute(), + check_index=False, + ) # Check cpu version of `to_ddf` - assert_eq(ddf, ds2.engine.to_ddf(cpu=True).compute()) + assert_eq( + ddf.reset_index(drop=False), + ds2.engine.to_ddf(cpu=True).compute(), + check_index=False, + ) @pytest.mark.parametrize("preserve_files", [True, False]) @@ -780,7 +788,7 @@ def test_hive_partitioned_data(tmpdir, cpu): os.remove(os.path.join(path, "schema.pbtxt")) # Read back with dask.dataframe and check the data - df_check = dd.read_parquet(path).compute() + df_check = dd.read_parquet(path, engine="pyarrow").compute() df_check["name"] = df_check["name"].astype("object") df_check["timestamp"] = df_check["timestamp"].astype("int64") df_check = df_check.sort_values(["id", "x", "y"]).reset_index(drop=True) @@ -847,3 +855,89 @@ def test_dataset_shuffle_on_keys(tmpdir, cpu, partition_on, keys, npartitions): for col in df1: # Order of columns can change after round-trip partitioning assert_eq(df1[col], df2[col], check_index=False) + + +@pytest.mark.parametrize("cpu", [True, False]) +def test_parquet_filtered_flat(tmpdir, cpu): + + # Initial timeseries dataset (in cpu memory). + # Round the full "timestamp" to the hour for partitioning. + path = str(tmpdir) + ddf1 = dd.from_pandas(pd.DataFrame({"a": [1] * 10}), 1) + ddf1.to_parquet(path, engine="pyarrow", write_index=False) + ddf2 = dd.from_pandas(pd.DataFrame({"a": [2] * 10}), 1) + ddf2.to_parquet(path, engine="pyarrow", append=True, write_index=False) + ddf3 = dd.from_pandas(pd.DataFrame({"a": [3] * 10}), 1) + ddf3.to_parquet(path, engine="pyarrow", append=True, write_index=False) + + # Convert to nvt.Dataset + ds = nvt.Dataset(path, engine="parquet", filters=[("a", ">", 1)]) + + # Make sure partitions were filtered + assert len(ds.to_ddf().a.unique()) == 2 + + +@pytest.mark.parametrize("cpu", [True, False]) +def test_parquet_filtered_hive(tmpdir, cpu): + + # Initial timeseries dataset (in cpu memory). + # Round the full "timestamp" to the hour for partitioning. + path = str(tmpdir) + ddf = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-03", + freq="600s", + partition_freq="6h", + seed=42, + ).reset_index() + ddf["timestamp"] = ddf["timestamp"].dt.round("D").dt.day + ddf.to_parquet(path, partition_on=["timestamp"], engine="pyarrow") + + # Convert to nvt.Dataset + ds = nvt.Dataset(path, cpu=cpu, engine="parquet", filters=[("timestamp", "==", 1)]) + + # Make sure partitions were filtered + assert len(ds.to_ddf().timestamp.unique()) == 1 + + +@pytest.mark.skipif( + LooseVersion(dask.__version__) < "2021.07.1", + reason="Dask>=2021.07.1 required for file aggregation", +) +@pytest.mark.parametrize("cpu", [True, False]) +def test_parquet_aggregate_files(tmpdir, cpu): + + # Initial timeseries dataset (in cpu memory). + # Round the full "timestamp" to the hour for partitioning. + path = str(tmpdir) + ddf = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-03", + freq="600s", + partition_freq="6h", + seed=42, + ).reset_index() + ddf["timestamp"] = ddf["timestamp"].dt.round("D").dt.day + ddf.to_parquet(path, partition_on=["timestamp"], engine="pyarrow") + + # Setting `aggregate_files=True` should result + # in one large partition + ds = nvt.Dataset(path, cpu=cpu, engine="parquet", aggregate_files=True, part_size="1GB") + assert ds.to_ddf().npartitions == 1 + + # Setting `aggregate_files="timestamp"` should result + # in one partition for each unique value of "timestamp" + ds = nvt.Dataset(path, cpu=cpu, engine="parquet", aggregate_files="timestamp", part_size="1GB") + assert ds.to_ddf().npartitions == len(ddf.timestamp.unique()) + + # Combining `aggregate_files` and `filters` should work + ds = nvt.Dataset( + path, + cpu=cpu, + engine="parquet", + aggregate_files="timestamp", + filters=[("timestamp", "==", 1)], + part_size="1GB", + ) + assert ds.to_ddf().npartitions == 1 + assert len(ds.to_ddf().timestamp.unique()) == 1