Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update infectionswithfeedback process #440

Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a10b95c
add function to pad using edge values and tests
sbidari Sep 10, 2024
f89aa9d
modify convolve and infectionwithfeedback to work with 2d array
sbidari Sep 10, 2024
b7fb437
Merge branch 'main' into 435-update-infectionswithfeedback-process-to…
sbidari Sep 10, 2024
b0e8a3d
test convolve scanner for 2d arrays
sbidari Sep 11, 2024
1862aee
Merge branch 'main' of https://github.com/CDCgov/PyRenew into 435-upd…
sbidari Sep 11, 2024
114af6a
add tests for 2d array, infectionsrtfeedback
sbidari Sep 11, 2024
a625fde
more tests
sbidari Sep 11, 2024
91be92f
remove pad to match functions
sbidari Sep 12, 2024
396c2d7
remove tests
sbidari Sep 12, 2024
8197059
use list comprehension in test_infectionsrtfeedback
sbidari Sep 12, 2024
97d19d0
revert to using for loop
sbidari Sep 12, 2024
d933df4
add more test for convolve scanner functions
sbidari Sep 12, 2024
39667c1
Merge branch 'main' into 435-update-infectionswithfeedback-process-to…
sbidari Sep 12, 2024
c210606
add check for initial infections and Rt ndims
sbidari Sep 12, 2024
5e91418
Merge branch '435-update-infectionswithfeedback-process-to-handle-bat…
sbidari Sep 12, 2024
f5557a3
remove typos and superfluous print statements
sbidari Sep 12, 2024
46e32e8
add more tests
sbidari Sep 12, 2024
728aad3
change value of inf_feedback in tests to cover diff scenarios
sbidari Sep 12, 2024
54c0527
add input array tests for infections.py
sbidari Sep 12, 2024
a6ce6c6
add test with plate
sbidari Sep 13, 2024
f37a229
Merge branch 'main' into 435-update-infectionswithfeedback-process-to…
sbidari Sep 13, 2024
cc61d0b
code review changes
sbidari Sep 13, 2024
a4a74f0
Merge branch '435-update-infectionswithfeedback-process-to-handle-bat…
sbidari Sep 13, 2024
0cd56ad
code review suggestion
sbidari Sep 13, 2024
0f3e2fd
replace jnp.dot with einsum
sbidari Sep 13, 2024
542ad84
code review suggestions
sbidari Sep 16, 2024
bafe3b0
code review changes
sbidari Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions pyrenew/arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,66 @@
from jax.typing import ArrayLike


def pad_edges_to_match(
sbidari marked this conversation as resolved.
Show resolved Hide resolved
x: ArrayLike,
y: ArrayLike,
axis: int = 0,
pad_direction: str = "end",
fix_y: bool = False,
) -> tuple[ArrayLike, ArrayLike]:
"""
Pad the shorter array at the start or end using the
edge values to match the length of the longer array.

Parameters
----------
x : ArrayLike
First array.
y : ArrayLike
Second array.
axis : int, optional
Axis along which to add padding, by default 0
pad_direction : str, optional
Direction to pad the shorter array, either "start" or "end", by default "end".
fix_y : bool, optional
If True, raise an error when `y` is shorter than `x`, by default False.

Returns
-------
tuple[ArrayLike, ArrayLike]
Tuple of the two arrays with the same length.
"""
x = jnp.atleast_1d(x)
y = jnp.atleast_1d(y)
x_len = x.shape[axis]
y_len = y.shape[axis]
pad_size = abs(x_len - y_len)
pad_width = [(0, 0)] * x.ndim

if pad_direction not in ["start", "end"]:
raise ValueError(
"pad_direction must be either 'start' or 'end'."
f" Got {pad_direction}."
)

pad_width[axis] = {"start": (pad_size, 0), "end": (0, pad_size)}.get(
pad_direction, None
)

if x_len > y_len:
if fix_y:
raise ValueError(
"Cannot fix y when x is longer than y."
f" x_len: {x_len}, y_len: {y_len}."
)
y = jnp.pad(y, pad_width, mode="edge")

elif y_len > x_len:
x = jnp.pad(x, pad_width, mode="edge")

return x, y


