Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up pylibcudf test assertations #15892

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/cudf/cudf/pylibcudf_tests/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def assert_column_eq(
assert lhs.equals(rhs)


def assert_table_eq(plc_table: plc.Table, pa_table: pa.Table) -> None:
def assert_table_eq(pa_table: pa.Table, plc_table: plc.Table) -> None:
"""Verify that a pylibcudf table and PyArrow table are equal."""
plc_shape = (plc_table.num_rows(), plc_table.num_columns())
assert plc_shape == pa_table.shape
Expand Down
14 changes: 7 additions & 7 deletions python/cudf/cudf/pylibcudf_tests/test_copying.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_gather(target_table, pa_target_table, index_column, pa_index_column):
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
)
expected = pa_target_table.take(pa_index_column)
assert_table_eq(result, expected)
assert_table_eq(expected, result)


def test_gather_map_has_nulls(target_table):
Expand Down Expand Up @@ -240,7 +240,7 @@ def test_scatter_table(
pa_target_table,
)

assert_table_eq(result, expected)
assert_table_eq(expected, result)


def test_scatter_table_num_col_mismatch(
Expand Down Expand Up @@ -315,7 +315,7 @@ def test_scatter_scalars(
pa_target_table,
)

assert_table_eq(result, expected)
assert_table_eq(expected, result)


def test_scatter_scalars_num_scalars_mismatch(
Expand Down Expand Up @@ -574,7 +574,7 @@ def test_slice_table(target_table, pa_target_table):
lower_bounds = bounds[::2]
result = plc.copying.slice(target_table, bounds)
for lb, ub, slice_ in zip(lower_bounds, upper_bounds, result):
assert_table_eq(slice_, pa_target_table[lb:ub])
assert_table_eq(pa_target_table[lb:ub], slice_)


def test_split_column(target_column, pa_target_column):
Expand All @@ -600,7 +600,7 @@ def test_split_table(target_table, pa_target_table):
lower_bounds = [0] + upper_bounds[:-1]
result = plc.copying.split(target_table, upper_bounds)
for lb, ub, split in zip(lower_bounds, upper_bounds, result):
assert_table_eq(split, pa_target_table[lb:ub])
assert_table_eq(pa_target_table[lb:ub], split)


def test_copy_if_else_column_column(
Expand Down Expand Up @@ -753,7 +753,7 @@ def test_boolean_mask_scatter_from_table(
pa_source_table, pa_mask, pa_target_table
)

assert_table_eq(result, expected)
assert_table_eq(expected, result)


def test_boolean_mask_scatter_from_wrong_num_cols(source_table, target_table):
Expand Down Expand Up @@ -828,7 +828,7 @@ def test_boolean_mask_scatter_from_scalars(
pa_target_table,
)

assert_table_eq(result, expected)
assert_table_eq(expected, result)


def test_get_element(input_column, pa_input_column):
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/pylibcudf_tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_interleave_columns(reshape_data, reshape_plc_tbl):

expect = pa.concat_arrays(interleaved_data)

assert_column_eq(res, expect)
assert_column_eq(expect, res)


@pytest.mark.parametrize("cnt", [0, 1, 3])
Expand All @@ -40,4 +40,4 @@ def test_tile(reshape_data, reshape_plc_tbl, cnt):
tiled_data, schema=plc.interop.to_arrow(reshape_plc_tbl).schema
)

assert_table_eq(res, expect)
assert_table_eq(expect, res)
6 changes: 3 additions & 3 deletions python/cudf/cudf/pylibcudf_tests/test_string_capitalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@ def plc_data(pa_data):
def test_capitalize(plc_data, pa_data):
got = plc.strings.capitalize.capitalize(plc_data)
expected = pa.compute.utf8_capitalize(pa_data)
assert_column_eq(got, expected)
assert_column_eq(expected, got)


def test_title(plc_data, pa_data):
got = plc.strings.capitalize.title(
plc_data, plc.strings.char_types.StringCharacterTypes.CASE_TYPES
)
expected = pa.compute.utf8_title(pa_data)
assert_column_eq(got, expected)
assert_column_eq(expected, got)


def test_is_title(plc_data, pa_data):
got = plc.strings.capitalize.is_title(plc_data)
expected = pa.compute.utf8_is_title(pa_data)
assert_column_eq(got, expected)
assert_column_eq(expected, got)
Loading