Skip to content

Commit

Permalink
fix all nested struct cases
Browse files Browse the repository at this point in the history
  • Loading branch information
lithomas1 committed Jun 24, 2024
1 parent e6c3ec7 commit 624d444
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 82 deletions.
27 changes: 15 additions & 12 deletions python/cudf/cudf/pylibcudf_tests/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def assert_column_eq(
if isinstance(rhs, pa.ChunkedArray):
rhs = rhs.combine_chunks()

# print(lhs)
# print(rhs)
assert lhs.equals(rhs)


Expand Down Expand Up @@ -189,19 +191,20 @@ def sink_to_str(sink):
DEFAULT_STRUCT_TESTING_TYPE = pa.struct(
[pa.field("v", pa.int64(), nullable=False)]
)
NESTED_STRUCT_TESTING_TYPE = pa.struct(
[
pa.field("a", pa.int64(), nullable=False),
pa.field(
"b_struct",
pa.struct([pa.field("b", pa.float64(), nullable=False)]),
nullable=False,
),
]
)

DEFAULT_PA_STRUCT_TESTING_TYPES = [DEFAULT_STRUCT_TESTING_TYPE] + [
# Nested case
pa.struct(
[
pa.field("a", pa.int64(), nullable=False),
pa.field(
"b_struct",
pa.struct([pa.field("b", pa.float64(), nullable=False)]),
nullable=False,
),
]
),
DEFAULT_PA_STRUCT_TESTING_TYPES = [
DEFAULT_STRUCT_TESTING_TYPE,
NESTED_STRUCT_TESTING_TYPE,
]

DEFAULT_PA_TYPES = (
Expand Down
143 changes: 73 additions & 70 deletions python/cudf/cudf/pylibcudf_tests/test_copying.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from utils import (
DEFAULT_STRUCT_TESTING_TYPE,
NESTED_STRUCT_TESTING_TYPE,
assert_column_eq,
assert_table_eq,
cudf_raises,
Expand All @@ -20,40 +21,6 @@
from cudf._lib import pylibcudf as plc


@pytest.fixture
def nested_struct_xfail(request):
"""
Fixture that xfails a test if we encounter a nested struct.
(as of right now, we are encountering some errors in interop
when this happens)
"""
if "target_table" in request.fixturenames:
pa_table, _ = request.getfixturevalue("target_table")
request.applymarker(
pytest.mark.xfail(
condition=any(
is_nested_struct(col.type) for col in pa_table.columns
),
reason="pylibcudf interop fails with nested struct",
)
)
elif (
"target_column" in request.fixturenames
or "input_column" in request.fixturenames
):
# Return value is tuple of (engine, precision)
if "target_column" in request.fixturenames:
pa_col, _ = request.getfixturevalue("target_column")
else:
pa_col, _ = request.getfixturevalue("input_column")
request.applymarker(
pytest.mark.xfail(
condition=is_nested_struct(pa_col.type),
reason="pylibcudf interop fails with nested struct",
)
)


@pytest.fixture
def nested_list_skip(request):
"""
Expand Down Expand Up @@ -284,7 +251,6 @@ def _pyarrow_boolean_mask_scatter_table(source, mask, target_table):


@skip_nested_list
@xfail_nested_struct
def test_scatter_table(
source_table,
index_column,
Expand Down Expand Up @@ -318,23 +284,42 @@ def test_scatter_table(
[pa.array([[4], [1], [2, 3], [3], [9], [10]])] * 3, [""] * 3
)
elif pa.types.is_struct(dtype):
expected = pa.table(
[
pa.array(
[
{"v": 4},
{"v": 1},
{"v": 2},
{"v": 3},
{"v": 8},
{"v": 9},
],
type=DEFAULT_STRUCT_TESTING_TYPE,
)
]
* 3,
[""] * 3,
)
if is_nested_struct(dtype):
expected = pa.table(
[
pa.array(
[
{"a": 4, "b_struct": {"b": 4.0}},
{"a": 1, "b_struct": {"b": 1.0}},
{"a": 2, "b_struct": {"b": 2.0}},
{"a": 3, "b_struct": {"b": 3.0}},
{"a": 8, "b_struct": {"b": 8.0}},
{"a": 9, "b_struct": {"b": 9.0}},
],
type=NESTED_STRUCT_TESTING_TYPE,
)
]
* 3,
[""] * 3,
)
else:
expected = pa.table(
[
pa.array(
[
{"v": 4},
{"v": 1},
{"v": 2},
{"v": 3},
{"v": 8},
{"v": 9},
],
type=DEFAULT_STRUCT_TESTING_TYPE,
)
]
* 3,
[""] * 3,
)
else:
expected = _pyarrow_boolean_mask_scatter_table(
pa_source_table,
Expand Down Expand Up @@ -868,7 +853,6 @@ def test_copy_if_else_column_scalar(


@skip_nested_list
@xfail_nested_struct
def test_boolean_mask_scatter_from_table(
source_table,
target_table,
Expand Down Expand Up @@ -899,23 +883,42 @@ def test_boolean_mask_scatter_from_table(
[pa.array([[1], [5, 6], [2, 3], [8], [3], [10]])] * 3, [""] * 3
)
elif pa.types.is_struct(dtype):
expected = pa.table(
[
pa.array(
[
{"v": 1},
{"v": 5},
{"v": 2},
{"v": 7},
{"v": 3},
{"v": 9},
],
type=DEFAULT_STRUCT_TESTING_TYPE,
)
]
* 3,
[""] * 3,
)
if is_nested_struct(dtype):
expected = pa.table(
[
pa.array(
[
{"a": 1, "b_struct": {"b": 1.0}},
{"a": 5, "b_struct": {"b": 5.0}},
{"a": 2, "b_struct": {"b": 2.0}},
{"a": 7, "b_struct": {"b": 7.0}},
{"a": 3, "b_struct": {"b": 3.0}},
{"a": 9, "b_struct": {"b": 9.0}},
],
type=NESTED_STRUCT_TESTING_TYPE,
)
]
* 3,
[""] * 3,
)
else:
expected = pa.table(
[
pa.array(
[
{"v": 1},
{"v": 5},
{"v": 2},
{"v": 7},
{"v": 3},
{"v": 9},
],
type=DEFAULT_STRUCT_TESTING_TYPE,
)
]
* 3,
[""] * 3,
)
else:
expected = _pyarrow_boolean_mask_scatter_table(
pa_source_table, pa_mask, pa_target_table
Expand Down

0 comments on commit 624d444

Please sign in to comment.