Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update infectionswithfeedback process #440

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a10b95c
add function to pad using edge values and tests
sbidari Sep 10, 2024
f89aa9d
modify convolve and infectionwithfeedback to work with 2d array
sbidari Sep 10, 2024
b7fb437
Merge branch 'main' into 435-update-infectionswithfeedback-process-to…
sbidari Sep 10, 2024
b0e8a3d
test convolve scanner for 2d arrays
sbidari Sep 11, 2024
1862aee
Merge branch 'main' of https://github.com/CDCgov/PyRenew into 435-upd…
sbidari Sep 11, 2024
114af6a
add tests for 2d array, infectionsrtfeedback
sbidari Sep 11, 2024
a625fde
more tests
sbidari Sep 11, 2024
91be92f
remove pad to match functions
sbidari Sep 12, 2024
396c2d7
remove tests
sbidari Sep 12, 2024
8197059
use list comprehension in test_infectionsrtfeedback
sbidari Sep 12, 2024
97d19d0
revert to using for loop
sbidari Sep 12, 2024
d933df4
add more test for convolve scanner functions
sbidari Sep 12, 2024
39667c1
Merge branch 'main' into 435-update-infectionswithfeedback-process-to…
sbidari Sep 12, 2024
c210606
add check for initial infections and Rt ndims
sbidari Sep 12, 2024
5e91418
Merge branch '435-update-infectionswithfeedback-process-to-handle-bat…
sbidari Sep 12, 2024
f5557a3
remove typos and superfluous print statements
sbidari Sep 12, 2024
46e32e8
add more tests
sbidari Sep 12, 2024
728aad3
change value of inf_feedback in tests to cover diff scenarios
sbidari Sep 12, 2024
54c0527
add input array tests for infections.py
sbidari Sep 12, 2024
a6ce6c6
add test with plate
sbidari Sep 13, 2024
f37a229
Merge branch 'main' into 435-update-infectionswithfeedback-process-to…
sbidari Sep 13, 2024
cc61d0b
code review changes
sbidari Sep 13, 2024
a4a74f0
Merge branch '435-update-infectionswithfeedback-process-to-handle-bat…
sbidari Sep 13, 2024
0cd56ad
code review suggestion
sbidari Sep 13, 2024
0f3e2fd
replace jnp.dot with einsum
sbidari Sep 13, 2024
542ad84
code review suggestions
sbidari Sep 16, 2024
bafe3b0
code review changes
sbidari Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ InfFeedbackSample = namedtuple(
)
```

The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_x_to_match_y()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class:
The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_edges_to_match()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class:

```{python}
# | label: new-model-def
Expand Down Expand Up @@ -224,10 +224,9 @@ class InfFeedback(RandomVariable):

inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength)

