diff --git a/python/cudf/cudf/_lib/parquet.pyx b/python/cudf/cudf/_lib/parquet.pyx index b85941d109f..fbbaba69b62 100644 --- a/python/cudf/cudf/_lib/parquet.pyx +++ b/python/cudf/cudf/_lib/parquet.pyx @@ -402,7 +402,7 @@ def write_parquet( object partitions_info=None, object force_nullable_schema=False, header_version="1.0", - use_dictionary=True, + use_dictionary="ADAPTIVE", ): """ Cython function to call into libcudf API, see `write_parquet`. @@ -477,11 +477,12 @@ def write_parquet( "Valid values are '1.0' and '2.0'" ) - dict_policy = ( - cudf_io_types.dictionary_policy.ADAPTIVE - if use_dictionary - else cudf_io_types.dictionary_policy.NEVER - ) + # Set up the dictionary policy + dict_policy = cudf_io_types.dictionary_policy.ADAPTIVE + if use_dictionary == "ALWAYS": + dict_policy = cudf_io_types.dictionary_policy.ALWAYS + elif use_dictionary == "NEVER": + dict_policy = cudf_io_types.dictionary_policy.NEVER cdef cudf_io_types.compression_type comp_type = _get_comp_type(compression) cdef cudf_io_types.statistics_freq stat_freq = _get_stat_freq(statistics) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index e7f1ad0751f..72a204fd0db 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -67,7 +67,7 @@ def _write_parquet( storage_options=None, force_nullable_schema=False, header_version="1.0", - use_dictionary=True, + use_dictionary="ADAPTIVE", ): if is_list_like(paths) and len(paths) > 1: if partitions_info is None: @@ -962,6 +962,22 @@ def to_parquet( if partition_offsets is not None else None ) + + # Set up the dictionary policy + dict_policy=None + if use_dictionary == True or use_dictionary == "ADAPTIVE": + dict_policy="ADAPTIVE" + elif use_dictionary == False or use_dictionary == "NEVER": + dict_policy="NEVER" + elif use_dictionary == "ALWAYS": + dict_policy="ALWAYS" + else: + dict_policy="ADAPTIVE" + warnings.warn( + "invalid value passed for `use_dictionary`." + "Using the default value `use_dictionary=True`" + ) + return _write_parquet( df, paths=path if is_list_like(path) else [path], @@ -978,7 +994,7 @@ def to_parquet( storage_options=storage_options, force_nullable_schema=force_nullable_schema, header_version=header_version, - use_dictionary=use_dictionary, + use_dictionary=dict_policy, ) else: diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 6fb1d3d8ba5..a5da536d8dc 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -3093,6 +3093,25 @@ def test_parquet_reader_detect_bad_dictionary(datadir): with pytest.raises(RuntimeError): cudf.read_parquet(fname) +@pytest.mark.parametrize("policy", [True, False, "ADAPTIVE", "ALWAYS", "NEVER"]) +def test_parquet_dictionary_policy(policy): + buf = BytesIO() + table = cudf.DataFrame( + { + "time64[ms]": cudf.Series([1234, 123, 4123], dtype="timedelta64[ms]"), + "int64": cudf.Series([1234, 123, 4123], dtype="int64"), + "list": list([[1,2],[1,2],[1,2]]), + "datetime[ms]": cudf.Series([1234, 123, 4123], dtype="datetime64[ms]"), + }) + + # Write parquet with the specified dict policy + table.to_parquet(buf, use_dictionary=policy) + + # Read the parquet back + got = cudf.read_parquet(buf) + + # Check the tables + assert_eq(table, got) @pytest.mark.parametrize("data", [{"a": [1, 2, 3, 4]}, {"b": [1, None, 2, 3]}]) @pytest.mark.parametrize("force_nullable_schema", [True, False])