Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove SampledValue #441

Merged
merged 34 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
04353f7
first pass
damonbayer Sep 10, 2024
0928a1c
removing t_unit and t_start
damonbayer Sep 10, 2024
4e1f9ce
fix some tests
damonbayer Sep 10, 2024
30b24ee
fix deterministicvariable class
damonbayer Sep 10, 2024
62069ab
remove _assert_sample_and_rtype
damonbayer Sep 10, 2024
bd2344e
fix deterministicVariable
damonbayer Sep 10, 2024
c3f35d0
fix test_infection_initialization_method
damonbayer Sep 10, 2024
991b0b3
fix test_iid_random_sequence
damonbayer Sep 10, 2024
24c4f00
fix test_ar_process
damonbayer Sep 10, 2024
65e7b79
fix test_differenced_process
damonbayer Sep 10, 2024
91854fc
fix test_rtperiodicdiff
damonbayer Sep 10, 2024
a6b9af3
fix test_distributional_rv
damonbayer Sep 10, 2024
7389a0c
fix test_transformed_rv_class
damonbayer Sep 10, 2024
6fcae0a
fix test_infection_initialization_process
damonbayer Sep 10, 2024
ed473fb
fix test_random_walk
damonbayer Sep 10, 2024
a6f66c7
actually fix test_transformed_rv_class
damonbayer Sep 10, 2024
535f7e0
finally fix test_transformed_rv_class
damonbayer Sep 10, 2024
b04675a
fix test_latent_infections
damonbayer Sep 10, 2024
46ef499
fix utils.py
damonbayer Sep 10, 2024
215dbb5
fix test_observation_negativebinom
damonbayer Sep 10, 2024
aaaba08
fix test_infectionsrtfeedback
damonbayer Sep 10, 2024
dcb743c
fix test_latent_admissions
damonbayer Sep 10, 2024
c9c3f14
fix test_observation_negativebinom
damonbayer Sep 10, 2024
aa78e36
fix test_model_basic_renewal
damonbayer Sep 10, 2024
271b336
all tests working
damonbayer Sep 10, 2024
9e8ce13
remove time tutorial
damonbayer Sep 10, 2024
c6e13ad
fix basic_renewal_model tutorial
damonbayer Sep 10, 2024
ab0eff7
fix periodic_effects tutorial
damonbayer Sep 10, 2024
ca9f8b6
fix extending_pyrenew
damonbayer Sep 11, 2024
88a856f
fixed day_of_the_week tutorial
damonbayer Sep 11, 2024
2e1ccf0
fix hospital_admissions_model.qmd
damonbayer Sep 11, 2024
e36f061
Remove NullProcess
damonbayer Sep 11, 2024
72af339
remove extraneous code in test_iid_random_sequence
damonbayer Sep 11, 2024
3206d05
Remove extra *_ in tests
damonbayer Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalVariable(name="I0", distribution=dist.LogNormal(2.5, 1)),
InitializeInfectionsZeroPad(pmf_array.size),
t_unit=1,
)


Expand All @@ -152,9 +151,9 @@ class MyRt(RandomVariable):
rt_init_rv = DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
init_rt = rt_init_rv.sample()

return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)


rt_proc = MyRt()
Expand Down Expand Up @@ -220,11 +219,11 @@ import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)

# Rt plot
axs[0].plot(sim_data.Rt.value)
axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")

# Infections plot
axs[1].plot(sim_data.observed_infections.value)
axs[1].plot(sim_data.observed_infections)
axs[1].set_ylabel("Infections")

fig.suptitle("Basic renewal model")
Expand All @@ -242,7 +241,7 @@ import jax
model1.run(
num_warmup=2000,
num_samples=1000,
data_observed_infections=sim_data.observed_infections.value,
data_observed_infections=sim_data.observed_infections,
rng_key=jax.random.key(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
Expand Down
16 changes: 8 additions & 8 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ I0 = InfectionInitializationProcess(
n_initialization_points,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
)

# Generation interval and Rt
Expand All @@ -110,11 +109,11 @@ class MyRt(metaclass.RandomVariable):

def sample(self, n: int, **kwargs) -> tuple:
# Standard deviation of the random walk
sd_rt, *_ = self.sd_rv()
sd_rt = self.sd_rv()

# Random walk step
step_rv = randomvariable.DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value)
name="rw_step_rv", distribution=dist.Normal(0, sd_rt)
)

rt_init_rv = randomvariable.DistributionalVariable(
Expand All @@ -133,9 +132,9 @@ class MyRt(metaclass.RandomVariable):
base_rv=base_rv,
transforms=transformation.ExpTransform(),
)
init_rt, *_ = rt_init_rv.sample()
init_rt = rt_init_rv.sample()

return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)


rtproc = MyRt(
Expand Down Expand Up @@ -168,7 +167,7 @@ obs = observation.NegativeBinomialObservation(
)
```

4. And finally, we built the model:
4. And finally, we build the model:

```{python}
# | label: init-model
Expand Down Expand Up @@ -282,10 +281,11 @@ As a result, we can see the posterior distribution of our novel day-of-the-week
# | label: fig-output-day-of-week
# | fig-cap: Day of the week effect
out = hosp_model_dow.plot_posterior(
var="dayofweek_effect", ylab="Day of the Week Effect", samples=500
var="dayofweek_effect_raw", ylab="Day of the Week Effect", samples=500
)

sp = hosp_model_dow.spread_draws(["dayofweek_effect"])
sp = hosp_model_dow.spread_draws(["dayofweek_effect_raw"])
# dayofweek_effect is not recorded
```

The new model with the day-of-the-week effect can be compared to the previous model without the effect. Finally, let's reproduce the figure without the day-of-the-week effect, and then plot the new model with the effect:
Expand Down
33 changes: 15 additions & 18 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ I0 = InfectionInitializationProcess(
gen_int_array.size,
DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
)

latent_infections = InfectionsWithFeedback(
Expand Down Expand Up @@ -85,9 +84,9 @@ class MyRt(RandomVariable):
rt_init_rv = DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
init_rt = rt_init_rv.sample()

return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)
```

With all the components defined, we can build the model:
Expand Down Expand Up @@ -118,7 +117,7 @@ with numpyro.handlers.seed(rng_seed=223):
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot(model0_samp.latent_infections.value)
ax.plot(model0_samp.latent_infections)
ax.set_xlabel("Time")
ax.set_ylabel("Infections")
plt.show()
Expand Down Expand Up @@ -164,7 +163,7 @@ from collections import namedtuple
# Creating a tuple to store the output
InfFeedbackSample = namedtuple(
typename="InfFeedbackSample",
field_names=["infections", "rt"],
field_names=["post_initialization_infections", "rt"],
defaults=(None, None),
)
```
Expand All @@ -175,7 +174,7 @@ The next step is to create the actual class. The bulk of its implementation lies
# | label: new-model-def
# | code-line-numbers: true
# Creating the class
from pyrenew.metaclass import RandomVariable, SampledValue
from pyrenew.metaclass import RandomVariable
from pyrenew.latent import compute_infections_from_rt_with_feedback
from pyrenew import arrayutils as au
from jax.typing import ArrayLike
Expand Down Expand Up @@ -219,11 +218,11 @@ class InfFeedback(RandomVariable):
I0_vec = I0[-gen_int_rev.size :]

# Sampling inf feedback strength and adjusting the shape
inf_feedback_strength, *_ = self.infection_feedback_strength(
inf_feedback_strength = self.infection_feedback_strength(
**kwargs,
)

inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength.value)
inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength)

inf_feedback_strength = au.pad_x_to_match_y(
x=inf_feedback_strength,
Expand All @@ -232,8 +231,8 @@ class InfFeedback(RandomVariable):
)

# Sampling inf feedback and adjusting the shape
inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.value)
inf_feedback_pmf = self.infection_feedback_pmf(**kwargs)
inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)

# Generating the infections with feedback
all_infections, Rt_adj = compute_infections_from_rt_with_feedback(
Expand All @@ -250,20 +249,18 @@ class InfFeedback(RandomVariable):
# Preparing theoutput

return InfFeedbackSample(
infections=SampledValue(all_infections),
rt=SampledValue(Rt_adj),
post_initialization_infections=all_infections,
rt=Rt_adj,
)
```

The core of the class is implemented in the `sample()` method. Things to highlight from the above code:

1. **Arguments of `sample`**: The `InfFeedback` class will be used within `RtInfectionsRenewalModel` to generate latent infections. During the sampling process, `InfFeedback()` will receive the reproduction number, the initial number of infections, and the generation interval. `RandomVariable()` calls are expected to include the `**kwargs` argument, even if unused.

2. **Calls to `RandomVariable()`**: All calls to `RandomVariable()` are expected to return a tuple or named tuple. In our implementation, we capture the output of `infection_feedback_strength()` and `infection_feedback_pmf()` in the variables `inf_feedback_strength` and `inf_feedback_pmf`, respectively, disregarding the other outputs (i.e., using `*_`).

3. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later.
2. **Saving computed quantities**: Since `Rt_adj` is not generated via `numpyro.sample()`, we use `numpyro.deterministic()` to record the quantity to a site; allowing us to access it later.

4. **Return type of `InfFeedback()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.
3. **Return type of `InfFeedback()`**: As said before, the `sample()` method should return a tuple or named tuple. In our case, we return a named tuple `InfFeedbackSample` with two fields: `infections` and `rt`.

```{python}
# | label: simulation2
Expand Down Expand Up @@ -293,8 +290,8 @@ Comparing `model0` with `model1`, these two should match:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(ncols=2)
ax[0].plot(model0_samp.latent_infections.value)
ax[1].plot(model1_samp.latent_infections.value)
ax[0].plot(model0_samp.latent_infections)
ax[1].plot(model1_samp.latent_infections)
ax[0].set_xlabel("Time (model 0)")
ax[1].set_xlabel("Time (model 1)")
ax[0].set_ylabel("Infections")
Expand Down
12 changes: 5 additions & 7 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ I0 = InfectionInitializationProcess(
n_initialization_points,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
t_unit=1,
)

# Generation interval and Rt
Expand Down Expand Up @@ -207,9 +206,9 @@ class MyRt(metaclass.RandomVariable):
rt_init_rv = randomvariable.DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
init_rt = rt_init_rv.sample()

return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)


rtproc = MyRt()
Expand Down Expand Up @@ -256,7 +255,6 @@ import numpy as np

timeframe = 120


with numpyro.handlers.seed(rng_seed=223):
simulated_data = hosp_model.sample(n_datapoints=timeframe)
```
Expand All @@ -269,11 +267,11 @@ import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)

# Rt plot
axs[0].plot(simulated_data.Rt.value)
axs[0].plot(simulated_data.Rt)
axs[0].set_ylabel("Simulated Rt")

# Admissions plot
axs[1].plot(simulated_data.observed_hosp_admissions.value, "-o")
axs[1].plot(simulated_data.observed_hosp_admissions, "-o")
axs[1].set_ylabel("Simulated Admissions")

fig.suptitle("Basic renewal model")
Expand Down Expand Up @@ -483,7 +481,7 @@ def compute_eti(dataset, eti_prob):
eti_bdry = dataset.quantile(
((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw")
)
return eti_bdry.values.T
return eti_bdry.T


fig, axes = plt.subplots(figsize=(6, 5))
Expand Down
7 changes: 2 additions & 5 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ with numpyro.handlers.seed(rng_seed=20):
# Plotting the Rt values
import matplotlib.pyplot as plt

plt.step(np.arange(len(sim_data.rt.value)), sim_data.rt.value, where="post")
plt.step(np.arange(len(sim_data)), sim_data, where="post")
plt.xlabel("Time")
plt.ylabel("Rt")
plt.title("Simulated Rt values")
Expand Down Expand Up @@ -79,7 +79,6 @@ dayofweek = process.DayOfWeekEffect(
quantity_to_broadcast=randomvariable.DistributionalVariable(
name="simp", distribution=mysimplex
),
t_start=0,
)
```

Expand All @@ -92,9 +91,7 @@ with numpyro.handlers.seed(rng_seed=20):
# Plotting the effect values
import matplotlib.pyplot as plt

plt.step(
np.arange(len(sim_data.value.value)), sim_data.value.value, where="post"
)
plt.step(np.arange(len(sim_data)), sim_data, where="post")
plt.xlabel("Time")
plt.ylabel("Effect size")
plt.title("Simulated Day of Week Effect values")
Expand Down
51 changes: 0 additions & 51 deletions docs/source/tutorials/time.qmd

This file was deleted.

5 changes: 0 additions & 5 deletions docs/source/tutorials/time.rst

This file was deleted.

2 changes: 1 addition & 1 deletion pyrenew/arrayutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
value: ArrayLike | None = None

def __repr__(self):
return f"PeriodicProcessSample(value={self.value})"
return f"PeriodicProcessSample(value={self})"

Check warning on line 112 in pyrenew/arrayutils.py

View check run for this annotation

Codecov / codecov/patch

pyrenew/arrayutils.py#L112

Added line #L112 was not covered by tests


def tile_until_n(
Expand Down
9 changes: 1 addition & 8 deletions pyrenew/deterministic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,11 @@

from pyrenew.deterministic.deterministic import DeterministicVariable
from pyrenew.deterministic.deterministicpmf import DeterministicPMF
from pyrenew.deterministic.nullrv import (
NullObservation,
NullProcess,
NullVariable,
)
from pyrenew.deterministic.process import DeterministicProcess
from pyrenew.deterministic.nullrv import NullObservation, NullVariable

__all__ = [
"DeterministicVariable",
"DeterministicPMF",
"DeterministicProcess",
"NullVariable",
"NullProcess",
"NullObservation",
]
Loading