From 858944b76dd3cb257b262d1626d71e6696296128 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath <3190405+shwina@users.noreply.github.com> Date: Thu, 2 Sep 2021 10:08:50 -0400 Subject: [PATCH] Use decimal precision metadata when reading from parquet files (#9162) Closes #8354. Authors: - Ashwin Srinath (https://github.com/shwina) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/9162 --- python/cudf/cudf/_lib/parquet.pyx | 23 +++++++---------------- python/cudf/cudf/tests/test_parquet.py | 26 ++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/python/cudf/cudf/_lib/parquet.pyx b/python/cudf/cudf/_lib/parquet.pyx index 95ae2202f68..e12a61f2a49 100644 --- a/python/cudf/cudf/_lib/parquet.pyx +++ b/python/cudf/cudf/_lib/parquet.pyx @@ -185,22 +185,13 @@ cpdef read_parquet(filepaths_or_buffers, columns=None, row_groups=None, update_struct_field_names(df, c_out_table.metadata.schema_info) - if df.empty and meta is not None: - cols_dtype_map = {} - for col in meta['columns']: - cols_dtype_map[col['name']] = col['numpy_type'] - - if not column_names: - column_names = [o['name'] for o in meta['columns']] - if not is_range_index and index_col in cols_dtype_map: - column_names.remove(index_col) - - for col in column_names: - meta_dtype = cols_dtype_map.get(col, None) - df._data[col] = cudf.core.column.column_empty( - row_count=0, - dtype=cudf.dtype(meta_dtype) - ) + # update the decimal precision of each column + if meta is not None: + for col, col_meta in zip(column_names, meta["columns"]): + if isinstance(df._data[col].dtype, cudf.Decimal64Dtype): + df._data[col].dtype.precision = ( + col_meta["metadata"]["precision"] + ) # Set the index column if index_col is not None and len(index_col) > 0: diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index e4a61a2a37e..26cc36f0e4f 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -2047,3 +2047,29 @@ def test_parquet_writer_nulls_pandas_read(tmpdir, pdf): got = pd.read_parquet(fname) nullable = True if num_rows > 0 else False assert_eq(gdf.to_pandas(nullable=nullable), got) + + +def test_parquet_decimal_precision(tmpdir): + df = cudf.DataFrame({"val": ["3.5", "4.2"]}).astype( + cudf.Decimal64Dtype(5, 2) + ) + assert df.val.dtype.precision == 5 + + fname = tmpdir.join("decimal_test.parquet") + df.to_parquet(fname) + df = cudf.read_parquet(fname) + assert df.val.dtype.precision == 5 + + +def test_parquet_decimal_precision_empty(tmpdir): + df = ( + cudf.DataFrame({"val": ["3.5", "4.2"]}) + .astype(cudf.Decimal64Dtype(5, 2)) + .iloc[:0] + ) + assert df.val.dtype.precision == 5 + + fname = tmpdir.join("decimal_test.parquet") + df.to_parquet(fname) + df = cudf.read_parquet(fname) + assert df.val.dtype.precision == 5