Skip to content

Commit

Permalink
Refactor hypothesis strategies
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jul 27, 2024
1 parent 22140eb commit 7545f84
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 118 deletions.
125 changes: 125 additions & 0 deletions tests/strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import cftime
import dask
import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
import numpy as np

from . import ALL_FUNCS, SCIPY_STATS_FUNCS


def supported_dtypes() -> st.SearchStrategy[np.dtype]:
return (
npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
| npst.complex_number_dtypes(endianness="=")
| npst.datetime64_dtypes(endianness="=")
| npst.timedelta64_dtypes(endianness="=")
| npst.unicode_string_dtypes(endianness="=")
)


# TODO: stop excluding everything but U
array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
by_dtype_st = supported_dtypes()

NON_NUMPY_FUNCS = ["first", "last", "nanfirst", "nanlast", "count", "any", "all"] + list(
SCIPY_STATS_FUNCS
)
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]

func_st = st.sampled_from(
[f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]
)
numeric_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
)
all_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes()
)


calendars = st.sampled_from(
[
"standard",
"gregorian",
"proleptic_gregorian",
"noleap",
"365_day",
"360_day",
"julian",
"all_leap",
"366_day",
]
)


@st.composite
def units(draw, *, calendar: str):
choices = ["days", "hours", "minutes", "seconds", "milliseconds", "microseconds"]
if calendar == "360_day":
choices += ["months"]
elif calendar == "noleap":
choices += ["common_years"]
time_units = draw(st.sampled_from(choices))

dt = draw(st.datetimes())
year, month, day = dt.year, dt.month, dt.day
if calendar == "360_day":
month %= 30
return f"{time_units} since {year}-{month}-{day}"


@st.composite
def cftime_arrays(draw, *, shape, calendars=calendars, elements=None):
if elements is None:
elements = {"min_value": -10_000, "max_value": 10_000}
cal = draw(calendars)
values = draw(npst.arrays(dtype=np.int64, shape=shape, elements=elements))
unit = draw(units(calendar=cal))
return cftime.num2date(values, units=unit, calendar=cal)


def by_arrays(shape, *, elements=None):
return st.one_of(
npst.arrays(
dtype=npst.integer_dtypes(endianness="=") | npst.unicode_string_dtypes(endianness="="),
shape=shape,
elements=elements,
),
cftime_arrays(shape=shape, elements=elements),
)


@st.composite
def chunks(draw, *, shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
chunks = []
for size in shape:
if size > 1:
nchunks = draw(st.integers(min_value=1, max_value=size - 1))
dividers = sorted(
set(draw(st.integers(min_value=1, max_value=size - 1)) for _ in range(nchunks - 1))
)
chunks.append(tuple(a - b for a, b in zip(dividers + [size], [0] + dividers)))
else:
chunks.append((1,))
return tuple(chunks)


@st.composite
def chunked_arrays(draw, *, chunks=chunks, arrays=numeric_arrays, from_array=dask.array.from_array):
array = draw(arrays)
chunks = draw(chunks(shape=array.shape))

if array.dtype.kind in "cf":
nan_idx = draw(
st.lists(
st.integers(min_value=0, max_value=array.shape[-1] - 1),
max_size=array.shape[-1] - 1,
unique=True,
)
)
if nan_idx:
array[..., nan_idx] = np.nan

return from_array(array, chunks=chunks)
120 changes: 2 additions & 118 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@
pytest.importorskip("dask")
pytest.importorskip("cftime")

import cftime
import dask
import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
import numpy as np
from hypothesis import assume, given, note

import flox
from flox.core import groupby_reduce, groupby_scan

from . import ALL_FUNCS, SCIPY_STATS_FUNCS, assert_equal
from . import assert_equal
from .strategies import all_arrays, by_arrays, chunked_arrays, func_st, numeric_arrays

dask.config.set(scheduler="sync")

Expand All @@ -32,94 +31,13 @@ def bfill(array, axis, dtype=None):
)[::-1]


