Skip to content

Commit

Permalink
Write pandas extension types to parquet file metadata (#8749)
Browse files Browse the repository at this point in the history
Prevents nullable columns to be read as float columns with NaNs when reading with pandas.

Fixes #8688

Authors:
  - Devavret Makkar (https://github.com/devavret)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #8749
  • Loading branch information
devavret authored Jul 16, 2021
1 parent 4540728 commit 4d7ad4f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
17 changes: 16 additions & 1 deletion python/cudf/cudf/_lib/utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ except ImportError:
import json

from cudf.utils.dtypes import (
cudf_dtypes_to_pandas_dtypes,
is_categorical_dtype,
is_decimal_dtype,
is_list_dtype,
Expand All @@ -27,6 +28,12 @@ from cudf.utils.dtypes import (
)


PARQUET_META_TYPE_MAP = {
str(cudf_dtype): str(pandas_dtype)
for cudf_dtype, pandas_dtype in cudf_dtypes_to_pandas_dtypes.items()
}


cdef vector[column_view] make_column_views(object columns):
cdef vector[column_view] views
views.reserve(len(columns))
Expand Down Expand Up @@ -152,8 +159,16 @@ cpdef generate_pandas_metadata(Table table, index):

md_dict = json.loads(metadata[b"pandas"])

# correct metadata for list and struct types
# correct metadata for list and struct and nullable numeric types
for col_meta in md_dict["columns"]:
if (
col_meta["name"] in table._column_names
and table._data[col_meta["name"]].nullable
and col_meta["numpy_type"] in PARQUET_META_TYPE_MAP
):
col_meta["numpy_type"] = PARQUET_META_TYPE_MAP[
col_meta["numpy_type"]
]
if col_meta["numpy_type"] in ("list", "struct"):
col_meta["numpy_type"] = "object"

Expand Down
21 changes: 21 additions & 0 deletions python/cudf/cudf/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,3 +1956,24 @@ def test_parquet_writer_column_validation():
lfunc_args_and_kwargs=(["cudf.parquet"],),
rfunc_args_and_kwargs=(["pandas.parquet"],),
)


def test_parquet_writer_nulls_pandas_read(tmpdir, pdf):
if "col_bool" in pdf.columns:
pdf.drop(columns="col_bool", inplace=True)
if "col_category" in pdf.columns:
pdf.drop(columns="col_category", inplace=True)
gdf = cudf.from_pandas(pdf)

num_rows = len(gdf)
if num_rows > 0:
for col in gdf.columns:
gdf[col][random.randint(0, num_rows - 1)] = None

fname = tmpdir.join("test_parquet_writer_nulls_pandas_read.parquet")
gdf.to_parquet(fname)
assert os.path.exists(fname)

got = pd.read_parquet(fname)
nullable = True if num_rows > 0 else False
assert_eq(gdf.to_pandas(nullable=nullable), got)

0 comments on commit 4d7ad4f

Please sign in to comment.