Skip to content

Commit

Permalink
Use mock instead of testing for strings in layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Dec 28, 2023
1 parent 20b662a commit 1a18ee3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
14 changes: 6 additions & 8 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,17 +1507,15 @@ def dask_groupby_agg(
group_chunks: tuple[tuple[int | float, ...]]

if method in ["map-reduce", "cohorts"]:
combine: Callable[..., IntermediateDict]
if do_simple_combine:
combine = partial(_simple_combine, reindex=reindex)
combine_name = "simple-combine"
else:
combine = partial(_grouped_combine, engine=engine, sort=sort)
combine_name = "grouped-combine"
combine: Callable[..., IntermediateDict] = (
partial(_simple_combine, reindex=reindex)
if do_simple_combine
else partial(_grouped_combine, engine=engine, sort=sort)
)

tree_reduce = partial(
dask.array.reductions._tree_reduce,
name=f"{name}-reduce-{method}-{combine_name}",
name=f"{name}-reduce-{method}",
dtype=array.dtype,
axis=axis,
keepdims=True,
Expand Down
20 changes: 16 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import warnings
from functools import partial, reduce
from typing import TYPE_CHECKING, Callable
from unittest.mock import MagicMock, patch

import numpy as np
import pandas as pd
import pytest
from numpy_groupies.aggregate_numpy import aggregate

import flox
from flox import xrutils
from flox.aggregations import Aggregation, _initialize_aggregation
from flox.core import (
Expand Down Expand Up @@ -303,6 +305,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
if chunks == -1:
params.extend([("blockwise", None)])

combine_error = RuntimeError("This combine should not have been called.")
for method, reindex in params:
call = partial(
groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
Expand All @@ -312,13 +315,22 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
with pytest.raises(NotImplementedError):
call()
continue
actual, *groups = call()
if method != "blockwise":

if method == "blockwise":
# no combine necessary
mocks = {
"_simple_combine": MagicMock(side_effect=combine_error),
"_grouped_combine": MagicMock(side_effect=combine_error),
}
else:
if "arg" not in func:
# make sure we use simple combine
assert any("simple-combine" in key for key in actual.dask.layers.keys())
mocks = {"_grouped_combine": MagicMock(side_effect=combine_error)}
else:
assert any("grouped-combine" in key for key in actual.dask.layers.keys())
mocks = {"_simple_combine": MagicMock(side_effect=combine_error)}

with patch.multiple(flox.core, **mocks):
actual, *groups = call()
for actual_group, expect in zip(groups, expected_groups):
assert_equal(actual_group, expect, tolerance)
if "arg" in func:
Expand Down

0 comments on commit 1a18ee3

Please sign in to comment.