diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 44c93231e..293912321 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -795,7 +795,7 @@ def test_case(df): assert result.column(2) == pa.array(["Hola", "Mundo", None]) -def test_regr_funcs(df): +def test_regr_funcs_sql(df): # test case base on # https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330 ctx = SessionContext() @@ -817,6 +817,68 @@ def test_regr_funcs(df): assert result[0].column(8) == pa.array([0], type=pa.float64()) +def test_regr_funcs_sql_2(): + # test case based on `regr_*() basic tests + # https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1 + ctx = SessionContext() + + # Perform the regression functions using SQL + result_sql = ctx.sql( + "select " + "regr_slope(column2, column1), " + "regr_intercept(column2, column1), " + "regr_count(column2, column1), " + "regr_r2(column2, column1), " + "regr_avgx(column2, column1), " + "regr_avgy(column2, column1), " + "regr_sxx(column2, column1), " + "regr_syy(column2, column1), " + "regr_sxy(column2, column1) " + "from (values (1,2), (2,4), (3,6))" + ).collect() + + # Assertions for SQL results + assert result_sql[0].column(0) == pa.array([2], type=pa.float64()) + assert result_sql[0].column(1) == pa.array([0], type=pa.float64()) + assert result_sql[0].column(2) == pa.array([3], type=pa.float64()) # todo: i would not expect this to be float + assert result_sql[0].column(3) == pa.array([1], type=pa.float64()) + assert result_sql[0].column(4) == pa.array([2], type=pa.float64()) + assert result_sql[0].column(5) == pa.array([4], type=pa.float64()) + assert result_sql[0].column(6) == pa.array([2], type=pa.float64()) + assert result_sql[0].column(7) == pa.array([8], type=pa.float64()) + assert result_sql[0].column(8) == pa.array([4], type=pa.float64()) + + +@pytest.mark.parametrize("func, expected", [ + pytest.param(f.regr_slope, pa.array([2], type=pa.float64()), id="regr_slope"), + pytest.param(f.regr_intercept, pa.array([0], type=pa.float64()), id="regr_intercept"), + pytest.param(f.regr_count, pa.array([3], type=pa.float64()), id="regr_count"), # TODO: I would expect this to return an int array + pytest.param(f.regr_r2, pa.array([1], type=pa.float64()), id="regr_r2"), + pytest.param(f.regr_avgx, pa.array([2], type=pa.float64()), id="regr_avgx"), + pytest.param(f.regr_avgy, pa.array([4], type=pa.float64()), id="regr_avgy"), + pytest.param(f.regr_sxx, pa.array([2], type=pa.float64()), id="regr_sxx"), + pytest.param(f.regr_syy, pa.array([8], type=pa.float64()), id="regr_syy"), + pytest.param(f.regr_sxy, pa.array([4], type=pa.float64()), id="regr_sxy") +]) +def test_regr_funcs_df(func, expected): + + # test case based on `regr_*() basic tests + # https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1 + + + ctx = SessionContext() + + # Create a DataFrame + data = {'column1': [1, 2, 3], 'column2': [2, 4, 6]} + df = ctx.from_pydict(data, name="test_table") + + # Perform the regression function using DataFrame API + result_df = df.aggregate([], [func(f.col("column2"), f.col("column1"))]).collect() + + # Assertion for DataFrame API result + assert result_df[0].column(0) == expected + + def test_first_last_value(df): df = df.aggregate( [],