Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write pandas extension types to parquet file metadata #8749

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion python/cudf/cudf/_lib/utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ except ImportError:
import json

from cudf.utils.dtypes import (
cudf_dtypes_to_pandas_dtypes,
is_categorical_dtype,
is_decimal_dtype,
is_list_dtype,
Expand All @@ -27,6 +28,12 @@ from cudf.utils.dtypes import (
)


PARQUET_META_TYPE_MAP = {
str(cudf_dtype): str(pandas_dtype)
for cudf_dtype, pandas_dtype in cudf_dtypes_to_pandas_dtypes.items()
}


cdef vector[column_view] make_column_views(object columns):
cdef vector[column_view] views
views.reserve(len(columns))
Expand Down Expand Up @@ -152,8 +159,16 @@ cpdef generate_pandas_metadata(Table table, index):

md_dict = json.loads(metadata[b"pandas"])

# correct metadata for list and struct types
# correct metadata for list and struct and nullable numeric types
for col_meta in md_dict["columns"]:
if (
col_meta["name"] in table._column_names
and table._data[col_meta["name"]].nullable
and col_meta["numpy_type"] in PARQUET_META_TYPE_MAP
):
col_meta["numpy_type"] = PARQUET_META_TYPE_MAP[
col_meta["numpy_type"]
]
if col_meta["numpy_type"] in ("list", "struct"):
col_meta["numpy_type"] = "object"

Expand Down
21 changes: 21 additions & 0 deletions python/cudf/cudf/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,3 +1956,24 @@ def test_parquet_writer_column_validation():
lfunc_args_and_kwargs=(["cudf.parquet"],),
rfunc_args_and_kwargs=(["pandas.parquet"],),
)


def test_parquet_writer_nulls_pandas_read(tmpdir, pdf):
if "col_bool" in pdf.columns:
pdf.drop(columns="col_bool", inplace=True)
if "col_category" in pdf.columns:
pdf.drop(columns="col_category", inplace=True)
gdf = cudf.from_pandas(pdf)

num_rows = len(gdf)
if num_rows > 0:
for col in gdf.columns:
gdf[col][random.randint(0, num_rows - 1)] = None

fname = tmpdir.join("test_parquet_writer_nulls_pandas_read.parquet")
gdf.to_parquet(fname)
assert os.path.exists(fname)

got = pd.read_parquet(fname)
nullable = True if num_rows > 0 else False
assert_eq(gdf.to_pandas(nullable=nullable), got)