Skip to content

Commit

Permalink
Add Aggregation.preserves_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 7, 2024
1 parent 7e5dbe9 commit 1e8fbf2
Showing 1 changed file with 27 additions and 20 deletions.
47 changes: 27 additions & 20 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
final_dtype: DTypeLike | None = None,
reduction_type: Literal["reduce", "argreduce"] = "reduce",
new_dims_func: Callable | None = None,
preserves_dtype: bool = False,
):
"""
Blueprint for computing grouped aggregations.
Expand Down Expand Up @@ -202,6 +203,8 @@ def __init__(
Function that receives finalize_kwargs and returns a tupleof sizes of any new dimensions
added by the reduction. For e.g. quantile for q=(0.5, 0.85) adds a new dimension of size 2,
so returns (2,)
preserves_dtype: bool,
Whether a function preserves the dtype on return E.g. min, max, first, last, mode
"""
self.name = name
# preprocess before blockwise
Expand Down Expand Up @@ -238,6 +241,7 @@ def __init__(
self.new_dims_func: Callable = (
returns_empty_tuple if new_dims_func is None else new_dims_func
)
self.preserves_dtype = preserves_dtype

@cached_property
def new_dims(self) -> tuple[Dim]:
Expand Down Expand Up @@ -380,10 +384,14 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
)


min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA)
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, preserves_dtype=True)
nanmin = Aggregation(
"nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA, preserves_dtype=True
)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, preserves_dtype=True)
nanmax = Aggregation(
"nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA, preserves_dtype=True
)


def argreduce_preprocess(array, axis):
Expand Down Expand Up @@ -471,10 +479,14 @@ def _pick_second(*x):
final_dtype=np.intp,
)

first = Aggregation("first", chunk=None, combine=None, fill_value=None)
last = Aggregation("last", chunk=None, combine=None, fill_value=None)
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA)
nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA)
first = Aggregation("first", chunk=None, combine=None, fill_value=None, preserves_dtype=True)
last = Aggregation("last", chunk=None, combine=None, fill_value=None, preserves_dtype=True)
nanfirst = Aggregation(
"nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA, preserves_dtype=True
)
nanlast = Aggregation(
"nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA, preserves_dtype=True
)

all_ = Aggregation(
"all",
Expand Down Expand Up @@ -525,8 +537,12 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
final_dtype=np.floating,
new_dims_func=quantile_new_dims_func,
)
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)
mode = Aggregation(
name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True
)
nanmode = Aggregation(
name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True
)


@dataclass
Expand Down Expand Up @@ -765,16 +781,7 @@ def _initialize_aggregation(
final_dtype = dtypes._normalize_dtype(
dtype_ or agg.dtype_init["final"], array_dtype, fill_value
)
if agg.name not in [
"first",
"last",
"nanfirst",
"nanlast",
"min",
"max",
"nanmin",
"nanmax",
]:
if not agg.preserves_dtype:
final_dtype = dtypes._maybe_promote_int(final_dtype)
agg.dtype = {
"user": dtype, # Save to automatically choose an engine
Expand Down

0 comments on commit 1e8fbf2

Please sign in to comment.