Skip to content

Commit

Permalink
Fix error in testing utils
Browse files Browse the repository at this point in the history
Co-authored-by: Lawrence Mitchell <[email protected]>
  • Loading branch information
lithomas1 and wence- committed Jun 25, 2024
1 parent 9a6a896 commit 0ed9af6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 45 deletions.
14 changes: 11 additions & 3 deletions python/cudf/cudf/pylibcudf_tests/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ def metadata_from_arrow_type(
name: str = "",
) -> plc.interop.ColumnMetadata | None:
metadata = plc.interop.ColumnMetadata(name) # None
if pa.types.is_list(pa_type) or pa.types.is_struct(pa_type):
if pa.types.is_list(pa_type):
child_meta = [plc.interop.ColumnMetadata("offsets")]
for i in range(pa_type.num_fields):
field_meta = metadata_from_arrow_type(
pa_type.field(i).type, pa_type.field(i).name
)
child_meta.append(field_meta)
metadata = plc.interop.ColumnMetadata(name, child_meta)
elif pa.types.is_struct(pa_type):
child_meta = []
for i in range(pa_type.num_fields):
field_meta = metadata_from_arrow_type(
Expand Down Expand Up @@ -57,8 +65,8 @@ def assert_column_eq(
if isinstance(rhs, pa.ChunkedArray):
rhs = rhs.combine_chunks()

# print(lhs)
# print(rhs)
print(lhs)
print(rhs)
assert lhs.equals(rhs)


Expand Down
68 changes: 26 additions & 42 deletions python/cudf/cudf/pylibcudf_tests/test_copying.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,6 @@
from cudf._lib import pylibcudf as plc


@pytest.fixture
def nested_list_skip(request):
"""
Fixture that xfails a test if we encounter a nested list.
(as of right now, we are encountering some segfaults/memoryerrors
in interop)
"""
if "target_table" in request.fixturenames:
pa_table, _ = request.getfixturevalue("target_table")
if any(is_nested_list(col.type) for col in pa_table.columns):
pytest.skip(reason="pylibcudf interop fails with nested list")
elif "target_column" or "input_column" in request.fixturenames:
if "target_column" in request.fixturenames:
pa_col, _ = request.getfixturevalue("target_column")
else:
pa_col, _ = request.getfixturevalue("input_column")
if is_nested_list(pa_col.type):
pytest.skip(reason="pylibcudf interop fails with nested list")


xfail_nested_struct = pytest.mark.usefixtures("nested_struct_xfail")
skip_nested_list = pytest.mark.usefixtures("nested_list_skip")


# TODO: consider moving this to conftest and "pairing"
# it with pa_type, so that they don't get out of sync
# TODO: Test nullable data
Expand Down Expand Up @@ -194,7 +170,6 @@ def mask(target_column):
return pa_mask, plc.interop.from_arrow(pa_mask)


@skip_nested_list
def test_gather(target_table, index_column):
pa_target_table, plc_target_table = target_table
pa_index_column, plc_index_column = index_column
Expand Down Expand Up @@ -250,7 +225,6 @@ def _pyarrow_boolean_mask_scatter_table(source, mask, target_table):
)


@skip_nested_list
def test_scatter_table(
source_table,
index_column,
Expand Down Expand Up @@ -280,9 +254,17 @@ def test_scatter_table(
)

if pa.types.is_list(dtype := pa_target_table[0].type):
expected = pa.table(
[pa.array([[4], [1], [2, 3], [3], [9], [10]])] * 3, [""] * 3
)
if is_nested_list(dtype):
expected = pa.table(
[pa.array([[[4]], [[1]], [[2, 3]], [[3]], [[9]], [[10]]])]
* 3,
[""] * 3,
)
else:
expected = pa.table(
[pa.array([[4], [1], [2, 3], [3], [9], [10]])] * 3,
[""] * 3,
)
elif pa.types.is_struct(dtype):
if is_nested_struct(dtype):
expected = pa.table(
Expand Down Expand Up @@ -392,7 +374,6 @@ def test_scatter_table_type_mismatch(source_table, index_column, target_table):
)


@skip_nested_list
def test_scatter_scalars(
source_scalar,
index_column,
Expand Down Expand Up @@ -670,7 +651,6 @@ def test_shift_type_mismatch(target_column):
plc.copying.shift(plc_target_column, 2, fill_value)


@skip_nested_list
def test_slice_column(target_column):
pa_target_column, plc_target_column = target_column
bounds = list(range(6))
Expand Down Expand Up @@ -699,7 +679,6 @@ def test_slice_column_out_of_bounds(target_column):
plc.copying.slice(plc_target_column, list(range(2, 8)))


@skip_nested_list
def test_slice_table(target_table):
pa_target_table, plc_target_table = target_table
bounds = list(range(6))
Expand All @@ -710,7 +689,6 @@ def test_slice_table(target_table):
assert_table_eq(pa_target_table[lb:ub], slice_)


@skip_nested_list
def test_split_column(target_column):
upper_bounds = [1, 3, 5]
lower_bounds = [0] + upper_bounds[:-1]
Expand All @@ -732,7 +710,6 @@ def test_split_column_out_of_bounds(target_column):
plc.copying.split(plc_target_column, list(range(5, 8)))


@skip_nested_list
def test_split_table(target_table):
pa_target_table, plc_target_table = target_table

Expand All @@ -743,7 +720,6 @@ def test_split_table(target_table):
assert_table_eq(pa_target_table[lb:ub], split)


@skip_nested_list
def test_copy_if_else_column_column(target_column, mask, source_scalar):
pa_target_column, plc_target_column = target_column
pa_source_scalar, _ = source_scalar
Expand Down Expand Up @@ -818,7 +794,6 @@ def test_copy_if_else_wrong_size_mask(target_column):
)


@skip_nested_list
@pytest.mark.parametrize("array_left", [True, False])
def test_copy_if_else_column_scalar(
target_column,
Expand Down Expand Up @@ -852,7 +827,6 @@ def test_copy_if_else_column_scalar(
assert_column_eq(expected, result)


@skip_nested_list
def test_boolean_mask_scatter_from_table(
source_table,
target_table,
Expand All @@ -879,9 +853,21 @@ def test_boolean_mask_scatter_from_table(
)

if pa.types.is_list(dtype := pa_target_table[0].type):
expected = pa.table(
[pa.array([[1], [5, 6], [2, 3], [8], [3], [10]])] * 3, [""] * 3
)
if is_nested_list(dtype):
expected = pa.table(
[
pa.array(
[[[1]], [[5, 6]], [[2, 3]], [[8]], [[3]], [[10]]]
)
]
* 3,
[""] * 3,
)
else:
expected = pa.table(
[pa.array([[1], [5, 6], [2, 3], [8], [3], [10]])] * 3,
[""] * 3,
)
elif pa.types.is_struct(dtype):
if is_nested_struct(dtype):
expected = pa.table(
Expand Down Expand Up @@ -989,7 +975,6 @@ def test_boolean_mask_scatter_from_wrong_mask_type(source_table, target_table):
)


@skip_nested_list
def test_boolean_mask_scatter_from_scalars(
source_scalar,
target_table,
Expand All @@ -1013,7 +998,6 @@ def test_boolean_mask_scatter_from_scalars(
assert_table_eq(expected, result)


@skip_nested_list
def test_get_element(input_column):
index = 1
pa_input_column, plc_input_column = input_column
Expand Down

0 comments on commit 0ed9af6

Please sign in to comment.