Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use mock instead of testing for strings in layer. #301

Merged
merged 1 commit into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,17 +1507,15 @@ def dask_groupby_agg(
group_chunks: tuple[tuple[int | float, ...]]

if method in ["map-reduce", "cohorts"]:
combine: Callable[..., IntermediateDict]
if do_simple_combine:
combine = partial(_simple_combine, reindex=reindex)
combine_name = "simple-combine"
else:
combine = partial(_grouped_combine, engine=engine, sort=sort)
combine_name = "grouped-combine"
combine: Callable[..., IntermediateDict] = (
partial(_simple_combine, reindex=reindex)
if do_simple_combine
else partial(_grouped_combine, engine=engine, sort=sort)
)

tree_reduce = partial(
dask.array.reductions._tree_reduce,
name=f"{name}-reduce-{method}-{combine_name}",
name=f"{name}-reduce-{method}",
dtype=array.dtype,
axis=axis,
keepdims=True,
Expand Down
20 changes: 16 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import warnings
from functools import partial, reduce
from typing import TYPE_CHECKING, Callable
from unittest.mock import MagicMock, patch

import numpy as np
import pandas as pd
import pytest
from numpy_groupies.aggregate_numpy import aggregate

import flox
from flox import xrutils
from flox.aggregations import Aggregation, _initialize_aggregation
from flox.core import (
Expand Down Expand Up @@ -303,6 +305,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
if chunks == -1:
params.extend([("blockwise", None)])

combine_error = RuntimeError("This combine should not have been called.")
for method, reindex in params:
call = partial(
groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
Expand All @@ -312,13 +315,22 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
with pytest.raises(NotImplementedError):
call()
continue
actual, *groups = call()
if method != "blockwise":

if method == "blockwise":
# no combine necessary
mocks = {
"_simple_combine": MagicMock(side_effect=combine_error),
"_grouped_combine": MagicMock(side_effect=combine_error),
}
else:
if "arg" not in func:
# make sure we use simple combine
assert any("simple-combine" in key for key in actual.dask.layers.keys())
mocks = {"_grouped_combine": MagicMock(side_effect=combine_error)}
else:
assert any("grouped-combine" in key for key in actual.dask.layers.keys())
mocks = {"_simple_combine": MagicMock(side_effect=combine_error)}

with patch.multiple(flox.core, **mocks):
actual, *groups = call()
for actual_group, expect in zip(groups, expected_groups):
assert_equal(actual_group, expect, tolerance)
if "arg" in func:
Expand Down