def pad_to_match(
x: ArrayLike,
y: ArrayLike,
Expand Down
8 changes: 6 additions & 2 deletions pyrenew/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def _new_scanner(
new_val = transform(
multiplier * jnp.dot(array_to_convolve, history_subset)
sbidari marked this conversation as resolved.
Show resolved Hide resolved
)
latest = jnp.hstack([history_subset[1:], new_val])
latest = jnp.concatenate(
[history_subset[1:], jnp.expand_dims(new_val, axis=0)]
)
sbidari marked this conversation as resolved.
Show resolved Hide resolved
sbidari marked this conversation as resolved.
Show resolved Hide resolved
return latest, new_val

return _new_scanner
Expand Down Expand Up @@ -160,7 +162,9 @@ def _new_scanner(
m1, m2 = multipliers
m_net1 = t1(m1 * jnp.dot(arr1, history_subset))
sbidari marked this conversation as resolved.
Show resolved Hide resolved
new_val = t2(m2 * m_net1 * jnp.dot(arr2, history_subset))
sbidari marked this conversation as resolved.
Show resolved Hide resolved
latest = jnp.hstack([history_subset[1:], new_val])
latest = jnp.concatenate(
[history_subset[1:], jnp.expand_dims(new_val, axis=0)]
)
sbidari marked this conversation as resolved.
Show resolved Hide resolved
return latest, (new_val, m_net1)

return _new_scanner
Expand Down
4 changes: 2 additions & 2 deletions pyrenew/latent/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ def sample(
InfectionsSample
Named tuple with "infections".
"""
if I0.size < gen_int.size:
if I0.shape[0] < gen_int.size:
raise ValueError(
"Initial infections vector must be at least as long as "
"the generation interval. "
f"Initial infections vector length: {I0.size}, "
f"Initial infections vector length: {I0.shape[0]}, "
f"generation interval length: {gen_int.size}."
)

Expand Down
22 changes: 11 additions & 11 deletions pyrenew/latent/infectionswithfeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@
InfectionsWithFeedback
Named tuple with "infections".
"""
if I0.size < gen_int.size:
if I0.shape[0] < gen_int.size:
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Initial infections must be at least as long as the "
f"generation interval. Got {I0.size} initial infections "
f"generation interval. Got {I0.shape[0]} initial infections "
f"and {gen_int.size} generation interval."
)

Expand All @@ -161,18 +161,18 @@
)
)
# Making sure inf_feedback_strength spans the Rt length
if inf_feedback_strength.size == 1:
inf_feedback_strength = au.pad_x_to_match_y(
if inf_feedback_strength.shape[0] == 1:
inf_feedback_strength = au.pad_edges_to_match(

Check warning on line 165 in pyrenew/latent/infectionswithfeedback.py

View check run for this annotation

Codecov / codecov/patch

pyrenew/latent/infectionswithfeedback.py#L165

Added line #L165 was not covered by tests
x=inf_feedback_strength,
y=Rt,
fill_value=inf_feedback_strength[0],
)
elif inf_feedback_strength.size != Rt.size:
axis=0,
)[0]
elif inf_feedback_strength.shape != Rt.shape:
raise ValueError(
"Infection feedback strength must be of size 1 "
"or the same size as the reproduction number array. "
f"Got {inf_feedback_strength.size} "
f"and {Rt.size} respectively."
"Infection feedback strength must be of length 1 "
"or the same length as the reproduction number array. "
f"Got {inf_feedback_strength.shape} "
f"and {Rt.shape} respectively."
)

# Sampling inf feedback pmf
Expand Down
66 changes: 66 additions & 0 deletions test/test_arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,69 @@ def test_arrayutils_pad_x_to_match_y():
x_pad = au.pad_x_to_match_y(x, y)

assert x_pad.size == 3


def test_pad_edges_to_match():
"""
Test function to verify padding along the edges for 1D and 2D arrays
"""

# test when y gets padded
x = jnp.array([1, 2, 3])
y = jnp.array([1, 2])

x_pad, y_pad = au.pad_edges_to_match(x, y)
assert x_pad.size == y_pad.size
assert y_pad[-1] == y[-1]

# test when x gets padded
x = jnp.array([1, 2])
y = jnp.array([1, 2, 3])

x_pad, y_pad = au.pad_edges_to_match(x, y)
assert x_pad.size == y_pad.size
assert x_pad[-1] == x[-1]

# test when no padding required
x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])
x_pad, y_pad = au.pad_edges_to_match(x, y)

assert jnp.array_equal(x_pad, x)
assert jnp.array_equal(y_pad, y)

# Verify function works with both padding directions
x = jnp.array([1, 2, 3])
y = jnp.array([1, 2])

x_pad, y_pad = au.pad_edges_to_match(x, y, pad_direction="start")

assert x_pad.size == y_pad.size
assert y_pad[0] == y[0]

# Verify that the function raises an error when `fix_y` is True
with pytest.raises(
ValueError, match="Cannot fix y when x is longer than y"
):
x_pad, y_pad = au.pad_edges_to_match(x, y, fix_y=True)

