Skip to content

Commit

Permalink
Fix decimal metadata in parquet writer (#10224)
Browse files Browse the repository at this point in the history
Fixes: #10172 

`pa.pandas_compat.construct_metadata` constructs the correct metadata but is being overridden by special `list` & `struct` handling logic as `string`, rather than retaining it as `object`. This PR fixes the issue and modifies existing tests to validate the issue.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - https://github.com/brandon-b-miller
  - Devavret Makkar (https://github.com/devavret)

URL: #10224
  • Loading branch information
galipremsagar authored Feb 4, 2022
1 parent 84ae8ab commit e5ba292
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
3 changes: 2 additions & 1 deletion python/cudf/cudf/_lib/utils.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -190,6 +190,7 @@ cpdef generate_pandas_metadata(table, index):
col_meta["name"] in table._column_names
and table._data[col_meta["name"]].nullable
and col_meta["numpy_type"] in PARQUET_META_TYPE_MAP
and col_meta["pandas_type"] != "decimal"
):
col_meta["numpy_type"] = PARQUET_META_TYPE_MAP[
col_meta["numpy_type"]
Expand Down
15 changes: 7 additions & 8 deletions python/cudf/cudf/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2265,18 +2265,17 @@ def test_parquet_writer_nested(tmpdir, data):
"decimal_type",
[cudf.Decimal32Dtype, cudf.Decimal64Dtype, cudf.Decimal128Dtype],
)
def test_parquet_writer_decimal(tmpdir, decimal_type):

gdf = cudf.DataFrame({"val": [0.00, 0.01, 0.02]})
@pytest.mark.parametrize("data", [[1, 2, 3], [0.00, 0.01, None, 0.5]])
def test_parquet_writer_decimal(decimal_type, data):
gdf = cudf.DataFrame({"val": data})

gdf["dec_val"] = gdf["val"].astype(decimal_type(7, 2))

fname = tmpdir.join("test_parquet_writer_decimal.parquet")
gdf.to_parquet(fname)
assert os.path.exists(fname)
buff = BytesIO()
gdf.to_parquet(buff)

got = pd.read_parquet(fname)
assert_eq(gdf, got)
got = pd.read_parquet(buff, use_nullable_dtypes=True)
assert_eq(gdf.to_pandas(nullable=True), got)


def test_parquet_writer_column_validation():
Expand Down

0 comments on commit e5ba292

Please sign in to comment.