inf_feedback_strength = au.pad_x_to_match_y(
inf_feedback_strength, _ = au.pad_edges_to_match(
x=inf_feedback_strength,
y=Rt,
fill_value=inf_feedback_strength[0],
)

# Sampling inf feedback and adjusting the shape
Expand Down
60 changes: 16 additions & 44 deletions pyrenew/arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,25 @@
from jax.typing import ArrayLike


def pad_to_match(
def pad_edges_to_match(
sbidari marked this conversation as resolved.
Show resolved Hide resolved
x: ArrayLike,
y: ArrayLike,
fill_value: float = 0.0,
axis: int = 0,
pad_direction: str = "end",
fix_y: bool = False,
) -> tuple[ArrayLike, ArrayLike]:
"""
Pad the shorter array at the start or end to match the length of the longer array.
Pad the shorter array at the start or end using the
edge values to match the length of the longer array.

Parameters
----------
x : ArrayLike
First array.
y : ArrayLike
Second array.
fill_value : float, optional
Value to use for padding, by default 0.0.
axis : int, optional
Axis along which to add padding, by default 0
pad_direction : str, optional
Direction to pad the shorter array, either "start" or "end", by default "end".
fix_y : bool, optional
Expand All @@ -38,64 +39,35 @@ def pad_to_match(
"""
x = jnp.atleast_1d(x)
y = jnp.atleast_1d(y)
x_len = x.size
y_len = y.size
x_len = x.shape[axis]
y_len = y.shape[axis]
pad_size = abs(x_len - y_len)
pad_width = [(0, 0)] * x.ndim

pad_width = {"start": (pad_size, 0), "end": (0, pad_size)}.get(
pad_direction, None
)

if pad_width is None:
if pad_direction not in ["start", "end"]:
raise ValueError(
"pad_direction must be either 'start' or 'end'."
f" Got {pad_direction}."
)

pad_width[axis] = {"start": (pad_size, 0), "end": (0, pad_size)}.get(
pad_direction, None
)

if x_len > y_len:
if fix_y:
raise ValueError(
"Cannot fix y when x is longer than y."
f" x_len: {x_len}, y_len: {y_len}."
)
y = jnp.pad(y, pad_width, constant_values=fill_value)
y = jnp.pad(y, pad_width, mode="edge")

elif y_len > x_len:
x = jnp.pad(x, pad_width, constant_values=fill_value)
x = jnp.pad(x, pad_width, mode="edge")

return x, y


def pad_x_to_match_y(
x: ArrayLike,
y: ArrayLike,
fill_value: float = 0.0,
pad_direction: str = "end",
) -> ArrayLike:
"""
Pad the `x` array at the start or end to match the length of the `y` array.

Parameters
----------
x : ArrayLike
First array.
y : ArrayLike
Second array.
fill_value : float, optional
Value to use for padding, by default 0.0.
pad_direction : str, optional
Direction to pad the shorter array, either "start" or "end", by default "end".

Returns
-------
Array
Padded array.
"""
return pad_to_match(
x, y, fill_value=fill_value, pad_direction=pad_direction, fix_y=True
)[0]


class PeriodicProcessSample(NamedTuple):
"""
A container for holding the output from `process.PeriodicProcess()`.
Expand Down
17 changes: 12 additions & 5 deletions pyrenew/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,12 @@ def _new_scanner(
history_subset: ArrayLike, multiplier: float
) -> tuple[ArrayLike, float]: # numpydoc ignore=GL08
new_val = transform(
multiplier * jnp.dot(array_to_convolve, history_subset)
multiplier
* jnp.einsum("i...,i...->...", array_to_convolve, history_subset)
)
latest = jnp.concatenate(
[history_subset[1:], new_val[jnp.newaxis]], axis=0
)
sbidari marked this conversation as resolved.
Show resolved Hide resolved
latest = jnp.hstack([history_subset[1:], new_val])
return latest, new_val

return _new_scanner
Expand Down Expand Up @@ -158,9 +161,13 @@ def _new_scanner(
multipliers: tuple[float, float],
) -> tuple[ArrayLike, tuple[float, float]]: # numpydoc ignore=GL08
m1, m2 = multipliers
m_net1 = t1(m1 * jnp.dot(arr1, history_subset))
new_val = t2(m2 * m_net1 * jnp.dot(arr2, history_subset))
latest = jnp.hstack([history_subset[1:], new_val])
m_net1 = t1(m1 * jnp.einsum("i...,i...->...", arr1, history_subset))
new_val = t2(
m2 * m_net1 * jnp.einsum("i...,i...->...", arr2, history_subset)
)
latest = jnp.concatenate(
[history_subset[1:], new_val[jnp.newaxis]], axis=0
)
return latest, (new_val, m_net1)

return _new_scanner
Expand Down
11 changes: 9 additions & 2 deletions pyrenew/latent/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,21 @@ def sample(
InfectionsSample
Named tuple with "infections".
"""
if I0.size < gen_int.size:
if I0.shape[0] < gen_int.size:
raise ValueError(
"Initial infections vector must be at least as long as "
"the generation interval. "
f"Initial infections vector length: {I0.size}, "
f"Initial infections vector length: {I0.shape[0]}, "
f"generation interval length: {gen_int.size}."
)

if I0.shape[1:] != Rt.shape[1:]:
raise ValueError(
"Initial infections and Rt must have the same batch shapes. "
f"Got initial infections of batch shape {I0.shape[1:]} "
f"and Rt of batch shape {Rt.shape[1:]}."
)

gen_int_rev = jnp.flip(gen_int)
recent_I0 = I0[-gen_int_rev.size :]

Expand Down
35 changes: 23 additions & 12 deletions pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,18 @@ def sample(
InfectionsWithFeedback
Named tuple with "infections".
"""
if I0.size < gen_int.size:
if I0.shape[0] < gen_int.size:
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Initial infections must be at least as long as the "
f"generation interval. Got {I0.size} initial infections "
f"and {gen_int.size} generation interval."
f"generation interval. Got initial infections length {I0.shape[0]}"
f"and generation interval length {gen_int.size}."
)

if I0.shape[1:] != Rt.shape[1:]:
raise ValueError(
"Initial infections and Rt must have the same batch shapes. "
f"Got initial infections of batch shape {I0.shape[1:]} "
f"and Rt of batch shape {Rt.shape[1:]}."
)

gen_int_rev = jnp.flip(gen_int)
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -160,19 +167,23 @@ def sample(
**kwargs,
)
)

if inf_feedback_strength.ndim == Rt.ndim - 1:
inf_feedback_strength = inf_feedback_strength[jnp.newaxis]

# Making sure inf_feedback_strength spans the Rt length
if inf_feedback_strength.size == 1:
inf_feedback_strength = au.pad_x_to_match_y(
if inf_feedback_strength.shape[0] == 1:
inf_feedback_strength = au.pad_edges_to_match(
x=inf_feedback_strength,
y=Rt,
fill_value=inf_feedback_strength[0],
)
elif inf_feedback_strength.size != Rt.size:
axis=0,
)[0]
if inf_feedback_strength.shape != Rt.shape:
raise ValueError(
"Infection feedback strength must be of size 1 "
"or the same size as the reproduction number array. "
f"Got {inf_feedback_strength.size} "
f"and {Rt.size} respectively."
"Infection feedback strength must be of length 1 "
"or the same length as the reproduction number array. "
f"Got {inf_feedback_strength.shape} "
f"and {Rt.shape} respectively."
)

# Sampling inf feedback pmf
Expand Down
75 changes: 47 additions & 28 deletions test/test_arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,53 +8,72 @@
import pyrenew.arrayutils as au


def test_arrayutils_pad_to_match():
def test_pad_edges_to_match():
"""
Verifies extension when required and error when `fix_y` is True.
Test function to verify padding along the edges for 1D and 2D arrays
"""

# test when y gets padded
x = jnp.array([1, 2, 3])
y = jnp.array([1, 2])

x_pad, y_pad = au.pad_to_match(x, y)

x_pad, y_pad = au.pad_edges_to_match(x, y)
assert x_pad.size == y_pad.size
assert x_pad.size == 3
assert y_pad[-1] == y[-1]
assert jnp.array_equal(x_pad, x)

# test when x gets padded
x = jnp.array([1, 2])
y = jnp.array([1, 2, 3])

x_pad, y_pad = au.pad_to_match(x, y)

x_pad, y_pad = au.pad_edges_to_match(x, y)
assert x_pad.size == y_pad.size
assert x_pad.size == 3
assert x_pad[-1] == x[-1]
assert jnp.array_equal(y_pad, y)

# test when no padding required
x = jnp.array([1, 2, 3])
y = jnp.array([1, 2])
y = jnp.array([4, 5, 6])
x_pad, y_pad = au.pad_edges_to_match(x, y)

# Verify that the function raises an error when `fix_y` is True
with pytest.raises(ValueError):
x_pad, y_pad = au.pad_to_match(x, y, fix_y=True)
assert jnp.array_equal(x_pad, x)
assert jnp.array_equal(y_pad, y)

# Verify function works with both padding directions
x_pad, y_pad = au.pad_to_match(x, y, pad_direction="start")
x = jnp.array([1, 2, 3])
y = jnp.array([1, 2])

x_pad, y_pad = au.pad_edges_to_match(x, y, pad_direction="start")

assert x_pad.size == y_pad.size
assert x_pad.size == 3
assert y_pad[0] == y[0]
assert jnp.array_equal(x_pad, x)

# Verify that the function raises an error when `fix_y` is True
with pytest.raises(
ValueError, match="Cannot fix y when x is longer than y"
):
x_pad, y_pad = au.pad_edges_to_match(x, y, fix_y=True)

# Verify function raises an error when pad_direction is not "start" or "end"
with pytest.raises(ValueError):
x_pad, y_pad = au.pad_to_match(x, y, pad_direction="middle")


def test_arrayutils_pad_x_to_match_y():
"""
Verifies extension when required
"""

x = jnp.array([1, 2])
y = jnp.array([1, 2, 3])

x_pad = au.pad_x_to_match_y(x, y)

assert x_pad.size == 3
x_pad, y_pad = au.pad_edges_to_match(x, y, pad_direction="middle")

# test padding for 2D arrays
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[5, 6]])

# Padding along axis 0
axis = 0
x_pad, y_pad = au.pad_edges_to_match(x, y, axis=axis, pad_direction="end")

assert jnp.array_equal(x_pad.shape[axis], y_pad.shape[axis])
assert jnp.array_equal(y_pad[-1], y[-1])
assert jnp.array_equal(x_pad, x)

sbidari marked this conversation as resolved.
Show resolved Hide resolved
# padding along axis 1
axis = 1
x_pad, y_pad = au.pad_edges_to_match(x, y, axis=axis, pad_direction="end")
assert jnp.array_equal(x_pad.shape[axis], y_pad.shape[axis])
assert jnp.array_equal(y[:, -1], y_pad[:, -1])
sbidari marked this conversation as resolved.
Show resolved Hide resolved
assert jnp.array_equal(x_pad, x)
Loading