From 674531317f24c1d3f7935854a1cdb97e15f482e5 Mon Sep 17 00:00:00 2001 From: Michael-J-Ward Date: Sat, 11 May 2024 01:05:43 -0500 Subject: [PATCH] test: parametrize test_array_functions test_array_functions now has 56 passing test cases and 1 expected failure (`array_slice` being the expected failure Ref #670). test_array_function_flatten was broken out as a single test because it was an outlier in terms of test-input. test_array_function_obj_tests had a different set of asserts, so was broken out for 5 test cases. Ref #671 --- datafusion/tests/test_functions.py | 414 ++++++++++++++++------------- 1 file changed, 223 insertions(+), 191 deletions(-) diff --git a/datafusion/tests/test_functions.py b/datafusion/tests/test_functions.py index d834f587f..dd1d09d06 100644 --- a/datafusion/tests/test_functions.py +++ b/datafusion/tests/test_functions.py @@ -183,324 +183,356 @@ def test_math_functions(): ) -def test_array_functions(): - data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]] - ctx = SessionContext() - batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"]) - df = ctx.create_dataframe([[batch]]) +def py_indexof(arr, v): + try: + return arr.index(v) + 1 + except ValueError: + return np.nan + - def py_indexof(arr, v): +def py_arr_remove(arr, v, n=None): + new_arr = arr[:] + found = 0 + while found != n: try: - return arr.index(v) + 1 + new_arr.remove(v) + found += 1 except ValueError: - return np.nan - - def py_arr_remove(arr, v, n=None): - new_arr = arr[:] - found = 0 - while found != n: - try: - new_arr.remove(v) - found += 1 - except ValueError: - break - - return new_arr - - def py_arr_replace(arr, from_, to, n=None): - new_arr = arr[:] - found = 0 - while found != n: - try: - idx = new_arr.index(from_) - new_arr[idx] = to - found += 1 - except ValueError: - break - - return new_arr - - def py_arr_resize(arr, size, value): - arr = np.asarray(arr) - return np.pad( - arr, - [(0, size - arr.shape[0])], - "constant", - constant_values=value, - ) + break - def py_flatten(arr): - result = [] - for elem in arr: - if isinstance(elem, list): - result.extend(py_flatten(elem)) - else: - result.append(elem) - return result + return new_arr - col = column("arr") - test_items = [ + +def py_arr_replace(arr, from_, to, n=None): + new_arr = arr[:] + found = 0 + while found != n: + try: + idx = new_arr.index(from_) + new_arr[idx] = to + found += 1 + except ValueError: + break + + return new_arr + + +def py_arr_resize(arr, size, value): + arr = np.asarray(arr) + return np.pad( + arr, + [(0, size - arr.shape[0])], + "constant", + constant_values=value, + ) + + +def py_flatten(arr): + result = [] + for elem in arr: + if isinstance(elem, list): + result.extend(py_flatten(elem)) + else: + result.append(elem) + return result + + +@pytest.mark.parametrize( + ("stmt", "py_expr"), + [ [ - f.array_append(col, literal(99.0)), - lambda: [np.append(arr, 99.0) for arr in data], + lambda col: f.array_append(col, literal(99.0)), + lambda data: [np.append(arr, 99.0) for arr in data], ], [ - f.array_push_back(col, literal(99.0)), - lambda: [np.append(arr, 99.0) for arr in data], + lambda col: f.array_push_back(col, literal(99.0)), + lambda data: [np.append(arr, 99.0) for arr in data], ], [ - f.list_append(col, literal(99.0)), - lambda: [np.append(arr, 99.0) for arr in data], + lambda col: f.list_append(col, literal(99.0)), + lambda data: [np.append(arr, 99.0) for arr in data], ], [ - f.list_push_back(col, literal(99.0)), - lambda: [np.append(arr, 99.0) for arr in data], + lambda col: f.list_push_back(col, literal(99.0)), + lambda data: [np.append(arr, 99.0) for arr in data], ], [ - f.array_concat(col, col), - lambda: [np.concatenate([arr, arr]) for arr in data], + lambda col: f.array_concat(col, col), + lambda data: [np.concatenate([arr, arr]) for arr in data], ], [ - f.array_cat(col, col), - lambda: [np.concatenate([arr, arr]) for arr in data], + lambda col: f.array_cat(col, col), + lambda data: [np.concatenate([arr, arr]) for arr in data], ], [ - f.array_dims(col), - lambda: [[len(r)] for r in data], + lambda col: f.array_dims(col), + lambda data: [[len(r)] for r in data], ], [ - f.array_distinct(col), - lambda: [list(set(r)) for r in data], + lambda col: f.array_distinct(col), + lambda data: [list(set(r)) for r in data], ], [ - f.list_distinct(col), - lambda: [list(set(r)) for r in data], + lambda col: f.list_distinct(col), + lambda data: [list(set(r)) for r in data], ], [ - f.list_dims(col), - lambda: [[len(r)] for r in data], + lambda col: f.list_dims(col), + lambda data: [[len(r)] for r in data], ], [ - f.array_element(col, literal(1)), - lambda: [r[0] for r in data], + lambda col: f.array_element(col, literal(1)), + lambda data: [r[0] for r in data], ], [ - f.array_extract(col, literal(1)), - lambda: [r[0] for r in data], + lambda col: f.array_extract(col, literal(1)), + lambda data: [r[0] for r in data], ], [ - f.list_element(col, literal(1)), - lambda: [r[0] for r in data], + lambda col: f.list_element(col, literal(1)), + lambda data: [r[0] for r in data], ], [ - f.list_extract(col, literal(1)), - lambda: [r[0] for r in data], + lambda col: f.list_extract(col, literal(1)), + lambda data: [r[0] for r in data], ], [ - f.array_length(col), - lambda: [len(r) for r in data], + lambda col: f.array_length(col), + lambda data: [len(r) for r in data], ], [ - f.list_length(col), - lambda: [len(r) for r in data], + lambda col: f.list_length(col), + lambda data: [len(r) for r in data], ], [ - f.array_has(col, literal(1.0)), - lambda: [1.0 in r for r in data], + lambda col: f.array_has(col, literal(1.0)), + lambda data: [1.0 in r for r in data], ], [ - f.array_has_all(col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])), - lambda: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data], + lambda col: f.array_has_all( + col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) + ), + lambda data: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data], ], [ - f.array_has_any(col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])), - lambda: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], + lambda col: f.array_has_any( + col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]]) + ), + lambda data: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data], ], [ - f.array_position(col, literal(1.0)), - lambda: [py_indexof(r, 1.0) for r in data], + lambda col: f.array_position(col, literal(1.0)), + lambda data: [py_indexof(r, 1.0) for r in data], ], [ - f.array_indexof(col, literal(1.0)), - lambda: [py_indexof(r, 1.0) for r in data], + lambda col: f.array_indexof(col, literal(1.0)), + lambda data: [py_indexof(r, 1.0) for r in data], ], [ - f.list_position(col, literal(1.0)), - lambda: [py_indexof(r, 1.0) for r in data], + lambda col: f.list_position(col, literal(1.0)), + lambda data: [py_indexof(r, 1.0) for r in data], ], [ - f.list_indexof(col, literal(1.0)), - lambda: [py_indexof(r, 1.0) for r in data], + lambda col: f.list_indexof(col, literal(1.0)), + lambda data: [py_indexof(r, 1.0) for r in data], ], [ - f.array_positions(col, literal(1.0)), - lambda: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data], + lambda col: f.array_positions(col, literal(1.0)), + lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data], ], [ - f.list_positions(col, literal(1.0)), - lambda: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data], + lambda col: f.list_positions(col, literal(1.0)), + lambda data: [[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data], ], [ - f.array_ndims(col), - lambda: [np.array(r).ndim for r in data], + lambda col: f.array_ndims(col), + lambda data: [np.array(r).ndim for r in data], ], [ - f.list_ndims(col), - lambda: [np.array(r).ndim for r in data], + lambda col: f.list_ndims(col), + lambda data: [np.array(r).ndim for r in data], ], [ - f.array_prepend(literal(99.0), col), - lambda: [np.insert(arr, 0, 99.0) for arr in data], + lambda col: f.array_prepend(literal(99.0), col), + lambda data: [np.insert(arr, 0, 99.0) for arr in data], ], [ - f.array_push_front(literal(99.0), col), - lambda: [np.insert(arr, 0, 99.0) for arr in data], + lambda col: f.array_push_front(literal(99.0), col), + lambda data: [np.insert(arr, 0, 99.0) for arr in data], ], [ - f.list_prepend(literal(99.0), col), - lambda: [np.insert(arr, 0, 99.0) for arr in data], + lambda col: f.list_prepend(literal(99.0), col), + lambda data: [np.insert(arr, 0, 99.0) for arr in data], ], [ - f.list_push_front(literal(99.0), col), - lambda: [np.insert(arr, 0, 99.0) for arr in data], + lambda col: f.list_push_front(literal(99.0), col), + lambda data: [np.insert(arr, 0, 99.0) for arr in data], ], [ - f.array_pop_back(col), - lambda: [arr[:-1] for arr in data], + lambda col: f.array_pop_back(col), + lambda data: [arr[:-1] for arr in data], ], [ - f.array_pop_front(col), - lambda: [arr[1:] for arr in data], + lambda col: f.array_pop_front(col), + lambda data: [arr[1:] for arr in data], ], [ - f.array_remove(col, literal(3.0)), - lambda: [py_arr_remove(arr, 3.0, 1) for arr in data], + lambda col: f.array_remove(col, literal(3.0)), + lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data], ], [ - f.list_remove(col, literal(3.0)), - lambda: [py_arr_remove(arr, 3.0, 1) for arr in data], + lambda col: f.list_remove(col, literal(3.0)), + lambda data: [py_arr_remove(arr, 3.0, 1) for arr in data], ], [ - f.array_remove_n(col, literal(3.0), literal(2)), - lambda: [py_arr_remove(arr, 3.0, 2) for arr in data], + lambda col: f.array_remove_n(col, literal(3.0), literal(2)), + lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data], ], [ - f.list_remove_n(col, literal(3.0), literal(2)), - lambda: [py_arr_remove(arr, 3.0, 2) for arr in data], + lambda col: f.list_remove_n(col, literal(3.0), literal(2)), + lambda data: [py_arr_remove(arr, 3.0, 2) for arr in data], ], [ - f.array_remove_all(col, literal(3.0)), - lambda: [py_arr_remove(arr, 3.0) for arr in data], + lambda col: f.array_remove_all(col, literal(3.0)), + lambda data: [py_arr_remove(arr, 3.0) for arr in data], ], [ - f.list_remove_all(col, literal(3.0)), - lambda: [py_arr_remove(arr, 3.0) for arr in data], + lambda col: f.list_remove_all(col, literal(3.0)), + lambda data: [py_arr_remove(arr, 3.0) for arr in data], ], [ - f.array_repeat(col, literal(2)), - lambda: [[arr] * 2 for arr in data], + lambda col: f.array_repeat(col, literal(2)), + lambda data: [[arr] * 2 for arr in data], ], [ - f.array_replace(col, literal(3.0), literal(4.0)), - lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], + lambda col: f.array_replace(col, literal(3.0), literal(4.0)), + lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], ], [ - f.list_replace(col, literal(3.0), literal(4.0)), - lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], + lambda col: f.list_replace(col, literal(3.0), literal(4.0)), + lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], ], [ - f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)), - lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], + lambda col: f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)), + lambda data: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data], ], [ - f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)), - lambda: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data], + lambda col: f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)), + lambda data: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data], ], [ - f.array_replace_all(col, literal(3.0), literal(4.0)), - lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data], + lambda col: f.array_replace_all(col, literal(3.0), literal(4.0)), + lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data], ], [ - f.list_replace_all(col, literal(3.0), literal(4.0)), - lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data], + lambda col: f.list_replace_all(col, literal(3.0), literal(4.0)), + lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data], ], [ - f.array_slice(col, literal(2), literal(4)), - lambda: [arr[1:4] for arr in data], + lambda col: f.array_slice(col, literal(2), literal(4)), + lambda data: [arr[1:4] for arr in data], ], - # [ - # f.list_slice(col, literal(-1), literal(2)), - # lambda: [arr[-1:2] for arr in data], - # ], + pytest.param( + lambda col: f.list_slice(col, literal(-1), literal(2)), + lambda data: [arr[-1:2] for arr in data], + marks=pytest.mark.xfail, + ), [ - f.array_intersect(col, literal([3.0, 4.0])), - lambda: [np.intersect1d(arr, [3.0, 4.0]) for arr in data], + lambda col: f.array_intersect(col, literal([3.0, 4.0])), + lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data], ], [ - f.list_intersect(col, literal([3.0, 4.0])), - lambda: [np.intersect1d(arr, [3.0, 4.0]) for arr in data], + lambda col: f.list_intersect(col, literal([3.0, 4.0])), + lambda data: [np.intersect1d(arr, [3.0, 4.0]) for arr in data], ], [ - f.array_union(col, literal([12.0, 999.0])), - lambda: [np.union1d(arr, [12.0, 999.0]) for arr in data], + lambda col: f.array_union(col, literal([12.0, 999.0])), + lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data], ], [ - f.list_union(col, literal([12.0, 999.0])), - lambda: [np.union1d(arr, [12.0, 999.0]) for arr in data], + lambda col: f.list_union(col, literal([12.0, 999.0])), + lambda data: [np.union1d(arr, [12.0, 999.0]) for arr in data], ], [ - f.array_except(col, literal([3.0])), - lambda: [np.setdiff1d(arr, [3.0]) for arr in data], + lambda col: f.array_except(col, literal([3.0])), + lambda data: [np.setdiff1d(arr, [3.0]) for arr in data], ], [ - f.list_except(col, literal([3.0])), - lambda: [np.setdiff1d(arr, [3.0]) for arr in data], + lambda col: f.list_except(col, literal([3.0])), + lambda data: [np.setdiff1d(arr, [3.0]) for arr in data], ], [ - f.array_resize(col, literal(10), literal(0.0)), - lambda: [py_arr_resize(arr, 10, 0.0) for arr in data], + lambda col: f.array_resize(col, literal(10), literal(0.0)), + lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data], ], [ - f.list_resize(col, literal(10), literal(0.0)), - lambda: [py_arr_resize(arr, 10, 0.0) for arr in data], + lambda col: f.list_resize(col, literal(10), literal(0.0)), + lambda data: [py_arr_resize(arr, 10, 0.0) for arr in data], ], - [f.flatten(literal(data)), lambda: [py_flatten(data)]], [ - f.range(literal(1), literal(5), literal(2)), - lambda: [np.arange(1, 5, 2)], + lambda col: f.range(literal(1), literal(5), literal(2)), + lambda data: [np.arange(1, 5, 2)], ], - ] + ], +) +def test_array_functions(stmt, py_expr): + data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]] + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"]) + df = ctx.create_dataframe([[batch]]) - for stmt, py_expr in test_items: - query_result = df.select(stmt).collect()[0].column(0) - for a, b in zip(query_result, py_expr()): - np.testing.assert_array_almost_equal( - np.array(a.as_py(), dtype=float), np.array(b, dtype=float) - ) + col = column("arr") + query_result = df.select(stmt(col)).collect()[0].column(0) + for a, b in zip(query_result, py_expr(data)): + np.testing.assert_array_almost_equal( + np.array(a.as_py(), dtype=float), np.array(b, dtype=float) + ) - obj_test_items = [ + +def test_array_function_flatten(): + data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]] + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"]) + df = ctx.create_dataframe([[batch]]) + + stmt = f.flatten(literal(data)) + py_expr = [py_flatten(data)] + query_result = df.select(stmt).collect()[0].column(0) + for a, b in zip(query_result, py_expr): + np.testing.assert_array_almost_equal( + np.array(a.as_py(), dtype=float), np.array(b, dtype=float) + ) + + +@pytest.mark.parametrize( + ("stmt", "py_expr"), + [ [ - f.array_to_string(col, literal(",")), - lambda: [",".join([str(int(v)) for v in r]) for r in data], + f.array_to_string(column("arr"), literal(",")), + lambda data: [",".join([str(int(v)) for v in r]) for r in data], ], [ - f.array_join(col, literal(",")), - lambda: [",".join([str(int(v)) for v in r]) for r in data], + f.array_join(column("arr"), literal(",")), + lambda data: [",".join([str(int(v)) for v in r]) for r in data], ], [ - f.list_to_string(col, literal(",")), - lambda: [",".join([str(int(v)) for v in r]) for r in data], + f.list_to_string(column("arr"), literal(",")), + lambda data: [",".join([str(int(v)) for v in r]) for r in data], ], [ - f.list_join(col, literal(",")), - lambda: [",".join([str(int(v)) for v in r]) for r in data], + f.list_join(column("arr"), literal(",")), + lambda data: [",".join([str(int(v)) for v in r]) for r in data], ], - ] - - for stmt, py_expr in obj_test_items: - query_result = np.array(df.select(stmt).collect()[0].column(0)) - for a, b in zip(query_result, py_expr()): - assert a == b + ], +) +def test_array_function_obj_tests(stmt, py_expr): + data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]] + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"]) + df = ctx.create_dataframe([[batch]]) + query_result = np.array(df.select(stmt).collect()[0].column(0)) + for a, b in zip(query_result, py_expr(data)): + assert a == b def test_string_functions(df):