From dd1313b748f007710882d91c1d17c2523120b146 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Mon, 14 Aug 2023 12:13:55 -0700 Subject: [PATCH 1/2] fix --- python/cudf/udf_cpp/shim.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/cudf/udf_cpp/shim.cu b/python/cudf/udf_cpp/shim.cu index 0959b6ba53f..686e39e7036 100644 --- a/python/cudf/udf_cpp/shim.cu +++ b/python/cudf/udf_cpp/shim.cu @@ -643,9 +643,8 @@ __device__ double BlockCorr(T* const lhs_ptr, T* const rhs_ptr, int64_t size) { auto numerator = BlockCoVar(lhs_ptr, rhs_ptr, size); auto denominator = BlockStd(lhs_ptr, size) * BlockStd(rhs_ptr, size); - if (denominator == 0.0) { - return 0.0; + return std::numeric_limits::quiet_NaN(); } else { return numerator / denominator; } From 26ad24a7530563f553bbff3c297e4794aa3fee53 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Fri, 18 Aug 2023 12:33:26 -0700 Subject: [PATCH 2/2] Add simple test --- python/cudf/cudf/tests/test_groupby.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index b01b44da201..e578e1061ca 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -449,6 +449,21 @@ def func(group): run_groupby_apply_jit_test(groupby_jit_data, func, keys) +@pytest.mark.parametrize("dtype", ["int32", "int64"]) +def test_groupby_apply_jit_correlation_zero_variance(dtype): + # pearson correlation is undefined when the variance of either + # variable is zero. This test ensures that the jit implementation + # returns the same result as pandas in this case. + data = DataFrame( + {"a": [0, 0, 0, 0, 0], "b": [1, 1, 1, 1, 1], "c": [2, 2, 2, 2, 2]} + ) + + def func(group): + return group["b"].corr(group["c"]) + + run_groupby_apply_jit_test(data, func, ["a"]) + + @pytest.mark.parametrize("dtype", ["float64"]) @pytest.mark.parametrize("func", ["min", "max", "sum", "mean", "var", "std"]) @pytest.mark.parametrize("special_val", [np.nan, np.inf, -np.inf])