Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow for batched inference with observational SIR model, add test for batched inference #566

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

rfl-urbaniak
Copy link
Collaborator

The sir observational model wasn't general enough to guide the user in building models for inference with batched data. It made a gesture in this direction by
allowing:

# Note: Here we set the event_dim to 1 if the last dimension of X["I"] is > 1, as the sir_observation_model
    # can be used for both single and multi-dimensional observations.
    event_dim = 1 if X["I"].shape and X["I"].shape[-1] > 1 else 0

But if a model of this type is passed to SVI inference, the log prob shapes are wrong. To ensure their correctness, we also need to introduce a plate of appropriate shape.

To illustrate and to ensure proper functionality, I added dynamical/test_batched_inference.py. For illustration, commenting out the lines introducing the plate in that test model (and un-indenting the pyro.sample statements) will lead to the type of log prob shape error in question.

Accordingly, I revised the dynamical systems notebook. The sir observational model now is:

def sir_observation_model(X: State[torch.Tensor]) -> None:
    # We don't observe the number of susceptible individuals directly.
    
    # Note: Here we set the event_dim to 1 if the last dimension of X["I"] is > 1, as the sir_observation_model
    # can be used for both single and multi-dimensional observations.
    event_dim = 1 if X["I"].shape and X["I"].shape[-1] > 1 else 0

    # Note: such plating while not necessary for this example,
    # would be needed to ensure proper log prob shapes
    # in inference with multiple observed time series,
    # so we include it for illustrative purposes.
    n = X["I"].shape[-2] if len(X["I"].shape) >= 2 else 1
    with pyro.plate("data", n, dim=-2):
        pyro.sample("I_obs", dist.Poisson(X["I"]).to_event(event_dim))  # noisy number of infected actually observed
        pyro.sample("R_obs", dist.Poisson(X["R"]).to_event(event_dim))  # noisy number of recovered actually observed

Otherwise, small changes, including adding plot.show() and partial prediction parallelization:

sir_data = dict(**{k:tr.trace.nodes[k]["value"] for k in ["I_obs", "R_obs"]})```

to

sir_data = dict(**{k:tr.trace.nodes[k]["value"].view(-1) for k in ["I_obs", "R_obs"]})
plt.xlim(start_time, end_time)
plt.xlabel("Time (Months)")
plt.ylabel("# of Individuals (Millions)")
plt.legend(loc="upper right")

to:

plt.xlim(start_time, end_time)
plt.xlabel("Time (Months)")
plt.ylabel("# of Individuals (Millions)")
plt.legend(loc="upper right")
plt.show()

(added plt.show() at a few locations to avoid redundant printing of object names before plotting)

# Generate samples from the posterior predictive distribution
sir_predictive = Predictive(simulated_bayesian_sir, guide=sir_guide, num_samples=num_samples)
sir_posterior_samples = sir_predictive(init_state, start_time, logging_times)

to

# Generate samples from the posterior predictive distribution
sir_predictive = Predictive(simulated_bayesian_sir, guide=sir_guide, num_samples=num_samples, parallel = True)
sir_posterior_samples = sir_predictive(init_state, start_time, logging_times)

and

intervened_sir_predictive = Predictive(intervened_sir, guide=sir_guide, num_samples=num_samples)
intervened_sir_posterior_samples = intervened_sir_predictive(lockdown_start, lockdown_end, lockdown_strength, init_state_lockdown, start_time, logging_times)

to

intervened_sir_predictive = Predictive(intervened_sir, guide=sir_guide, num_samples=num_samples, parallel=True)
intervened_sir_posterior_samples = intervened_sir_predictive(lockdown_start, lockdown_end, lockdown_strength, init_state_lockdown, start_time, logging_times)

There seems to be a small shape-related bug in the notebook that leads to runtime error with parallelizaton at a few locations. It remains unfixed. The locations are:

uncertain_intervened_sir_predictive = Predictive(uncertain_intervened_sir, guide=sir_guide, num_samples=num_samples)
uncertain_intervened_sir_posterior_samples = uncertain_intervened_sir_predictive(lockdown_strength, init_state_lockdown, start_time, logging_times)

and


dynamic_intervened_sir_predictive = Predictive(dynamic_intervened_sir, guide=sir_guide, num_samples=num_samples)
dynamic_intervened_sir_posterior_samples = dynamic_intervened_sir_predictive(lockdown_trigger, lockdown_lift_trigger, lockdown_strength, init_state_lockdown, start_time, logging_times)

and

uncertain_dynamic_intervened_sir_predictive = Predictive(uncertain_dynamic_intervened_sir, guide=sir_guide, num_samples=num_samples)
uncertain_dynamic_intervened_sir_posterior_samples = (uncertain_dynamic_intervened_sir_predictive(lockdown_strength, init_state_lockdown, start_time, logging_times))

The whole notebook has been re-run.

@rfl-urbaniak rfl-urbaniak requested a review from SamWitty August 26, 2024 20:52
@rfl-urbaniak rfl-urbaniak added the status:awaiting review Awaiting response from reviewer label Aug 26, 2024
Copy link
Collaborator

@SamWitty SamWitty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My preference would be to make the model introduced in the tutorial only as complex as is necessary to demonstrate the functionality in the tutorial itself. If anything, this would suggest removing the event_dim = 1 if ... line altogether, rather than adding a plate to account for observations we don't actually use here. Instead, we can add the additional plate to a future example that does actually condition on multiple trajectories.

As far as the test is concerned, I think it's a good illustration of the issue, but not exactly minimal. I'd be in support of including a much much smaller test using the existing SIR test model with multiple observations.

The other changes (especially the parallel=True) are good and helpful, but should be separated out into a separate standalone PR.

@SamWitty SamWitty added status:awaiting response Awaiting response from creator module:dynamical and removed status:awaiting review Awaiting response from reviewer labels Aug 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:dynamical status:awaiting response Awaiting response from creator
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants