From a10b95cbb823aba00a4006cad6eddd71d34429c9 Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 10 Sep 2024 14:08:15 -0400 Subject: [PATCH 01/21] add function to pad using edge values and tests --- pyrenew/arrayutils.py | 60 +++++++++++++++++++++++++++++++++++++ test/test_arrayutils.py | 66 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) diff --git a/pyrenew/arrayutils.py b/pyrenew/arrayutils.py index 402fbe2c..9225bcec 100644 --- a/pyrenew/arrayutils.py +++ b/pyrenew/arrayutils.py @@ -8,6 +8,66 @@ from jax.typing import ArrayLike +def pad_edges_to_match( + x: ArrayLike, + y: ArrayLike, + axis: int = 0, + pad_direction: str = "end", + fix_y: bool = False, +) -> tuple[ArrayLike, ArrayLike]: + """ + Pad the shorter array at the start or end using the + edge values to match the length of the longer array. + + Parameters + ---------- + x : ArrayLike + First array. + y : ArrayLike + Second array. + axis : int, optional + Axis along which to add padding, by default 0 + pad_direction : str, optional + Direction to pad the shorter array, either "start" or "end", by default "end". + fix_y : bool, optional + If True, raise an error when `y` is shorter than `x`, by default False. + + Returns + ------- + tuple[ArrayLike, ArrayLike] + Tuple of the two arrays with the same length. + """ + x = jnp.atleast_1d(x) + y = jnp.atleast_1d(y) + x_len = x.shape[axis] + y_len = y.shape[axis] + pad_size = abs(x_len - y_len) + + pad_width = [(0, 0)] * x.ndim + pad_width[axis] = {"start": (pad_size, 0), "end": (0, pad_size)}.get( + pad_direction, None + ) + + if pad_direction not in ["start", "end"]: + raise ValueError( + "pad_direction must be either 'start' or 'end'." + f" Got {pad_direction}." + ) + + if x_len > y_len: + if fix_y: + raise ValueError( + "Cannot fix y when x is longer than y." + f" x_len: {x_len}, y_len: {y_len}." + ) + y = jnp.pad(y, pad_width, mode="edge") + + elif y_len > x_len: + x = jnp.pad(x, pad_width, mode="edge") + + return x, y + + def pad_to_match( x: ArrayLike, y: ArrayLike, diff --git a/test/test_arrayutils.py b/test/test_arrayutils.py index 9e168de1..38685a5c 100644 --- a/test/test_arrayutils.py +++ b/test/test_arrayutils.py @@ -58,3 +58,69 @@ def test_arrayutils_pad_x_to_match_y(): x_pad = au.pad_x_to_match_y(x, y) assert x_pad.size == 3 + + +def test_pad_edges_to_match(): + """ + Test function to verify padding along the edges for 1D and 2D arrays + """ + + # test when y gets padded + x = jnp.array([1, 2, 3]) + y = jnp.array([1, 2]) + + x_pad, y_pad = au.pad_edges_to_match(x, y) + assert x_pad.size == y_pad.size + assert y_pad[-1] == y[-1] + + # test when x gets padded + x = jnp.array([1, 2]) + y = jnp.array([1, 2, 3]) + + x_pad, y_pad = au.pad_edges_to_match(x, y) + assert x_pad.size == y_pad.size + assert x_pad[-1] == x[-1] + + # test when no padding required + x = jnp.array([1, 2, 3]) + y = jnp.array([4, 5, 6]) + x_pad, y_pad = au.pad_edges_to_match(x, y) + + assert jnp.array_equal(x_pad, x) + assert jnp.array_equal(y_pad, y) + + # Verify function works with both padding directions + x = jnp.array([1, 2, 3]) + y = jnp.array([1, 2]) + + x_pad, y_pad = au.pad_edges_to_match(x, y, pad_direction="start") + + assert x_pad.size == y_pad.size + assert y_pad[0] == y[0] + + # Verify that the function raises an error when `fix_y` is True + with pytest.raises( + ValueError, match="Cannot fix y when x is longer than y" + ): + x_pad, y_pad = au.pad_edges_to_match(x, y, fix_y=True) + + # Verify function raises an error when pad_direction is not "start" or "end" + with pytest.raises(ValueError): + x_pad, y_pad = au.pad_edges_to_match(x, y, pad_direction="middle") + + # test padding for 2D arrays + x = jnp.array([[1, 2], [3, 4]]) + y = jnp.array([[5, 6]]) + + # Padding along axis 0 + axis = 0 + x_pad, y_pad = au.pad_edges_to_match(x, y, axis=axis, pad_direction="end") + + assert jnp.array_equal(x_pad.shape[axis], y_pad.shape[axis]) + assert jnp.array_equal(y_pad[-1], y[-1]) + + # padding along axis 1 + axis = 1 + x_pad, y_pad = au.pad_edges_to_match(x, y, axis=axis, pad_direction="end") + assert jnp.array_equal(x_pad.shape[axis], y_pad.shape[axis]) + assert jnp.array_equal(y[:, -1], y_pad[:, -1]) From f89aa9d7f73abce1dc6564a5f3f8d20ca96cc0f3 Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 10 Sep 2024 14:14:21 -0400 Subject: [PATCH 02/21] modify convolve and infectionwithfeedback to work with 2d array --- pyrenew/convolve.py | 5 ++++- pyrenew/latent/infectionswithfeedback.py | 26 +++++++++++++----------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/pyrenew/convolve.py b/pyrenew/convolve.py index 44d98e4c..bd427615 100755 --- a/pyrenew/convolve.py +++ b/pyrenew/convolve.py @@ -11,6 +11,7 @@ :py:func:`jax.lax.scan` with an appropriate array to scan along. """ + from __future__ import annotations from typing import Callable @@ -159,7 +160,9 @@ def _new_scanner( m1, m2 = multipliers m_net1 = t1(m1 * jnp.dot(arr1, history_subset)) new_val = t2(m2 * m_net1 * jnp.dot(arr2, history_subset)) - latest = jnp.hstack([history_subset[1:], new_val]) + latest = jnp.concatenate( + [history_subset[1:], jnp.expand_dims(new_val, axis=0)] + ) return latest, (new_val, m_net1) return _new_scanner diff --git a/pyrenew/latent/infectionswithfeedback.py b/pyrenew/latent/infectionswithfeedback.py index 833087e2..f9e6ada5 100644 --- a/pyrenew/latent/infectionswithfeedback.py +++ b/pyrenew/latent/infectionswithfeedback.py @@ -147,10 +147,10 @@ def sample( InfectionsWithFeedback Named tuple with "infections". """ - if I0.size < gen_int.size: + if I0.shape[0] < gen_int.size: raise ValueError( "Initial infections must be at least as long as the " - f"generation interval. Got {I0.size} initial infections " + f"generation interval. Got {I0.shape[0]} initial infections " f"and {gen_int.size} generation interval." ) @@ -162,21 +162,23 @@ def sample( inf_feedback_strength = jnp.atleast_1d( self.infection_feedback_strength( **kwargs, - )[0].value + )[ + 0 + ].value # [jnp.newaxis, :] ) # Making sure inf_feedback_strength spans the Rt length - if inf_feedback_strength.size == 1: - inf_feedback_strength = au.pad_x_to_match_y( + if inf_feedback_strength.shape[0] == 1: + inf_feedback_strength = au.pad_edges_to_match( x=inf_feedback_strength, y=Rt, - fill_value=inf_feedback_strength[0], - ) - elif inf_feedback_strength.size != Rt.size: + axis=0, + )[0] + elif inf_feedback_strength.shape != Rt.shape: raise ValueError( - "Infection feedback strength must be of size 1 " - "or the same size as the reproduction number array. " - f"Got {inf_feedback_strength.size} " - f"and {Rt.size} respectively." + "Infection feedback strength must be of length 1 " + "or the same length as the reproduction number array. " + f"Got {inf_feedback_strength.shape} " + f"and {Rt.shape} respectively." ) # Sampling inf feedback pmf From b0e8a3d6f4de132dd8bc1ca81a32be51a0888503 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 11 Sep 2024 11:23:32 -0400 Subject: [PATCH 03/21] test convolve scanner for 2d arrays --- pyrenew/arrayutils.py | 8 +++---- pyrenew/convolve.py | 4 +++- test/test_convolve_scanners.py | 41 +++++++++++++++++++++++++++++----- 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/pyrenew/arrayutils.py b/pyrenew/arrayutils.py index 9225bcec..d5681ee4 100644 --- a/pyrenew/arrayutils.py +++ b/pyrenew/arrayutils.py @@ -42,11 +42,7 @@ def pad_edges_to_match( x_len = x.shape[axis] y_len = y.shape[axis] pad_size = abs(x_len - y_len) - pad_width = [(0, 0)] * x.ndim - pad_width[axis] = {"start": (pad_size, 0), "end": (0, pad_size)}.get( - pad_direction, None - ) if pad_direction not in ["start", "end"]: raise ValueError( @@ -54,6 +50,10 @@ def pad_edges_to_match( f" Got {pad_direction}." ) + pad_width[axis] = {"start": (pad_size, 0), "end": (0, pad_size)}.get( + pad_direction, None + ) + if x_len > y_len: if fix_y: raise ValueError( diff --git a/pyrenew/convolve.py b/pyrenew/convolve.py index 03fa6322..575d140b 100755 --- a/pyrenew/convolve.py +++ b/pyrenew/convolve.py @@ -80,7 +80,9 @@ def _new_scanner( new_val = transform( multiplier * jnp.dot(array_to_convolve, history_subset) ) - latest = jnp.hstack([history_subset[1:], new_val]) + latest = jnp.concatenate( + [history_subset[1:], jnp.expand_dims(new_val, axis=0)] + ) return latest, new_val return _new_scanner diff --git a/test/test_convolve_scanners.py b/test/test_convolve_scanners.py index 12430004..77071686 100644 --- a/test/test_convolve_scanners.py +++ b/test/test_convolve_scanners.py @@ -6,21 +6,33 @@ import jax import jax.numpy as jnp import numpy as np +import pytest from numpy.testing import assert_array_equal import pyrenew.convolve as pc -def test_double_scanner_reduces_to_single(): +@pytest.mark.parametrize( + ["inits", "to_scan_a", "multipliers"], + [ + [ + jnp.array([0.352, 5.2, -3]), + jnp.array([0.5, 0.3, 0.2]), + jnp.array(np.random.normal(0, 0.5, size=500)), + ], + [ + jnp.array(np.array([0.352, 5.2, -3] * 3).reshape(3, 3)), + jnp.array([0.5, 0.3, 0.2]), + jnp.array(np.random.normal(0, 0.5, size=(500, 3))), + ], + ], +) +def test_double_scanner_reduces_to_single(inits, to_scan_a, multipliers): """ Test that new_double_scanner() yields a function that is equivalent to a single scanner if the first scan is chosen appropriately """ - inits = jnp.array([0.352, 5.2, -3]) - to_scan_a = jnp.array([0.5, 0.3, 0.2]) - - multipliers = jnp.array(np.random.normal(0, 0.5, size=500)) def transform_a(x: any): """ @@ -39,10 +51,27 @@ def transform_a(x: any): """ return 4 * x + 0.025 + def transform_ones_like(x: any): + """ + Generate an array of ones with the same shape as the input array. + + Parameters + ---------- + x : any + Input value + + Returns + ------- + ArrayLike + An array of ones with the same shape as the input value `x`. + """ + return jnp.ones_like(x) + scanner_a = pc.new_convolve_scanner(to_scan_a, transform_a) double_scanner_a = pc.new_double_convolve_scanner( - (jnp.array([523, 2, -0.5233]), to_scan_a), (lambda x: 1, transform_a) + (jnp.array([523, 2, -0.5233]), to_scan_a), + (transform_ones_like, transform_a), ) _, result_a = jax.lax.scan(f=scanner_a, init=inits, xs=multipliers) From 114af6a94f1ea64d693edda40d202a2766397838 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 11 Sep 2024 15:28:36 -0400 Subject: [PATCH 04/21] add tests for 2d array, infectionsrtfeedback --- pyrenew/latent/infections.py | 4 +- test/test_infectionsrtfeedback.py | 88 ++++++++++++++++++++++--------- 2 files changed, 64 insertions(+), 28 deletions(-) diff --git a/pyrenew/latent/infections.py b/pyrenew/latent/infections.py index 887d0e5a..168e8d6e 100644 --- a/pyrenew/latent/infections.py +++ b/pyrenew/latent/infections.py @@ -80,11 +80,11 @@ def sample( InfectionsSample Named tuple with "infections". """ - if I0.size < gen_int.size: + if I0.shape[0] < gen_int.size: raise ValueError( "Initial infections vector must be at least as long as " "the generation interval. " - f"Initial infections vector length: {I0.size}, " + f"Initial infections vector length: {I0.shape[0]}, " f"generation interval length: {gen_int.size}." ) diff --git a/test/test_infectionsrtfeedback.py b/test/test_infectionsrtfeedback.py index d241783c..9cb7b43a 100644 --- a/test/test_infectionsrtfeedback.py +++ b/test/test_infectionsrtfeedback.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import numpy as np import numpyro +import pytest from jax.typing import ArrayLike from numpy.testing import assert_array_almost_equal, assert_array_equal @@ -39,32 +40,53 @@ def _infection_w_feedback_alt( tuple """ - Rt = np.array(Rt) # coerce from jax to use numpy-like operations T = len(Rt) + Rt = np.array(Rt).reshape( + T, -1 + ) # coerce from jax to use numpy-like operations len_gen = len(gen_int) - I_vec = np.concatenate([I0, np.zeros(T)]) - Rt_adj = np.zeros(T) - - for t in range(T): - Rt_adj[t] = Rt[t] * np.exp( - inf_feedback_strength[t] - * np.dot(I_vec[t : t + len_gen], np.flip(inf_feedback_pmf)) - ) - - I_vec[t + len_gen] = Rt_adj[t] * np.dot( - I_vec[t : t + len_gen], np.flip(gen_int) - ) - - return {"post_initialization_infections": I_vec[I0.size :], "rt": Rt_adj} - - -def test_infectionsrtfeedback(): + I_vec = np.concatenate([I0.reshape(T, -1), np.zeros(Rt.shape)]) + Rt_adj = np.zeros(Rt.shape) + inf_feedback_strength = np.array(inf_feedback_strength).reshape(T, -1) + + for n in range(Rt.shape[1]): + for t in range(Rt.shape[0]): + Rt_adj[t, n] = Rt[t, n] * np.exp( + inf_feedback_strength[t, n] + * np.dot(I_vec[t : t + len_gen, n], np.flip(inf_feedback_pmf)) + ) + + I_vec[t + len_gen, n] = Rt_adj[t, n] * np.dot( + I_vec[t : t + len_gen, n], np.flip(gen_int) + ) + + return { + "post_initialization_infections": np.squeeze(I_vec[I0.shape[0] :]), + "rt": np.squeeze(Rt_adj), + } + + +@pytest.mark.parametrize( + ["Rt", "I0"], + [ + [ + jnp.array([0.5, 0.6, 0.7, 0.8, 2, 0.5, 2.25]), + jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]), + ], + [ + jnp.array( + np.array([0.5, 0.6, 0.7, 0.8, 2, 0.5, 2.25] * 3) + ).reshape((7, 3)), + jnp.array( + np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3) + ).reshape((7, 3)), + ], + ], +) +def test_infectionsrtfeedback(Rt, I0): """ Test the InfectionsWithFeedback matching the Infections class. """ - - Rt = jnp.array([0.5, 0.6, 0.7, 0.8, 2, 0.5, 2.25]) - I0 = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) # By doing the infection feedback strength 0, Rt = Rt_adjusted @@ -104,17 +126,31 @@ def test_infectionsrtfeedback(): return None -def test_infectionsrtfeedback_feedback(): +@pytest.mark.parametrize( + ["Rt", "I0"], + [ + [ + jnp.array([0.5, 0.6, 0.7, 0.8, 2, 0.5, 2.25]), + jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]), + ], + [ + jnp.array( + np.array([0.5, 0.6, 0.7, 0.8, 2, 0.5, 2.25] * 3) + ).reshape((7, 3)), + jnp.array( + np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3) + ).reshape((7, 3)), + ], + ], +) +def test_infectionsrtfeedback_feedback(Rt, I0): """ Test the InfectionsWithFeedback with feedback """ - - Rt = jnp.array([0.5, 0.6, 1.5, 2.523, 0.7, 0.8]) - I0 = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) inf_feed_strength = DeterministicVariable( - name="inf_feed_strength", value=jnp.repeat(0.5, len(Rt)) + name="inf_feed_strength", value=0.5 * jnp.ones_like(Rt) ) inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) From a625fde0ed6111fb41dbc257a208931dc3159a1b Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 11 Sep 2024 16:45:14 -0400 Subject: [PATCH 05/21] more tests --- test/test_infection_functions.py | 44 ++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/test/test_infection_functions.py b/test/test_infection_functions.py index f8e1a0b6..b4ef7626 100644 --- a/test/test_infection_functions.py +++ b/test/test_infection_functions.py @@ -4,6 +4,7 @@ """ import jax.numpy as jnp +import pytest from numpy.testing import assert_array_equal from pyrenew.latent import infection_functions as inf @@ -54,6 +55,45 @@ def test_compute_infections_from_rt_with_feedback(): ) assert_array_equal(Rt_adj, Rt_raw) - pass - pass + return None + + +@pytest.mark.parametrize( + ["I0", "gen_int", "inf_pmf", "Rt_raw"], + [ + [ + jnp.array([[5.0, 0.2]]), + jnp.array([1.0]), + jnp.array([1.0]), + jnp.ones((5, 2)), + ], + [ + 3.5235 * jnp.ones((35, 3)), + jnp.ones(35) / 35, + jnp.ones(35), + jnp.zeros((253, 3)), + ], + ], +) +def test_compute_infections_from_rt_with_feedback_2d( + I0, gen_int, inf_pmf, Rt_raw +): + """ + Test implementation of infection feedback + when I0 and Rt are 2d arrays. + """ + ( + infs_feedback, + Rt_adj, + ) = inf.compute_infections_from_rt_with_feedback( + I0, Rt_raw, jnp.zeros_like(Rt_raw), gen_int, inf_pmf + ) + print(inf.compute_infections_from_rt(I0, Rt_raw, gen_int)) + assert_array_equal( + inf.compute_infections_from_rt(I0, Rt_raw, gen_int), + infs_feedback, + ) + + assert_array_equal(Rt_adj, Rt_raw) + return None From 91be92fe0bb12aad6d43148dc95ab03966a8102f Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 12 Sep 2024 11:52:39 -0400 Subject: [PATCH 06/21] remove pad to match functions --- docs/source/tutorials/extending_pyrenew.qmd | 7 +- pyrenew/arrayutils.py | 88 --------------------- 2 files changed, 3 insertions(+), 92 deletions(-) diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 2cef0d2a..34e92a15 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -168,7 +168,7 @@ InfFeedbackSample = namedtuple( ) ``` -The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_x_to_match_y()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class: +The next step is to create the actual class. The bulk of its implementation lies in the function `pyrenew.latent.compute_infections_from_rt_with_feedback()`. We will also use the `pyrenew.arrayutils.pad_edges_to_match()` function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the `InfectionsWithFeedback` class: ```{python} # | label: new-model-def @@ -224,11 +224,10 @@ class InfFeedback(RandomVariable): inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength) - inf_feedback_strength = au.pad_x_to_match_y( + inf_feedback_strength = au.pad_edges_to_match( x=inf_feedback_strength, y=Rt, - fill_value=inf_feedback_strength[0], - ) + )[0] # Sampling inf feedback and adjusting the shape inf_feedback_pmf = self.infection_feedback_pmf(**kwargs) diff --git a/pyrenew/arrayutils.py b/pyrenew/arrayutils.py index 915c6816..f8ae8590 100644 --- a/pyrenew/arrayutils.py +++ b/pyrenew/arrayutils.py @@ -68,94 +68,6 @@ def pad_edges_to_match( return x, y -def pad_to_match( - x: ArrayLike, - y: ArrayLike, - fill_value: float = 0.0, - pad_direction: str = "end", - fix_y: bool = False, -) -> tuple[ArrayLike, ArrayLike]: - """ - Pad the shorter array at the start or end to match the length of the longer array. - - Parameters - ---------- - x : ArrayLike - First array. - y : ArrayLike - Second array. - fill_value : float, optional - Value to use for padding, by default 0.0. - pad_direction : str, optional - Direction to pad the shorter array, either "start" or "end", by default "end". - fix_y : bool, optional - If True, raise an error when `y` is shorter than `x`, by default False. - - Returns - ------- - tuple[ArrayLike, ArrayLike] - Tuple of the two arrays with the same length. - """ - x = jnp.atleast_1d(x) - y = jnp.atleast_1d(y) - x_len = x.size - y_len = y.size - pad_size = abs(x_len - y_len) - - pad_width = {"start": (pad_size, 0), "end": (0, pad_size)}.get( - pad_direction, None - ) - - if pad_width is None: - raise ValueError( - "pad_direction must be either 'start' or 'end'." - f" Got {pad_direction}." - ) - - if x_len > y_len: - if fix_y: - raise ValueError( - "Cannot fix y when x is longer than y." - f" x_len: {x_len}, y_len: {y_len}." - ) - y = jnp.pad(y, pad_width, constant_values=fill_value) - - elif y_len > x_len: - x = jnp.pad(x, pad_width, constant_values=fill_value) - - return x, y - - -def pad_x_to_match_y( - x: ArrayLike, - y: ArrayLike, - fill_value: float = 0.0, - pad_direction: str = "end", -) -> ArrayLike: - """ - Pad the `x` array at the start or end to match the length of the `y` array. - - Parameters - ---------- - x : ArrayLike - First array. - y : ArrayLike - Second array. - fill_value : float, optional - Value to use for padding, by default 0.0. - pad_direction : str, optional - Direction to pad the shorter array, either "start" or "end", by default "end". - - Returns - ------- - Array - Padded array. - """ - return pad_to_match( - x, y, fill_value=fill_value, pad_direction=pad_direction, fix_y=True - )[0] - - class PeriodicProcessSample(NamedTuple): """ A container for holding the output from `process.PeriodicProcess()`. From 396c2d70dbc518248d7218d3f630e822bf71534f Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 12 Sep 2024 11:54:01 -0400 Subject: [PATCH 07/21] remove tests --- test/test_arrayutils.py | 52 ----------------------------------------- 1 file changed, 52 deletions(-) diff --git a/test/test_arrayutils.py b/test/test_arrayutils.py index 38685a5c..9bf2882d 100644 --- a/test/test_arrayutils.py +++ b/test/test_arrayutils.py @@ -8,58 +8,6 @@ import pyrenew.arrayutils as au -def test_arrayutils_pad_to_match(): - """ - Verifies extension when required and error when `fix_y` is True. - """ - - x = jnp.array([1, 2, 3]) - y = jnp.array([1, 2]) - - x_pad, y_pad = au.pad_to_match(x, y) - - assert x_pad.size == y_pad.size - assert x_pad.size == 3 - - x = jnp.array([1, 2]) - y = jnp.array([1, 2, 3]) - - x_pad, y_pad = au.pad_to_match(x, y) - - assert x_pad.size == y_pad.size - assert x_pad.size == 3 - - x = jnp.array([1, 2, 3]) - y = jnp.array([1, 2]) - - # Verify that the function raises an error when `fix_y` is True - with pytest.raises(ValueError): - x_pad, y_pad = au.pad_to_match(x, y, fix_y=True) - - # Verify function works with both padding directions - x_pad, y_pad = au.pad_to_match(x, y, pad_direction="start") - - assert x_pad.size == y_pad.size - assert x_pad.size == 3 - - # Verify function raises an error when pad_direction is not "start" or "end" - with pytest.raises(ValueError): - x_pad, y_pad = au.pad_to_match(x, y, pad_direction="middle") - - -def test_arrayutils_pad_x_to_match_y(): - """ - Verifies extension when required - """ - - x = jnp.array([1, 2]) - y = jnp.array([1, 2, 3]) - - x_pad = au.pad_x_to_match_y(x, y) - - assert x_pad.size == 3 - - def test_pad_edges_to_match(): """ Test function to verify padding along the edges for 1D and 2D arrays From 81970595f222ad498753f9974e68b649bd39b3c2 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 12 Sep 2024 14:03:51 -0400 Subject: [PATCH 08/21] use list comprehension in test_infectionsrtfeedback --- test/test_infectionsrtfeedback.py | 54 ++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/test/test_infectionsrtfeedback.py b/test/test_infectionsrtfeedback.py index 9cb7b43a..e9d0c903 100644 --- a/test/test_infectionsrtfeedback.py +++ b/test/test_infectionsrtfeedback.py @@ -45,23 +45,55 @@ def _infection_w_feedback_alt( T, -1 ) # coerce from jax to use numpy-like operations len_gen = len(gen_int) - I_vec = np.concatenate([I0.reshape(T, -1), np.zeros(Rt.shape)]) + infs = np.concatenate([I0.reshape(T, -1), np.zeros(Rt.shape)]) Rt_adj = np.zeros(Rt.shape) inf_feedback_strength = np.array(inf_feedback_strength).reshape(T, -1) - for n in range(Rt.shape[1]): - for t in range(Rt.shape[0]): - Rt_adj[t, n] = Rt[t, n] * np.exp( - inf_feedback_strength[t, n] - * np.dot(I_vec[t : t + len_gen, n], np.flip(inf_feedback_pmf)) - ) + def compute_Rt_adj( + Rt, inf_feedback_strength, infs, inf_feedback_pmf, len_gen, t, n + ): # numpydoc ignore=GL08 + return Rt[t, n] * np.exp( + inf_feedback_strength[t, n] + * np.dot(infs[t : t + len_gen, n], np.flip(inf_feedback_pmf)) + ) + + Rt_adj = np.array( + [ + [ + compute_Rt_adj( + Rt, + inf_feedback_strength, + infs, + inf_feedback_pmf, + len_gen, + t, + n, + ) + for n in range(Rt.shape[1]) + ] + for t in range(Rt.shape[0]) + ] + ) - I_vec[t + len_gen, n] = Rt_adj[t, n] * np.dot( - I_vec[t : t + len_gen, n], np.flip(gen_int) - ) + def compute_infections( + Rt_adj, infs, len_gen, gen_int, t, n + ): # numpydoc ignore=GL08 + return Rt_adj[t, n] * np.dot( + infs[t : t + len_gen, n], np.flip(gen_int) + ) + + infs[len_gen : T + len_gen] = np.array( + [ + [ + compute_infections(Rt_adj, infs, len_gen, gen_int, t, n) + for n in range(Rt.shape[1]) + ] + for t in range(Rt.shape[0]) + ] + ) return { - "post_initialization_infections": np.squeeze(I_vec[I0.shape[0] :]), + "post_initialization_infections": np.squeeze(infs[I0.shape[0] :]), "rt": np.squeeze(Rt_adj), } From 97d19d05ed295cd7547376ff4958ea347c743bba Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 12 Sep 2024 14:17:03 -0400 Subject: [PATCH 09/21] revert to using for loop --- test/test_infectionsrtfeedback.py | 53 ++++++------------------------- 1 file changed, 10 insertions(+), 43 deletions(-) diff --git a/test/test_infectionsrtfeedback.py b/test/test_infectionsrtfeedback.py index e9d0c903..47841922 100644 --- a/test/test_infectionsrtfeedback.py +++ b/test/test_infectionsrtfeedback.py @@ -39,58 +39,25 @@ def _infection_w_feedback_alt( ------- tuple """ - T = len(Rt) Rt = np.array(Rt).reshape( T, -1 ) # coerce from jax to use numpy-like operations len_gen = len(gen_int) - infs = np.concatenate([I0.reshape(T, -1), np.zeros(Rt.shape)]) + infs = np.pad(I0.reshape(T, -1), ((0, Rt.shape[0]), (0, 0))) Rt_adj = np.zeros(Rt.shape) inf_feedback_strength = np.array(inf_feedback_strength).reshape(T, -1) - def compute_Rt_adj( - Rt, inf_feedback_strength, infs, inf_feedback_pmf, len_gen, t, n - ): # numpydoc ignore=GL08 - return Rt[t, n] * np.exp( - inf_feedback_strength[t, n] - * np.dot(infs[t : t + len_gen, n], np.flip(inf_feedback_pmf)) - ) - - Rt_adj = np.array( - [ - [ - compute_Rt_adj( - Rt, - inf_feedback_strength, - infs, - inf_feedback_pmf, - len_gen, - t, - n, - ) - for n in range(Rt.shape[1]) - ] - for t in range(Rt.shape[0]) - ] - ) + for n in range(Rt.shape[1]): + for t in range(Rt.shape[0]): + Rt_adj[t, n] = Rt[t, n] * np.exp( + inf_feedback_strength[t, n] + * np.dot(infs[t : t + len_gen, n], np.flip(inf_feedback_pmf)) + ) - def compute_infections( - Rt_adj, infs, len_gen, gen_int, t, n - ): # numpydoc ignore=GL08 - return Rt_adj[t, n] * np.dot( - infs[t : t + len_gen, n], np.flip(gen_int) - ) - - infs[len_gen : T + len_gen] = np.array( - [ - [ - compute_infections(Rt_adj, infs, len_gen, gen_int, t, n) - for n in range(Rt.shape[1]) - ] - for t in range(Rt.shape[0]) - ] - ) + infs[t + len_gen, n] = Rt_adj[t, n] * np.dot( + infs[t : t + len_gen, n], np.flip(gen_int) + ) return { "post_initialization_infections": np.squeeze(infs[I0.shape[0] :]), From d933df457b666addc862dee4f301ac11252f24e8 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 12 Sep 2024 17:26:01 -0400 Subject: [PATCH 10/21] add more test for convolve scanner functions --- pyrenew/latent/infections.py | 7 ++ pyrenew/latent/infectionswithfeedback.py | 11 +- test/test_convolve_scanners.py | 146 +++++++++++++++++++++++ 3 files changed, 162 insertions(+), 2 deletions(-) diff --git a/pyrenew/latent/infections.py b/pyrenew/latent/infections.py index 168e8d6e..628c7cc4 100644 --- a/pyrenew/latent/infections.py +++ b/pyrenew/latent/infections.py @@ -88,6 +88,13 @@ def sample( f"generation interval length: {gen_int.size}." ) + if I0.shape != Rt.shape: + raise ValueError( + "Initial infections and Rt must have the same shape. " + f"Got initial infections of shape {I0.shape} " + f"and Rt of shape {Rt.shape}." + ) + gen_int_rev = jnp.flip(gen_int) recent_I0 = I0[-gen_int_rev.size :] diff --git a/pyrenew/latent/infectionswithfeedback.py b/pyrenew/latent/infectionswithfeedback.py index 7cacb581..32a4ac18 100644 --- a/pyrenew/latent/infectionswithfeedback.py +++ b/pyrenew/latent/infectionswithfeedback.py @@ -146,8 +146,15 @@ def sample( if I0.shape[0] < gen_int.size: raise ValueError( "Initial infections must be at least as long as the " - f"generation interval. Got {I0.shape[0]} initial infections " - f"and {gen_int.size} generation interval." + f"generation interval. Got initial infections length {I0.shape[0]}" + f"and generation interval length {gen_int.size}." + ) + + if I0.shape != Rt.shape: + raise ValueError( + "Initial infections and Rt must have the same shape. " + f"Got initial infections of shape {I0.shape} " + f"and Rt of shape {Rt.shape}." ) gen_int_rev = jnp.flip(gen_int) diff --git a/test/test_convolve_scanners.py b/test/test_convolve_scanners.py index 77071686..3b772a1e 100644 --- a/test/test_convolve_scanners.py +++ b/test/test_convolve_scanners.py @@ -10,6 +10,7 @@ from numpy.testing import assert_array_equal import pyrenew.convolve as pc +import pyrenew.transformation as t @pytest.mark.parametrize( @@ -82,3 +83,148 @@ def transform_ones_like(x: any): assert_array_equal(result_a_double[1], jnp.ones_like(multipliers)) assert_array_equal(result_a_double[0], result_a) + + +@pytest.mark.parametrize( + ["arr", "history", "multipliers", "transform"], + [ + [ + jnp.array([1.0, 2.0]), + jnp.array([3.0, 4.0]), + jnp.array([1, 2, 3]), + t.IdentityTransform(), + ], + [ + jnp.ones(3), + jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3), + jnp.ones((3, 3)), + t.ExpTransform(), + ], + ], +) +def test_convolve_scanner_using_scan(arr, history, multipliers, transform): + """ + Tests the output of new convolve scanner function + used with `jax.lax.scan` against values calculated + using a for loop + """ + scanner = pc.new_convolve_scanner(arr, transform) + + _, result = jax.lax.scan(f=scanner, init=history, xs=multipliers) + + result_not_scanned = [] + for multiplier in multipliers: + history, new_val = scanner(history, multiplier) + result_not_scanned.append(new_val) + + assert jnp.array_equal(result, result_not_scanned) + + +@pytest.mark.parametrize( + ["arr1", "arr2", "history", "m1", "m2", "transform"], + [ + [ + jnp.array([1.0, 2.0]), + jnp.array([2.0, 1.0]), + jnp.array([0.1, 0.4]), + jnp.array([1, 2, 3]), + jnp.ones(3), + (t.IdentityTransform(), t.IdentityTransform()), + ], + [ + jnp.array([1.0, 2.0, 0.3]), + jnp.array([2.0, 1.0, 0.5]), + jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3), + jnp.ones((3, 3)), + jnp.ones((3, 3)), + (t.ExpTransform(), t.IdentityTransform()), + ], + ], +) +def test_double_convolve_scanner_using_scan( + arr1, arr2, history, m1, m2, transform +): + """ + Tests the output of new convolve double scanner function + used with `jax.lax.scan` against values calculated + using a for loop + """ + arr1 = jnp.array([1.0, 2.0]) + arr2 = jnp.array([2.0, 1.0]) + transform = (t.IdentityTransform(), t.IdentityTransform()) + history = jnp.array([0.1, 0.4]) + m1, m2 = (jnp.array([1, 2, 3]), jnp.ones(3)) + + scanner = pc.new_double_convolve_scanner((arr1, arr2), transform) + + _, result = jax.lax.scan(f=scanner, init=history, xs=(m1, m2)) + + res1, res2 = [], [] + for m1, m2 in zip(m1, m2): + history, new_val = scanner(history, (m1, m2)) + res1.append(new_val[0]) + res2.append(new_val[1]) + + assert jnp.array_equal(result, (res1, res2)) + + +@pytest.mark.parametrize( + ["arr", "history", "multiplier", "transform"], + [ + [ + jnp.array([1.0, 2.0]), + jnp.array([3.0, 4.0]), + jnp.array(2), + t.IdentityTransform(), + ], + [ + jnp.ones(3), + jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3), + jnp.ones(3), + t.ExpTransform(), + ], + ], +) +def test_convolve_scanner(arr, history, multiplier, transform): + """ + Tests new convolve scanner function + """ + scanner = pc.new_convolve_scanner(arr, transform) + latest, new_val = scanner(history, multiplier) + assert jnp.array_equal( + new_val, transform(multiplier * jnp.dot(arr, history)) + ) + + +@pytest.mark.parametrize( + ["arr1", "arr2", "history", "m1", "m2", "transforms"], + [ + [ + jnp.array([1.0, 2.0]), + jnp.array([2.0, 1.0]), + jnp.array([0.1, 0.4]), + jnp.array(1), + jnp.array(3), + (t.IdentityTransform(), t.IdentityTransform()), + ], + [ + jnp.array([1.0, 2.0, 0.3]), + jnp.array([2.0, 1.0, 0.5]), + jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3), + jnp.ones(3), + 0.1 * jnp.ones(3), + (t.ExpTransform(), t.IdentityTransform()), + ], + ], +) +def test_double_convolve_scanner(arr1, arr2, history, m1, m2, transforms): + """ + Tests new double convolve scanner function + """ + double_scanner = pc.new_double_convolve_scanner((arr1, arr2), transforms) + latest, (new_val, m_net) = double_scanner(history, (m1, m2)) + + assert jnp.array_equal(m_net, transforms[0](m1 * jnp.dot(arr1, history))) + assert jnp.array_equal( + new_val, transforms[1](m2 * m_net * jnp.dot(arr2, history)) + ) From c210606867a4d5ac4b4e8d08e0e4190f0686c205 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 12 Sep 2024 17:32:05 -0400 Subject: [PATCH 11/21] add check for initial infections and Rt ndims --- pyrenew/latent/infections.py | 4 ++-- pyrenew/latent/infectionswithfeedback.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyrenew/latent/infections.py b/pyrenew/latent/infections.py index 628c7cc4..0a6df04c 100644 --- a/pyrenew/latent/infections.py +++ b/pyrenew/latent/infections.py @@ -88,9 +88,9 @@ def sample( f"generation interval length: {gen_int.size}." ) - if I0.shape != Rt.shape: + if I0.ndim != Rt.ndim: raise ValueError( - "Initial infections and Rt must have the same shape. " + "Initial infections and Rt must have the same dimensions. " f"Got initial infections of shape {I0.shape} " f"and Rt of shape {Rt.shape}." ) diff --git a/pyrenew/latent/infectionswithfeedback.py b/pyrenew/latent/infectionswithfeedback.py index 32a4ac18..879cfb4b 100644 --- a/pyrenew/latent/infectionswithfeedback.py +++ b/pyrenew/latent/infectionswithfeedback.py @@ -150,9 +150,9 @@ def sample( f"and generation interval length {gen_int.size}." ) - if I0.shape != Rt.shape: + if I0.ndim != Rt.ndim: raise ValueError( - "Initial infections and Rt must have the same shape. " + "Initial infections and Rt must have the dimensions. " f"Got initial infections of shape {I0.shape} " f"and Rt of shape {Rt.shape}." ) From f5557a3b36566ab49efa5b040b01b15682676d3b Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 12 Sep 2024 18:03:04 -0400 Subject: [PATCH 12/21] remove typos and superfluous print statements --- pyrenew/latent/infectionswithfeedback.py | 2 +- test/test_infection_functions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrenew/latent/infectionswithfeedback.py b/pyrenew/latent/infectionswithfeedback.py index 879cfb4b..c841536c 100644 --- a/pyrenew/latent/infectionswithfeedback.py +++ b/pyrenew/latent/infectionswithfeedback.py @@ -152,7 +152,7 @@ def sample( if I0.ndim != Rt.ndim: raise ValueError( - "Initial infections and Rt must have the dimensions. " + "Initial infections and Rt must have the same dimensions. " f"Got initial infections of shape {I0.shape} " f"and Rt of shape {Rt.shape}." ) diff --git a/test/test_infection_functions.py b/test/test_infection_functions.py index b4ef7626..94a3ad5b 100644 --- a/test/test_infection_functions.py +++ b/test/test_infection_functions.py @@ -88,7 +88,7 @@ def test_compute_infections_from_rt_with_feedback_2d( ) = inf.compute_infections_from_rt_with_feedback( I0, Rt_raw, jnp.zeros_like(Rt_raw), gen_int, inf_pmf ) - print(inf.compute_infections_from_rt(I0, Rt_raw, gen_int)) + assert_array_equal( inf.compute_infections_from_rt(I0, Rt_raw, gen_int), infs_feedback, From 46e32e8d1930cc966fb6d48fb04c8cdcb2989525 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 12 Sep 2024 18:18:38 -0400 Subject: [PATCH 13/21] add more tests --- test/test_infectionsrtfeedback.py | 45 +++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/test/test_infectionsrtfeedback.py b/test/test_infectionsrtfeedback.py index 47841922..c2fab26c 100644 --- a/test/test_infectionsrtfeedback.py +++ b/test/test_infectionsrtfeedback.py @@ -193,3 +193,48 @@ def test_infectionsrtfeedback_feedback(Rt, I0): assert_array_almost_equal(samp1.rt, res["rt"]) return None + + +def test_infections_with_feedback_invalid_inputs(): + """ + Test the InfectionsWithFeedback class cannot + be sampled when Rt and I0 have invalid input shapes + """ + I0_1d = jnp.array([0.5, 0.6, 0.7, 0.8]) + I0_2d = jnp.array( + np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3) + ).reshape((7, -1)) + Rt = jnp.ones(10) + gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) + + inf_feed_strength = DeterministicVariable( + name="inf_feed_strength", value=jnp.zeros_like(Rt) + ) + inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) + + # Test the InfectionsWithFeedback class + InfectionsWithFeedback = latent.InfectionsWithFeedback( + infection_feedback_strength=inf_feed_strength, + infection_feedback_pmf=inf_feedback_pmf, + ) + + with numpyro.handlers.seed(rng_seed=0): + with pytest.raises( + ValueError, + match="Initial infections must be at least as long as the generation interval.", + ): + InfectionsWithFeedback( + gen_int=gen_int, + Rt=Rt, + I0=I0_1d, + ) + + with pytest.raises( + ValueError, + match="Initial infections and Rt must have the same dimensions.", + ): + InfectionsWithFeedback( + gen_int=gen_int, + Rt=Rt, + I0=I0_2d, + ) From 728aad38e69196ebbba86bbb0a914622146783e1 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 12 Sep 2024 18:29:05 -0400 Subject: [PATCH 14/21] change value of inf_feedback in tests to cover diff scenarios --- test/test_infectionsrtfeedback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_infectionsrtfeedback.py b/test/test_infectionsrtfeedback.py index c2fab26c..384e1829 100644 --- a/test/test_infectionsrtfeedback.py +++ b/test/test_infectionsrtfeedback.py @@ -208,7 +208,7 @@ def test_infections_with_feedback_invalid_inputs(): gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) inf_feed_strength = DeterministicVariable( - name="inf_feed_strength", value=jnp.zeros_like(Rt) + name="inf_feed_strength", value=0.5 ) inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) From 54c0527c05f9546671e0c0d5c4caaf57ae59ea0f Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 12 Sep 2024 18:34:42 -0400 Subject: [PATCH 15/21] add input array tests for infections.py --- test/test_infectionsrtfeedback.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/test_infectionsrtfeedback.py b/test/test_infectionsrtfeedback.py index 384e1829..6346be56 100644 --- a/test/test_infectionsrtfeedback.py +++ b/test/test_infectionsrtfeedback.py @@ -218,6 +218,8 @@ def test_infections_with_feedback_invalid_inputs(): infection_feedback_pmf=inf_feedback_pmf, ) + infections = latent.Infections() + with numpyro.handlers.seed(rng_seed=0): with pytest.raises( ValueError, @@ -229,6 +231,16 @@ def test_infections_with_feedback_invalid_inputs(): I0=I0_1d, ) + with pytest.raises( + ValueError, + match="Initial infections vector must be at least as long as the generation interval.", + ): + infections( + gen_int=gen_int, + Rt=Rt, + I0=I0_1d, + ) + with pytest.raises( ValueError, match="Initial infections and Rt must have the same dimensions.", @@ -238,3 +250,13 @@ def test_infections_with_feedback_invalid_inputs(): Rt=Rt, I0=I0_2d, ) + + with pytest.raises( + ValueError, + match="Initial infections and Rt must have the same dimensions.", + ): + infections( + gen_int=gen_int, + Rt=Rt, + I0=I0_2d, + ) From a6ce6c6e36fe6e9c1af307c10f73266f055ffdb8 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 13 Sep 2024 09:07:35 -0400 Subject: [PATCH 16/21] add test with plate --- pyrenew/latent/infectionswithfeedback.py | 4 + ...est_infection_and_infectionwithfeedback.py | 79 +++++++++++++++++++ test/test_infectionsrtfeedback.py | 69 +--------------- ...fectionwithfeedback_plate_compatibility.py | 44 +++++++++++ 4 files changed, 128 insertions(+), 68 deletions(-) create mode 100644 test/test_infection_and_infectionwithfeedback.py create mode 100644 test/test_infectionwithfeedback_plate_compatibility.py diff --git a/pyrenew/latent/infectionswithfeedback.py b/pyrenew/latent/infectionswithfeedback.py index c841536c..949fdb38 100644 --- a/pyrenew/latent/infectionswithfeedback.py +++ b/pyrenew/latent/infectionswithfeedback.py @@ -167,6 +167,10 @@ def sample( **kwargs, ) ) + + if inf_feedback_strength.ndim < Rt.ndim: + inf_feedback_strength = jnp.expand_dims(inf_feedback_strength, 0) + # Making sure inf_feedback_strength spans the Rt length if inf_feedback_strength.shape[0] == 1: inf_feedback_strength = au.pad_edges_to_match( diff --git a/test/test_infection_and_infectionwithfeedback.py b/test/test_infection_and_infectionwithfeedback.py new file mode 100644 index 00000000..e8915546 --- /dev/null +++ b/test/test_infection_and_infectionwithfeedback.py @@ -0,0 +1,79 @@ +""" +Test to verify Infection and InfectionsWithFeedback class +return error when input array shape for I0 and Rt are invalid +""" + +import jax.numpy as jnp +import numpy as np +import numpyro +import pytest + +import pyrenew.latent as latent +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable + + +def test_infections_with_feedback_invalid_inputs(): + """ + Test the InfectionsWithFeedback class cannot + be sampled when Rt and I0 have invalid input shapes + """ + I0_1d = jnp.array([0.5, 0.6, 0.7, 0.8]) + I0_2d = jnp.array( + np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3) + ).reshape((7, -1)) + Rt = jnp.ones(10) + gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) + + inf_feed_strength = DeterministicVariable( + name="inf_feed_strength", value=0.5 + ) + inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) + + # Test the InfectionsWithFeedback class + InfectionsWithFeedback = latent.InfectionsWithFeedback( + infection_feedback_strength=inf_feed_strength, + infection_feedback_pmf=inf_feedback_pmf, + ) + + infections = latent.Infections() + + with numpyro.handlers.seed(rng_seed=0): + with pytest.raises( + ValueError, + match="Initial infections must be at least as long as the generation interval.", + ): + InfectionsWithFeedback( + gen_int=gen_int, + Rt=Rt, + I0=I0_1d, + ) + + with pytest.raises( + ValueError, + match="Initial infections vector must be at least as long as the generation interval.", + ): + infections( + gen_int=gen_int, + Rt=Rt, + I0=I0_1d, + ) + + with pytest.raises( + ValueError, + match="Initial infections and Rt must have the same dimensions.", + ): + InfectionsWithFeedback( + gen_int=gen_int, + Rt=Rt, + I0=I0_2d, + ) + + with pytest.raises( + ValueError, + match="Initial infections and Rt must have the same dimensions.", + ): + infections( + gen_int=gen_int, + Rt=Rt, + I0=I0_2d, + ) diff --git a/test/test_infectionsrtfeedback.py b/test/test_infectionsrtfeedback.py index 6346be56..3151d62b 100644 --- a/test/test_infectionsrtfeedback.py +++ b/test/test_infectionsrtfeedback.py @@ -91,7 +91,7 @@ def test_infectionsrtfeedback(Rt, I0): # By doing the infection feedback strength 0, Rt = Rt_adjusted # So infection should be equal in both inf_feed_strength = DeterministicVariable( - name="inf_feed_strength", value=jnp.zeros_like(Rt) + name="inf_feed_strength", value=jnp.array(0) ) inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) @@ -193,70 +193,3 @@ def test_infectionsrtfeedback_feedback(Rt, I0): assert_array_almost_equal(samp1.rt, res["rt"]) return None - - -def test_infections_with_feedback_invalid_inputs(): - """ - Test the InfectionsWithFeedback class cannot - be sampled when Rt and I0 have invalid input shapes - """ - I0_1d = jnp.array([0.5, 0.6, 0.7, 0.8]) - I0_2d = jnp.array( - np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3) - ).reshape((7, -1)) - Rt = jnp.ones(10) - gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) - - inf_feed_strength = DeterministicVariable( - name="inf_feed_strength", value=0.5 - ) - inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) - - # Test the InfectionsWithFeedback class - InfectionsWithFeedback = latent.InfectionsWithFeedback( - infection_feedback_strength=inf_feed_strength, - infection_feedback_pmf=inf_feedback_pmf, - ) - - infections = latent.Infections() - - with numpyro.handlers.seed(rng_seed=0): - with pytest.raises( - ValueError, - match="Initial infections must be at least as long as the generation interval.", - ): - InfectionsWithFeedback( - gen_int=gen_int, - Rt=Rt, - I0=I0_1d, - ) - - with pytest.raises( - ValueError, - match="Initial infections vector must be at least as long as the generation interval.", - ): - infections( - gen_int=gen_int, - Rt=Rt, - I0=I0_1d, - ) - - with pytest.raises( - ValueError, - match="Initial infections and Rt must have the same dimensions.", - ): - InfectionsWithFeedback( - gen_int=gen_int, - Rt=Rt, - I0=I0_2d, - ) - - with pytest.raises( - ValueError, - match="Initial infections and Rt must have the same dimensions.", - ): - infections( - gen_int=gen_int, - Rt=Rt, - I0=I0_2d, - ) diff --git a/test/test_infectionwithfeedback_plate_compatibility.py b/test/test_infectionwithfeedback_plate_compatibility.py new file mode 100644 index 00000000..535902ab --- /dev/null +++ b/test/test_infectionwithfeedback_plate_compatibility.py @@ -0,0 +1,44 @@ +""" +Test the InfectionsWithFeedback class works well within numpyro plate +""" + +import jax.numpy as jnp +import numpy as np +import numpyro +import numpyro.distributions as dist + +import pyrenew.latent as latent +from pyrenew.deterministic import DeterministicPMF +from pyrenew.randomvariable import DistributionalVariable + + +def test_infections_with_feedback_plate_compatibility(): + """ + Test the InfectionsWithFeedback matching the Infections class. + """ + I0 = jnp.array( + np.array([0.0, 0.0, 0.0, 0.5, 0.6, 0.7, 0.8] * 5).reshape(-1, 5) + ) + Rt = jnp.ones((10, 5)) + gen_int = jnp.array([0.4, 0.25, 0.25, 0.1]) + + inf_feed_strength = DistributionalVariable( + "inf_feed_strength", dist.Beta(1, 1) + ) + inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) + + # Test the InfectionsWithFeedback class + InfectionsWithFeedback = latent.InfectionsWithFeedback( + infection_feedback_strength=inf_feed_strength, + infection_feedback_pmf=inf_feedback_pmf, + ) + + with numpyro.handlers.seed(rng_seed=0): + with numpyro.plate("test_plate", 5): + samp = InfectionsWithFeedback( + gen_int=gen_int, + Rt=Rt, + I0=I0, + ) + + assert samp.rt.shape == Rt.shape From cc61d0b2f3a973109808ffa2c222aaeb319be3c3 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 13 Sep 2024 15:57:15 -0400 Subject: [PATCH 17/21] code review changes --- pyrenew/latent/infectionswithfeedback.py | 4 ++-- test/test_convolve_scanners.py | 18 +----------------- test/test_infectionsrtfeedback.py | 14 +++++++++----- 3 files changed, 12 insertions(+), 24 deletions(-) diff --git a/pyrenew/latent/infectionswithfeedback.py b/pyrenew/latent/infectionswithfeedback.py index 949fdb38..3cc8c6e2 100644 --- a/pyrenew/latent/infectionswithfeedback.py +++ b/pyrenew/latent/infectionswithfeedback.py @@ -168,7 +168,7 @@ def sample( ) ) - if inf_feedback_strength.ndim < Rt.ndim: + if inf_feedback_strength.ndim == Rt.ndim - 1: inf_feedback_strength = jnp.expand_dims(inf_feedback_strength, 0) # Making sure inf_feedback_strength spans the Rt length @@ -178,7 +178,7 @@ def sample( y=Rt, axis=0, )[0] - elif inf_feedback_strength.shape != Rt.shape: + if inf_feedback_strength.shape != Rt.shape: raise ValueError( "Infection feedback strength must be of length 1 " "or the same length as the reproduction number array. " diff --git a/test/test_convolve_scanners.py b/test/test_convolve_scanners.py index 3b772a1e..63bdd6a1 100644 --- a/test/test_convolve_scanners.py +++ b/test/test_convolve_scanners.py @@ -52,27 +52,11 @@ def transform_a(x: any): """ return 4 * x + 0.025 - def transform_ones_like(x: any): - """ - Generate an array of ones with the same shape as the input array. - - Parameters - ---------- - x : any - Input value - - Returns - ------- - ArrayLike - An array of ones with the same shape as the input value `x`. - """ - return jnp.ones_like(x) - scanner_a = pc.new_convolve_scanner(to_scan_a, transform_a) double_scanner_a = pc.new_double_convolve_scanner( (jnp.array([523, 2, -0.5233]), to_scan_a), - (transform_ones_like, transform_a), + (jnp.ones_like, transform_a), ) _, result_a = jax.lax.scan(f=scanner_a, init=inits, xs=multipliers) diff --git a/test/test_infectionsrtfeedback.py b/test/test_infectionsrtfeedback.py index 3151d62b..8620bf10 100644 --- a/test/test_infectionsrtfeedback.py +++ b/test/test_infectionsrtfeedback.py @@ -66,11 +66,14 @@ def _infection_w_feedback_alt( @pytest.mark.parametrize( - ["Rt", "I0"], + ["Rt", "I0", "inf_feed_strength"], [ [ jnp.array([0.5, 0.6, 0.7, 0.8, 2, 0.5, 2.25]), jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]), + DeterministicVariable( + name="inf_feed_strength", value=jnp.array(0) + ), ], [ jnp.array( @@ -79,10 +82,13 @@ def _infection_w_feedback_alt( jnp.array( np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3) ).reshape((7, 3)), + DeterministicVariable( + name="inf_feed_strength", value=jnp.zeros(3) + ), ], ], ) -def test_infectionsrtfeedback(Rt, I0): +def test_infectionsrtfeedback(Rt, I0, inf_feed_strength): """ Test the InfectionsWithFeedback matching the Infections class. """ @@ -90,9 +96,7 @@ def test_infectionsrtfeedback(Rt, I0): # By doing the infection feedback strength 0, Rt = Rt_adjusted # So infection should be equal in both - inf_feed_strength = DeterministicVariable( - name="inf_feed_strength", value=jnp.array(0) - ) + inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) # Test the InfectionsWithFeedback class From 0cd56ad169a4d39941a24f5885fc9e34b06bc5f0 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 13 Sep 2024 16:45:39 -0400 Subject: [PATCH 18/21] code review suggestion --- docs/source/tutorials/extending_pyrenew.qmd | 4 ++-- test/test_arrayutils.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 34e92a15..3834d858 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -224,10 +224,10 @@ class InfFeedback(RandomVariable): inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength) - inf_feedback_strength = au.pad_edges_to_match( + inf_feedback_strength, _ = au.pad_edges_to_match( x=inf_feedback_strength, y=Rt, - )[0] + ) # Sampling inf feedback and adjusting the shape inf_feedback_pmf = self.infection_feedback_pmf(**kwargs) diff --git a/test/test_arrayutils.py b/test/test_arrayutils.py index 9bf2882d..0048fade 100644 --- a/test/test_arrayutils.py +++ b/test/test_arrayutils.py @@ -20,6 +20,7 @@ def test_pad_edges_to_match(): x_pad, y_pad = au.pad_edges_to_match(x, y) assert x_pad.size == y_pad.size assert y_pad[-1] == y[-1] + assert jnp.array_equal(x_pad, x) # test when x gets padded x = jnp.array([1, 2]) @@ -28,6 +29,7 @@ def test_pad_edges_to_match(): x_pad, y_pad = au.pad_edges_to_match(x, y) assert x_pad.size == y_pad.size assert x_pad[-1] == x[-1] + assert jnp.array_equal(y_pad, y) # test when no padding required x = jnp.array([1, 2, 3]) @@ -45,6 +47,7 @@ def test_pad_edges_to_match(): assert x_pad.size == y_pad.size assert y_pad[0] == y[0] + assert jnp.array_equal(x_pad, x) # Verify that the function raises an error when `fix_y` is True with pytest.raises( @@ -66,9 +69,11 @@ def test_pad_edges_to_match(): assert jnp.array_equal(x_pad.shape[axis], y_pad.shape[axis]) assert jnp.array_equal(y_pad[-1], y[-1]) + assert jnp.array_equal(x_pad, x) # padding along axis 1 axis = 1 x_pad, y_pad = au.pad_edges_to_match(x, y, axis=axis, pad_direction="end") assert jnp.array_equal(x_pad.shape[axis], y_pad.shape[axis]) assert jnp.array_equal(y[:, -1], y_pad[:, -1]) + assert jnp.array_equal(x_pad, x) From 0f3e2fda4fde306a5a271e7a30150dbb68dc8df0 Mon Sep 17 00:00:00 2001 From: sbidari Date: Fri, 13 Sep 2024 17:11:32 -0400 Subject: [PATCH 19/21] replace jnp.dot with einsum --- pyrenew/convolve.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pyrenew/convolve.py b/pyrenew/convolve.py index 575d140b..25fb68d7 100755 --- a/pyrenew/convolve.py +++ b/pyrenew/convolve.py @@ -78,10 +78,11 @@ def _new_scanner( history_subset: ArrayLike, multiplier: float ) -> tuple[ArrayLike, float]: # numpydoc ignore=GL08 new_val = transform( - multiplier * jnp.dot(array_to_convolve, history_subset) + multiplier + * jnp.einsum("i...,i...->...", array_to_convolve, history_subset) ) latest = jnp.concatenate( - [history_subset[1:], jnp.expand_dims(new_val, axis=0)] + [history_subset[1:], jnp.expand_dims(new_val, axis=0)], axis=0 ) return latest, new_val @@ -160,10 +161,12 @@ def _new_scanner( multipliers: tuple[float, float], ) -> tuple[ArrayLike, tuple[float, float]]: # numpydoc ignore=GL08 m1, m2 = multipliers - m_net1 = t1(m1 * jnp.dot(arr1, history_subset)) - new_val = t2(m2 * m_net1 * jnp.dot(arr2, history_subset)) + m_net1 = t1(m1 * jnp.einsum("i...,i...->...", arr1, history_subset)) + new_val = t2( + m2 * m_net1 * jnp.einsum("i...,i...->...", arr2, history_subset) + ) latest = jnp.concatenate( - [history_subset[1:], jnp.expand_dims(new_val, axis=0)] + [history_subset[1:], jnp.expand_dims(new_val, axis=0)], axis=0 ) return latest, (new_val, m_net1) From 542ad84548055207f1b825589881ce83c886b077 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 16 Sep 2024 11:12:39 -0400 Subject: [PATCH 20/21] code review suggestions --- pyrenew/convolve.py | 4 ++-- pyrenew/latent/infections.py | 8 ++++---- pyrenew/latent/infectionswithfeedback.py | 8 ++++---- test/test_infection_and_infectionwithfeedback.py | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pyrenew/convolve.py b/pyrenew/convolve.py index 25fb68d7..8b99853c 100755 --- a/pyrenew/convolve.py +++ b/pyrenew/convolve.py @@ -82,7 +82,7 @@ def _new_scanner( * jnp.einsum("i...,i...->...", array_to_convolve, history_subset) ) latest = jnp.concatenate( - [history_subset[1:], jnp.expand_dims(new_val, axis=0)], axis=0 + [history_subset[1:], new_val[jnp.newaxis]], axis=0 ) return latest, new_val @@ -166,7 +166,7 @@ def _new_scanner( m2 * m_net1 * jnp.einsum("i...,i...->...", arr2, history_subset) ) latest = jnp.concatenate( - [history_subset[1:], jnp.expand_dims(new_val, axis=0)], axis=0 + [history_subset[1:], new_val[jnp.newaxis]], axis=0 ) return latest, (new_val, m_net1) diff --git a/pyrenew/latent/infections.py b/pyrenew/latent/infections.py index 0a6df04c..7a8b02a8 100644 --- a/pyrenew/latent/infections.py +++ b/pyrenew/latent/infections.py @@ -88,11 +88,11 @@ def sample( f"generation interval length: {gen_int.size}." ) - if I0.ndim != Rt.ndim: + if I0.shape[1:] != Rt.shape[1:]: raise ValueError( - "Initial infections and Rt must have the same dimensions. " - f"Got initial infections of shape {I0.shape} " - f"and Rt of shape {Rt.shape}." + "Initial infections and Rt must have the same batch shapes. " + f"Got initial infections of batch shape {I0.shape[1:]} " + f"and Rt of batch shape {Rt.shape[1:]}." ) gen_int_rev = jnp.flip(gen_int) diff --git a/pyrenew/latent/infectionswithfeedback.py b/pyrenew/latent/infectionswithfeedback.py index 3cc8c6e2..934fedfd 100644 --- a/pyrenew/latent/infectionswithfeedback.py +++ b/pyrenew/latent/infectionswithfeedback.py @@ -150,11 +150,11 @@ def sample( f"and generation interval length {gen_int.size}." ) - if I0.ndim != Rt.ndim: + if I0.shape[1:] != Rt.shape[1:]: raise ValueError( - "Initial infections and Rt must have the same dimensions. " - f"Got initial infections of shape {I0.shape} " - f"and Rt of shape {Rt.shape}." + "Initial infections and Rt must have the same batch shapes. " + f"Got initial infections of batch shape {I0.shape[1:]} " + f"and Rt of batch shape {Rt.shape[1:]}." ) gen_int_rev = jnp.flip(gen_int) diff --git a/test/test_infection_and_infectionwithfeedback.py b/test/test_infection_and_infectionwithfeedback.py index e8915546..5e1f44ef 100644 --- a/test/test_infection_and_infectionwithfeedback.py +++ b/test/test_infection_and_infectionwithfeedback.py @@ -60,7 +60,7 @@ def test_infections_with_feedback_invalid_inputs(): with pytest.raises( ValueError, - match="Initial infections and Rt must have the same dimensions.", + match="Initial infections and Rt must have the same batch shapes.", ): InfectionsWithFeedback( gen_int=gen_int, @@ -70,7 +70,7 @@ def test_infections_with_feedback_invalid_inputs(): with pytest.raises( ValueError, - match="Initial infections and Rt must have the same dimensions.", + match="Initial infections and Rt must have the same batch shapes.", ): infections( gen_int=gen_int, From bafe3b04021d389fe626e972297228cc586e99ce Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 17 Sep 2024 16:02:15 -0400 Subject: [PATCH 21/21] code review changes --- pyrenew/latent/infectionswithfeedback.py | 2 +- ...est_infection_and_infectionwithfeedback.py | 79 ------------------- test/test_infectionsrtfeedback.py | 67 ++++++++++++++++ 3 files changed, 68 insertions(+), 80 deletions(-) delete mode 100644 test/test_infection_and_infectionwithfeedback.py diff --git a/pyrenew/latent/infectionswithfeedback.py b/pyrenew/latent/infectionswithfeedback.py index 934fedfd..b344a1b5 100644 --- a/pyrenew/latent/infectionswithfeedback.py +++ b/pyrenew/latent/infectionswithfeedback.py @@ -169,7 +169,7 @@ def sample( ) if inf_feedback_strength.ndim == Rt.ndim - 1: - inf_feedback_strength = jnp.expand_dims(inf_feedback_strength, 0) + inf_feedback_strength = inf_feedback_strength[jnp.newaxis] # Making sure inf_feedback_strength spans the Rt length if inf_feedback_strength.shape[0] == 1: diff --git a/test/test_infection_and_infectionwithfeedback.py b/test/test_infection_and_infectionwithfeedback.py deleted file mode 100644 index 5e1f44ef..00000000 --- a/test/test_infection_and_infectionwithfeedback.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Test to verify Infection and InfectionsWithFeedback class -return error when input array shape for I0 and Rt are invalid -""" - -import jax.numpy as jnp -import numpy as np -import numpyro -import pytest - -import pyrenew.latent as latent -from pyrenew.deterministic import DeterministicPMF, DeterministicVariable - - -def test_infections_with_feedback_invalid_inputs(): - """ - Test the InfectionsWithFeedback class cannot - be sampled when Rt and I0 have invalid input shapes - """ - I0_1d = jnp.array([0.5, 0.6, 0.7, 0.8]) - I0_2d = jnp.array( - np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3) - ).reshape((7, -1)) - Rt = jnp.ones(10) - gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) - - inf_feed_strength = DeterministicVariable( - name="inf_feed_strength", value=0.5 - ) - inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) - - # Test the InfectionsWithFeedback class - InfectionsWithFeedback = latent.InfectionsWithFeedback( - infection_feedback_strength=inf_feed_strength, - infection_feedback_pmf=inf_feedback_pmf, - ) - - infections = latent.Infections() - - with numpyro.handlers.seed(rng_seed=0): - with pytest.raises( - ValueError, - match="Initial infections must be at least as long as the generation interval.", - ): - InfectionsWithFeedback( - gen_int=gen_int, - Rt=Rt, - I0=I0_1d, - ) - - with pytest.raises( - ValueError, - match="Initial infections vector must be at least as long as the generation interval.", - ): - infections( - gen_int=gen_int, - Rt=Rt, - I0=I0_1d, - ) - - with pytest.raises( - ValueError, - match="Initial infections and Rt must have the same batch shapes.", - ): - InfectionsWithFeedback( - gen_int=gen_int, - Rt=Rt, - I0=I0_2d, - ) - - with pytest.raises( - ValueError, - match="Initial infections and Rt must have the same batch shapes.", - ): - infections( - gen_int=gen_int, - Rt=Rt, - I0=I0_2d, - ) diff --git a/test/test_infectionsrtfeedback.py b/test/test_infectionsrtfeedback.py index 8620bf10..9be03312 100644 --- a/test/test_infectionsrtfeedback.py +++ b/test/test_infectionsrtfeedback.py @@ -197,3 +197,70 @@ def test_infectionsrtfeedback_feedback(Rt, I0): assert_array_almost_equal(samp1.rt, res["rt"]) return None + + +def test_infections_with_feedback_invalid_inputs(): + """ + Test the InfectionsWithFeedback class cannot + be sampled when Rt and I0 have invalid input shapes + """ + I0_1d = jnp.array([0.5, 0.6, 0.7, 0.8]) + I0_2d = jnp.array( + np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] * 3) + ).reshape((7, -1)) + Rt = jnp.ones(10) + gen_int = jnp.array([0.4, 0.25, 0.25, 0.1, 0.0, 0.0, 0.0]) + + inf_feed_strength = DeterministicVariable( + name="inf_feed_strength", value=0.5 + ) + inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) + + # Test the InfectionsWithFeedback class + InfectionsWithFeedback = latent.InfectionsWithFeedback( + infection_feedback_strength=inf_feed_strength, + infection_feedback_pmf=inf_feedback_pmf, + ) + + infections = latent.Infections() + + with numpyro.handlers.seed(rng_seed=0): + with pytest.raises( + ValueError, + match="Initial infections must be at least as long as the generation interval.", + ): + InfectionsWithFeedback( + gen_int=gen_int, + Rt=Rt, + I0=I0_1d, + ) + + with pytest.raises( + ValueError, + match="Initial infections vector must be at least as long as the generation interval.", + ): + infections( + gen_int=gen_int, + Rt=Rt, + I0=I0_1d, + ) + + with pytest.raises( + ValueError, + match="Initial infections and Rt must have the same batch shapes.", + ): + InfectionsWithFeedback( + gen_int=gen_int, + Rt=Rt, + I0=I0_2d, + ) + + with pytest.raises( + ValueError, + match="Initial infections and Rt must have the same batch shapes.", + ): + infections( + gen_int=gen_int, + Rt=Rt, + I0=I0_2d, + )