From c24664c1f106750bd534c72de3729e96b0b4372e Mon Sep 17 00:00:00 2001 From: Thomas Li Date: Fri, 7 Jun 2024 18:25:06 +0000 Subject: [PATCH] update and start writing tests --- python/cudf/cudf/_lib/json.pyx | 2 +- .../cudf/cudf/pylibcudf_tests/common/utils.py | 36 ++++++- python/cudf/cudf/pylibcudf_tests/conftest.py | 100 +++++++++++++++--- python/cudf/cudf/pylibcudf_tests/test_json.py | 22 ++++ 4 files changed, 141 insertions(+), 19 deletions(-) diff --git a/python/cudf/cudf/_lib/json.pyx b/python/cudf/cudf/_lib/json.pyx index 26ee1dc4554..d0e0875feb1 100644 --- a/python/cudf/cudf/_lib/json.pyx +++ b/python/cudf/cudf/_lib/json.pyx @@ -174,7 +174,7 @@ def write_json( plc.io.SinkInfo([path_or_buf]), plc.io.TableWithMetadata( plc.Table([ - c.to_pylibcudf(mode="read") for c in table._data.columns + c.to_pylibcudf(mode="read") for c in table._columns ]), colnames ), diff --git a/python/cudf/cudf/pylibcudf_tests/common/utils.py b/python/cudf/cudf/pylibcudf_tests/common/utils.py index 54d38f1a8cf..9eaee15e89a 100644 --- a/python/cudf/cudf/pylibcudf_tests/common/utils.py +++ b/python/cudf/cudf/pylibcudf_tests/common/utils.py @@ -136,8 +136,36 @@ def is_fixed_width(plc_dtype: plc.DataType): ) -# We must explicitly specify this type via a field to ensure we don't include -# nullability accidentally. -DEFAULT_STRUCT_TESTING_TYPE = pa.struct( - [pa.field("v", pa.int64(), nullable=False)] +NUMERIC_PA_TYPES = [pa.int64(), pa.float64(), pa.uint64()] +STRING_PA_TYPES = [pa.string()] +BOOL_PA_TYPES = [pa.bool_()] +LIST_PA_TYPES = [ + pa.list_(pa.int64()), + # Nested case + pa.list_(pa.list_(pa.int64())), +] + +DEFAULT_PA_STRUCT_TESTING_TYPES = [ + # We must explicitly specify this type via a field to ensure we don't include + # nullability accidentally. + pa.struct([pa.field("v", pa.int64(), nullable=False)]), + # Nested case + pa.struct( + [ + pa.field("a", pa.int64(), nullable=False), + pa.field( + "b_struct", + pa.struct([pa.field("b", pa.float64(), nullable=False)]), + nullable=False, + ), + ] + ), +] + +DEFAULT_PA_TYPES = ( + NUMERIC_PA_TYPES + + STRING_PA_TYPES + + BOOL_PA_TYPES + + LIST_PA_TYPES + + DEFAULT_PA_STRUCT_TESTING_TYPES ) diff --git a/python/cudf/cudf/pylibcudf_tests/conftest.py b/python/cudf/cudf/pylibcudf_tests/conftest.py index f3c6584ef8c..893116e8261 100644 --- a/python/cudf/cudf/pylibcudf_tests/conftest.py +++ b/python/cudf/cudf/pylibcudf_tests/conftest.py @@ -4,6 +4,7 @@ import os import sys +import numpy as np import pyarrow as pa import pytest @@ -11,7 +12,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "common")) -from utils import DEFAULT_STRUCT_TESTING_TYPE +from utils import DEFAULT_PA_TYPES, NUMERIC_PA_TYPES # This fixture defines the standard set of types that all tests should default to @@ -20,14 +21,7 @@ # across modules. Otherwise it may be defined on a per-module basis. @pytest.fixture( scope="session", - params=[ - pa.int64(), - pa.float64(), - pa.string(), - pa.bool_(), - pa.list_(pa.int64()), - DEFAULT_STRUCT_TESTING_TYPE, - ], + params=[DEFAULT_PA_TYPES], ) def pa_type(request): return request.param @@ -35,16 +29,94 @@ def pa_type(request): @pytest.fixture( scope="session", - params=[ - pa.int64(), - pa.float64(), - pa.uint64(), - ], + params=[NUMERIC_PA_TYPES], ) def numeric_pa_type(request): return request.param +@pytest.fixture(scope="session", params=[0, 100]) +def plc_table_w_meta(request): + """ + The default TableWithMetadata you should be using for testing + pylibcudf I/O writers. + + Contains one of each category (e.g. int, bool, list, struct) + of dtypes. + """ + nrows = request.param + + table_dict = dict() + # Colnames in the format expected by + # plc.io.TableWithMetadata + colnames = [] + + for typ in DEFAULT_PA_TYPES: + rand_vals = np.random.randint(0, nrows, nrows) + child_colnames = [] + + if isinstance(typ, pa.ListType): + + def _generate_list_data(typ): + child_colnames = [] + if isinstance(typ, pa.ListType): + # recurse to get vals + rand_arrs, grandchild_colnames = _generate_list_data( + typ.value_type + ) + pa_array = pa.array( + [list(row_vals) for row_vals in zip(rand_arrs)], + type=typ, + ) + child_colnames.append(("", grandchild_colnames)) + else: + # typ is scalar type + pa_array = pa.array(rand_vals).cast(typ) + child_colnames.append(("", [])) + return pa_array, child_colnames + + rand_arr, child_colnames = _generate_list_data(typ) + elif isinstance(typ, pa.StructType): + + def _generate_struct_data(typ): + child_colnames = [] + if isinstance(typ, pa.StructType): + # recurse to get vals + rand_arrs = [] + for i in range(typ.num_fields): + rand_arr, grandchild_colnames = _generate_struct_data( + typ.field(i).type + ) + rand_arrs.append(rand_arr) + child_colnames.append( + (typ.field(i).name, grandchild_colnames) + ) + + pa_array = pa.StructArray.from_arrays( + [rand_arr for rand_arr in rand_arrs], + names=[ + typ.field(i).name for i in range(typ.num_fields) + ], + ) + else: + # typ is scalar type + pa_array = pa.array(rand_vals).cast(typ) + return pa_array, child_colnames + + rand_arr, child_colnames = _generate_struct_data(typ) + else: + rand_arr = pa.array(rand_vals).cast(typ) + + table_dict[f"col_{typ}"] = rand_arr + colnames.append((f"col_{typ}", child_colnames)) + + pa_table = pa.Table.from_pydict(table_dict) + + return plc.io.TableWithMetadata( + plc.interop.from_arrow(pa_table), column_names=colnames + ) + + @pytest.fixture( scope="session", params=[opt for opt in plc.types.Interpolation] ) diff --git a/python/cudf/cudf/pylibcudf_tests/test_json.py b/python/cudf/cudf/pylibcudf_tests/test_json.py index e69de29bb2d..b3ec7f884eb 100644 --- a/python/cudf/cudf/pylibcudf_tests/test_json.py +++ b/python/cudf/cudf/pylibcudf_tests/test_json.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +import io +import os +import pathlib + +import pytest + +import cudf._lib.pylibcudf as plc + + +@pytest.mark.parametrize( + "sink", ["a.txt", pathlib.Path("a.txt"), io.BytesIO(), io.StringIO()] +) +def test_write_json_basic(plc_table_w_meta, sink, tmp_path): + if isinstance(sink, str): + sink = f"{tmp_path}/{sink}" + elif isinstance(sink, os.PathLike): + sink = tmp_path.joinpath(sink) + plc.io.json.write_json( + plc.io.SinkInfo([sink]), + plc_table_w_meta, + )