diff --git a/numpy_groupies/tests/test_generic.py b/numpy_groupies/tests/test_generic.py index 03d7b94..cea8554 100644 --- a/numpy_groupies/tests/test_generic.py +++ b/numpy_groupies/tests/test_generic.py @@ -353,6 +353,24 @@ def test_agg_along_axis(aggregate_all, size, func, axis): # instead we squeeze out the extra dims in actual. np.testing.assert_allclose(actual.squeeze(), expected) + +def test_custom_callable(aggregate_all): + def sum_(array): + return array.sum() + + size = (10,) + axis = -1 + + group_idx = np.zeros(size, dtype=int) + array = np.random.randn(*size) + + expected = array.sum(axis=axis, keepdims=True) + actual = aggregate_all(group_idx, array, axis=axis, func=sum_, fill_value=0) + assert actual.ndim == array.ndim + + np.testing.assert_allclose(actual, expected) + + def test_argreduction_nD_array_1D_idx(aggregate_all): # regression test for GH41 labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0], dtype=int) diff --git a/numpy_groupies/utils_numpy.py b/numpy_groupies/utils_numpy.py index fecf9b4..02616c2 100644 --- a/numpy_groupies/utils_numpy.py +++ b/numpy_groupies/utils_numpy.py @@ -283,7 +283,7 @@ def input_validation(group_idx, a, size=None, order='C', axis=None, else: is_form_3 = group_idx.ndim == 1 and a.ndim > 1 and axis is not None orig_shape = a.shape if is_form_3 else group_idx.shape - if "arg" in func: + if isinstance(func, str) and "arg" in func: unravel_shape = orig_shape else: unravel_shape = None