Skip to content

Commit

Permalink
updates for dict_policy to accept boolean as well as strings
Browse files Browse the repository at this point in the history
  • Loading branch information
mhaseeb123 committed May 7, 2024
1 parent 8ba2258 commit 144dedf
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
13 changes: 7 additions & 6 deletions python/cudf/cudf/_lib/parquet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def write_parquet(
object partitions_info=None,
object force_nullable_schema=False,
header_version="1.0",
use_dictionary=True,
use_dictionary="ADAPTIVE",
):
"""
Cython function to call into libcudf API, see `write_parquet`.
Expand Down Expand Up @@ -477,11 +477,12 @@ def write_parquet(
"Valid values are '1.0' and '2.0'"
)

dict_policy = (
cudf_io_types.dictionary_policy.ADAPTIVE
if use_dictionary
else cudf_io_types.dictionary_policy.NEVER
)
# Set up the dictionary policy
dict_policy = cudf_io_types.dictionary_policy.ADAPTIVE
if use_dictionary == "ALWAYS":
dict_policy = cudf_io_types.dictionary_policy.ALWAYS
elif use_dictionary == "NEVER":
dict_policy = cudf_io_types.dictionary_policy.NEVER

cdef cudf_io_types.compression_type comp_type = _get_comp_type(compression)
cdef cudf_io_types.statistics_freq stat_freq = _get_stat_freq(statistics)
Expand Down
20 changes: 18 additions & 2 deletions python/cudf/cudf/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _write_parquet(
storage_options=None,
force_nullable_schema=False,
header_version="1.0",
use_dictionary=True,
use_dictionary="ADAPTIVE",
):
if is_list_like(paths) and len(paths) > 1:
if partitions_info is None:
Expand Down Expand Up @@ -962,6 +962,22 @@ def to_parquet(
if partition_offsets is not None
else None
)

# Set up the dictionary policy
dict_policy=None
if use_dictionary == True or use_dictionary == "ADAPTIVE":
dict_policy="ADAPTIVE"
elif use_dictionary == False or use_dictionary == "NEVER":
dict_policy="NEVER"
elif use_dictionary == "ALWAYS":
dict_policy="ALWAYS"
else:
dict_policy="ADAPTIVE"
warnings.warn(
"invalid value passed for `use_dictionary`."
"Using the default value `use_dictionary=True`"
)

return _write_parquet(
df,
paths=path if is_list_like(path) else [path],
Expand All @@ -978,7 +994,7 @@ def to_parquet(
storage_options=storage_options,
force_nullable_schema=force_nullable_schema,
header_version=header_version,
use_dictionary=use_dictionary,
use_dictionary=dict_policy,
)

else:
Expand Down
19 changes: 19 additions & 0 deletions python/cudf/cudf/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3093,6 +3093,25 @@ def test_parquet_reader_detect_bad_dictionary(datadir):
with pytest.raises(RuntimeError):
cudf.read_parquet(fname)

@pytest.mark.parametrize("policy", [True, False, "ADAPTIVE", "ALWAYS", "NEVER"])
def test_parquet_dictionary_policy(policy):
buf = BytesIO()
table = cudf.DataFrame(
{
"time64[ms]": cudf.Series([1234, 123, 4123], dtype="timedelta64[ms]"),
"int64": cudf.Series([1234, 123, 4123], dtype="int64"),
"list": list([[1,2],[1,2],[1,2]]),
"datetime[ms]": cudf.Series([1234, 123, 4123], dtype="datetime64[ms]"),
})

# Write parquet with the specified dict policy
table.to_parquet(buf, use_dictionary=policy)

# Read the parquet back
got = cudf.read_parquet(buf)

# Check the tables
assert_eq(table, got)

@pytest.mark.parametrize("data", [{"a": [1, 2, 3, 4]}, {"b": [1, None, 2, 3]}])
@pytest.mark.parametrize("force_nullable_schema", [True, False])
Expand Down

0 comments on commit 144dedf

Please sign in to comment.