# Verify function raises an error when pad_direction is not "start" or "end"
with pytest.raises(ValueError):
x_pad, y_pad = au.pad_edges_to_match(x, y, pad_direction="middle")

# test padding for 2D arrays
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[5, 6]])

# Padding along axis 0
axis = 0
x_pad, y_pad = au.pad_edges_to_match(x, y, axis=axis, pad_direction="end")

assert jnp.array_equal(x_pad.shape[axis], y_pad.shape[axis])
assert jnp.array_equal(y_pad[-1], y[-1])

sbidari marked this conversation as resolved.
Show resolved Hide resolved
# padding along axis 1
axis = 1
x_pad, y_pad = au.pad_edges_to_match(x, y, axis=axis, pad_direction="end")
assert jnp.array_equal(x_pad.shape[axis], y_pad.shape[axis])
assert jnp.array_equal(y[:, -1], y_pad[:, -1])
sbidari marked this conversation as resolved.
Show resolved Hide resolved
41 changes: 35 additions & 6 deletions test/test_convolve_scanners.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,33 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from numpy.testing import assert_array_equal

import pyrenew.convolve as pc


def test_double_scanner_reduces_to_single():
@pytest.mark.parametrize(
["inits", "to_scan_a", "multipliers"],
[
[
jnp.array([0.352, 5.2, -3]),
jnp.array([0.5, 0.3, 0.2]),
jnp.array(np.random.normal(0, 0.5, size=500)),
],
[
jnp.array(np.array([0.352, 5.2, -3] * 3).reshape(3, 3)),
jnp.array([0.5, 0.3, 0.2]),
jnp.array(np.random.normal(0, 0.5, size=(500, 3))),
],
],
)
def test_double_scanner_reduces_to_single(inits, to_scan_a, multipliers):
"""
Test that new_double_scanner() yields a function
that is equivalent to a single scanner if the first
scan is chosen appropriately
"""
inits = jnp.array([0.352, 5.2, -3])
to_scan_a = jnp.array([0.5, 0.3, 0.2])

multipliers = jnp.array(np.random.normal(0, 0.5, size=500))

def transform_a(x: any):
"""
Expand All @@ -39,10 +51,27 @@ def transform_a(x: any):
"""
return 4 * x + 0.025

def transform_ones_like(x: any):
sbidari marked this conversation as resolved.
Show resolved Hide resolved
"""
Generate an array of ones with the same shape as the input array.

Parameters
----------
x : any
Input value

Returns
-------
ArrayLike
An array of ones with the same shape as the input value `x`.
"""
return jnp.ones_like(x)

scanner_a = pc.new_convolve_scanner(to_scan_a, transform_a)

double_scanner_a = pc.new_double_convolve_scanner(
(jnp.array([523, 2, -0.5233]), to_scan_a), (lambda x: 1, transform_a)
(jnp.array([523, 2, -0.5233]), to_scan_a),
(transform_ones_like, transform_a),
sbidari marked this conversation as resolved.
Show resolved Hide resolved
)

_, result_a = jax.lax.scan(f=scanner_a, init=inits, xs=multipliers)
Expand Down
44 changes: 42 additions & 2 deletions test/test_infection_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import jax.numpy as jnp
import pytest
from numpy.testing import assert_array_equal

from pyrenew.latent import infection_functions as inf
Expand Down Expand Up @@ -54,6 +55,45 @@ def test_compute_infections_from_rt_with_feedback():
)

assert_array_equal(Rt_adj, Rt_raw)
pass
pass
return None


@pytest.mark.parametrize(
["I0", "gen_int", "inf_pmf", "Rt_raw"],
[
[
jnp.array([[5.0, 0.2]]),
jnp.array([1.0]),
jnp.array([1.0]),
jnp.ones((5, 2)),
],
[
3.5235 * jnp.ones((35, 3)),
jnp.ones(35) / 35,
jnp.ones(35),
jnp.zeros((253, 3)),
],
],
)
def test_compute_infections_from_rt_with_feedback_2d(
I0, gen_int, inf_pmf, Rt_raw
):
"""
Test implementation of infection feedback
when I0 and Rt are 2d arrays.
"""
(
infs_feedback,
Rt_adj,
) = inf.compute_infections_from_rt_with_feedback(
I0, Rt_raw, jnp.zeros_like(Rt_raw), gen_int, inf_pmf
)
print(inf.compute_infections_from_rt(I0, Rt_raw, gen_int))
assert_array_equal(
inf.compute_infections_from_rt(I0, Rt_raw, gen_int),
infs_feedback,
)

assert_array_equal(Rt_adj, Rt_raw)

return None
Loading