NON_NUMPY_FUNCS = ["first", "last", "nanfirst", "nanlast", "count", "any", "all"] + list(
SCIPY_STATS_FUNCS
)
SKIPPED_FUNCS = ["var", "std", "nanvar", "nanstd"]
NUMPY_SCAN_FUNCS = {
"nancumsum": np.nancumsum,
"ffill": ffill,
"bfill": bfill,
} # "cumsum": np.cumsum,


def supported_dtypes() -> st.SearchStrategy[np.dtype]:
return (
npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=", sizes=(32, 64))
| npst.complex_number_dtypes(endianness="=")
| npst.datetime64_dtypes(endianness="=")
| npst.timedelta64_dtypes(endianness="=")
| npst.unicode_string_dtypes(endianness="=")
)


# TODO: stop excluding everything but U
array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
by_dtype_st = supported_dtypes()
func_st = st.sampled_from(
[f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS]
)
numeric_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
)
all_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=supported_dtypes()
)

calendars = st.sampled_from(
[
"standard",
"gregorian",
"proleptic_gregorian",
"noleap",
"365_day",
"360_day",
"julian",
"all_leap",
"366_day",
]
)


@st.composite
def units(draw, *, calendar: str):
choices = ["days", "hours", "minutes", "seconds", "milliseconds", "microseconds"]
if calendar == "360_day":
choices += ["months"]
elif calendar == "noleap":
choices += ["common_years"]
time_units = draw(st.sampled_from(choices))

dt = draw(st.datetimes())
year, month, day = dt.year, dt.month, dt.day
if calendar == "360_day":
month %= 30
return f"{time_units} since {year}-{month}-{day}"


@st.composite
def cftime_arrays(draw, *, shape, calendars=calendars, elements=None):
if elements is None:
elements = {"min_value": -10_000, "max_value": 10_000}
cal = draw(calendars)
values = draw(npst.arrays(dtype=np.int64, shape=shape, elements=elements))
unit = draw(units(calendar=cal))
return cftime.num2date(values, units=unit, calendar=cal)


def by_arrays(shape, *, elements=None):
return st.one_of(
npst.arrays(
dtype=npst.integer_dtypes(endianness="=") | npst.unicode_string_dtypes(endianness="="),
shape=shape,
elements=elements,
),
cftime_arrays(shape=shape, elements=elements),
)


def not_overflowing_array(array) -> bool:
if array.dtype.kind == "f":
info = np.finfo(array.dtype)
Expand All @@ -133,40 +51,6 @@ def not_overflowing_array(array) -> bool:
return result


@st.composite
def chunks(draw, *, shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
chunks = []
for size in shape:
if size > 1:
nchunks = draw(st.integers(min_value=1, max_value=size - 1))
dividers = sorted(
set(draw(st.integers(min_value=1, max_value=size - 1)) for _ in range(nchunks - 1))
)
chunks.append(tuple(a - b for a, b in zip(dividers + [size], [0] + dividers)))
else:
chunks.append((1,))
return tuple(chunks)


@st.composite
def chunked_arrays(draw, *, chunks=chunks, arrays=numeric_arrays, from_array=dask.array.from_array):
array = draw(arrays)
chunks = draw(chunks(shape=array.shape))

if array.dtype.kind in "cf":
nan_idx = draw(
st.lists(
st.integers(min_value=0, max_value=array.shape[-1] - 1),
max_size=array.shape[-1] - 1,
unique=True,
)
)
if nan_idx:
array[..., nan_idx] = np.nan

return from_array(array, chunks=chunks)


# TODO: migrate to by_arrays but with constant value
@given(data=st.data(), array=numeric_arrays, func=func_st)
def test_groupby_reduce(data, array, func):
Expand Down

0 comments on commit 7545f84

Please sign in to comment.