-
Notifications
You must be signed in to change notification settings - Fork 4
/
getting_started.jl
474 lines (380 loc) · 17.6 KB
/
getting_started.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
### A Pluto.jl notebook ###
# v0.19.40
using Markdown
using InteractiveUtils
# ╔═╡ c593a2a0-d7f5-11ee-0931-d9f65ae84a72
# hideall
let
docs_dir = dirname(dirname(@__DIR__))
pkg_dir = dirname(docs_dir)
using Pkg: Pkg
Pkg.activate(docs_dir)
Pkg.develop(; path = pkg_dir)
Pkg.instantiate()
end;
# ╔═╡ da479d8d-1312-4b98-b0af-5be52dffaf3f
begin
using EpiAware
using Turing
using Distributions
using StatsPlots
using Random
using DynamicPPL
using Statistics
using DataFramesMeta
using LinearAlgebra
using Transducers
using ReverseDiff
using Pathfinder
end
# ╔═╡ 3ebc8384-f73d-4597-83a7-07a3744fed61
md"
# Getting stated with `EpiAware`
This tutorial introduces the basic functionality of `EpiAware`. `EpiAware` is a package for making inferences on epidemiological case/determined infection data using a model-based approach.
## `EpiAware` models
The models we consider are discrete-time $t = 1,\dots, T$ with a latent random process, $Z_t$ generating stochasticity in the number of new infections $I_t$ at each time step. Observations are treated as downstream random variables determined by the actual infections and a model of infection to observation delay.
#### Mathematical definition
```math
\begin{align}
Z_\cdot &\sim \mathcal{P}(\mathbb{R}^T) | \theta_Z, \\
I_0 &\sim f_0(\theta_I), \\
I_t &\sim g_I(\{I_s, Z_s\}_{s < t}, \theta_{I}), \\
y_t &\sim f_O(\{I_s\}_{s \leq t}, \theta_{O}).
\end{align}
```
Where, $\mathcal{P}(\mathbb{R}^T) | \theta_Z$ is a parametric process on $\mathbb{R}^T$. $f_0$ and $f_O$ are parametric distributions on, respectively, the initial number of infections and the observed case data conditional on underlying infections. $g_I$ is distribution of new infections conditional on infections and latent process in the past. Note that we assume that new infections are conditional on the strict past, whereas new observations can depend on infections on the same time step.
#### Code structure outline
An `EpiAware` model in code is created from three modular components:
- A `LatentModel`: This defines the model for $Z_\cdot$. This chooses $\mathcal{P}(\mathbb{R}^T) | \theta_Z$.
- An `EpiModel`: This defines a generative process for infections conditional on the latent process. This chooses $f_0(\theta_I)$, and $g_I(\{I_s, Z_s\}_{s < t}, \theta_{I})$.
- An `ObservationModel`: This defines the observation model. This chooses $f_O({I_s}_{s \leq t}, \theta_{O})$
#### Reproductive number
`EpiAware` models do not need to specify a time-varying reproductive number $\mathcal{R}_t$ to generate $I_\cdot$, however, this is often a quantity of interest. When not directly used we will typically back-calculate $\mathcal{R}_t$ from the generated infections:
```math
\mathcal{R}_t = {I_t \over \sum_{s \geq 1} g_s I_{t-s} }.
```
Where $g_s$ is a discrete generation interval. For this reason, even when not using a reproductive number approach directly, we ask for a generation interval.
"
# ╔═╡ 5a0d5ab8-e985-4126-a1ac-58fe08beee38
md"
## Random walk `LatentModel`
As an example, we choose the latent process as a random walk with parameters $\theta_Z$:
- ``Z_0``: Initial position.
- ``\sigma^2_{Z}``: The step-size variance.
Conditional on the parameters the random walk is then generated by white noise:
```math
\begin{align}
Z_t &= Z_0 + \sigma_{RW} \sum_{t= 1}^T \epsilon_t, \\
\epsilon_t &\sim \mathcal{N}(0,1).
\end{align}
```
In `EpiAware` we provide a constructor for random walk latent models with priors for $\theta_Z$. We choose priors,
```math
\begin{align}
Z_0 &\sim \mathcal{N}(0,1),\\
\sigma_{RW} &\sim \text{HalfNormal}(0.1 * \sqrt{\pi} / \sqrt{2})).
\end{align}
```
"
# ╔═╡ 56ae496b-0094-460b-89cb-526627991717
rwp = EpiAware.RandomWalk(
init_prior = Normal(),
std_prior = EpiAware.EpiLatentModels._make_halfnormal_prior(0.1))
# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44
md"
## Direct infection `EpiModel`
This is a simple model where the unobserved log-infections are directly generated by the latent process $Z$.
```math
\log I_t = \log I_0 + Z_t.
```
As discussed above, we still ask for a defined generation interval, which can be used to calculate $\mathcal{R}_t$.
"
# ╔═╡ 9e43cbe3-94de-44fc-a788-b9c7adb34218
truth_GI = Gamma(2, 5)
# ╔═╡ f067284f-a1a6-44a6-9b79-f8c2de447673
md"
The `EpiData` constructor performs double interval censoring to convert our _continuous_ estimate of the generation interval into a discretized version. We also implement right truncation using the keyword `D_gen`.
"
# ╔═╡ c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43
model_data = EpiData(gen_distribution = truth_GI, D_gen = 10.0)
# ╔═╡ fd72094f-1b95-4d07-a8b0-ef47dc560dfc
md"
We can supply a prior for the initial log_infections.
"
# ╔═╡ 6639e66f-7725-4976-81b2-6472419d1a62
log_I0_prior = Normal(log(100.0), 1.0)
# ╔═╡ df5e59f8-3185-4bed-9cca-7c266df17cec
md"
And construct the `EpiModel`.
"
# ╔═╡ 6fbdd8e6-2323-4352-9185-1f31a9cf9012
epi_model = DirectInfections(model_data, log_I0_prior)
# ╔═╡ 10c750db-6d00-4ef6-9caa-3cf7b3c0d711
latent = generate_latent_infs(epi_model, 20)
# ╔═╡ 45b287b8-22b5-4f09-9a93-51df82477b01
rand(latent)
# ╔═╡ 5e62a50a-71f4-4902-b1c9-fdf51fe145fa
md"
### Delayed Observations `ObservationModel`
The observation model is a negative binomial distribution parameterised with mean $\mu$ and 'successes' parameter $r$. The standard deviation _relative_ to the mean $\sigma_{\text{rel}} = \sigma / \mu$ for negative binomial observations is,
```math
\sigma_{\text{rel}} =(1/\sqrt{\mu}) + (1 / \sqrt{r}).
```
It is standard to use a half-t distribution for standard deviation priors (e.g. as argued in this [paper](http://www.stat.columbia.edu/~gelman/research/published/taumain.pdf)); we specialise this to a Half-Normal prior and use an _a priori_ assumption that a typical observation fluctuation around the mean (when the mean is $\sim\mathcal{O}(10^2)$) would be 10%. This implies a standard deviation prior,
```math
1 / \sqrt{r} \sim \text{HalfNormal}\Big(0.1 ~\sqrt{{\pi \over 2}}\Big).
```
The $\sqrt{{\pi \over 2}}$ factor ensures the correct prior mean (see [here](https://en.wikipedia.org/wiki/Half-normal_distribution)).
The expected observed cases are delayed infections. Delays are implemented as the action of a sparse kernel on the infections $I(t)$.
```math
y_t \sim \text{NegBinomial}\Big(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r\Big). \\
```
"
# ╔═╡ e813d547-6100-4c43-b84c-8cebe306bda8
md"
We also set up the inference to occur over 100 days.
"
# ╔═╡ c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0
time_horizon = 30
# ╔═╡ 0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f
md"
We choose a simple observation model where infections are observed 0, 1, 2, 3 days later with equal probability.
"
# ╔═╡ 448669bc-99f4-4823-b15e-fcc9040ba31b
obs_model = DelayObservations(
fill(0.25, 4),
time_horizon,
EpiAware.EpiLatentModels._make_halfnormal_prior(0.1)
)
# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b
md"
## Generate cases from the `EpiAware` model
Having chosen an `EpiModel`, `LatentModel` and `ObservationModel`, we can implement the model as a [`Turing`](https://turinglang.org/dev/) model using `make_epi_aware`.
By giving `missing` to the first argument, we indicate that case data will be _generated_ from the model rather than treated as fixed.
"
# ╔═╡ abeff860-58c3-4644-9325-66ffd4446b6d
full_epi_aware_mdl = make_epi_aware(missing, time_horizon;
epi_model = epi_model,
latent_model = rwp,
observation_model = obs_model)
# ╔═╡ 821628fb-8044-48b0-aa4f-0b7b57a2f45a
md"
We choose some fixed parameters:
- Initial incidence is 100.
- In the direct infection model, the initial incidence and in the initial value of the random walk form a non-identifiable pair. Therefore, we fix $Z_0 = 0$.
"
# ╔═╡ 36b34fd2-2891-42ca-b5dc-abb482e516ee
fixed_parameters = (rw_init = 0.0, init_incidence = log(100.0))
# ╔═╡ 0aadd9e3-7f91-4b45-9663-67d11335f0d0
md"
We fix these parameters using `fix`, and generate a random epidemic.
"
# ╔═╡ 7e0e6012-8648-4f84-a25a-8b0138c4b72a
cond_generative_model = fix(full_epi_aware_mdl, fixed_parameters)
# ╔═╡ b20c28be-7b07-410c-a33b-ea5ad6828c12
random_epidemic = rand(cond_generative_model)
# ╔═╡ d073e63b-62da-4743-ace0-78ef7806bc0b
true_infections = generated_quantities(cond_generative_model, random_epidemic).I_t
# ╔═╡ a04f3c1b-7e11-4800-9c2a-9fc0021de6e7
generated_obs = generated_quantities(cond_generative_model, random_epidemic).generated_y_t
# ╔═╡ f68b4e41-ac5c-42cd-a8c2-8761d66f7543
let
plot(true_infections,
label = "I_t",
xlabel = "Time",
ylabel = "Infections",
title = "Generated Infections")
scatter!(generated_obs, lab = "generated cases")
end
# ╔═╡ b5bc8f05-b538-4abf-aa84-450bf2dff3d9
md"
## Inference
Fixing $Z_0 = 0$ for the random walk was based on inference principles; in this model $Z_0$ and $\log I_0$ are non-identifiable.
However, we now treat the generated data as `truth_data` and make inference without fixing any other parameters.
We do the inference by MCMC/NUTS using the `Turing` NUTS sampler with default warm-up steps.
"
# ╔═╡ 4a4c6e91-8d8f-4bbf-bb7e-a36dc281e312
md"
The observation model supports partially complete data. To test this we set some of the generated observations to be `missing`.
"
# ╔═╡ 525aa98c-d0e5-4ffa-b808-d90fc986204c
truth_data = generated_obs
# ╔═╡ 259a7042-e74f-43c7-aeb4-97a3beeb7776
let
truth_data = Union{Int, Missing}[truth_data...]
truth_data[vcat([3, 5], 10:20)] .= missing
end
# ╔═╡ 32638954-2c99-4d4e-8e03-52154030c657
md"
We now make the model but fixing the initial condition of the random walk to be 0.
"
# ╔═╡ b4033728-b321-4100-8194-1fd9fe2d268d
inference_mdl = fix(
make_epi_aware(truth_data, time_horizon; epi_model = epi_model,
latent_model = rwp, observation_model = obs_model),
(rw_init = 0.0,)
)
# ╔═╡ 9222b436-9445-4039-abbf-25c8cddb7f63
md"
### Initialising inference
It is possible for the default warm-up process for NUTS to get stuck in low probability or otherwise degenerate regions of parameter space.
To make NUTS more robust we provide `manypathfinder`, which is built on pathfinder variational inference from [Pathfinder.jl](https://mlcolab.github.io/Pathfinder.jl/stable/). `manypathfinder` runs `nruns` pathfinder processes on the inference problem and returns the pathfinder run with maximum estimated ELBO.
`manypathfinder` differs from `Pathfinder.multipathfinder`; `multipathfinder` is aimed at sampling from a potentially non-Gaussian target distribution which is first approximated as a uniformly weighted collection of normal approximations from pathfinder runs. `manypathfinder` is aimed at moving rapidly to a 'good' part of parameter space, and is robust to runs that fail.
"
# ╔═╡ 83389965-7e63-4cf6-bada-cb521b6a6257
best_pf = pathfinder(inference_mdl; adtype = AutoReverseDiff(true))
# ╔═╡ 073a1d40-456a-450e-969f-11b23eb7fd1f
md"
We can use draws from the best pathfinder run to initialise NUTS.
"
# ╔═╡ 0379b058-4c35-440a-bc01-aafa0178bdbf
best_pf.draws_transformed
# ╔═╡ a7798f71-9bb5-4506-9476-0cc11553b9e2
init_params = collect.(eachrow(best_pf.draws_transformed.value[1:4, :, 1]))
# ╔═╡ 4deb3a51-781d-48c4-91f6-6adf2b1affcf
md"
**NB: We are running this inference run for speed rather than accuracy as a demonstration. You might want to use a higher target acceptance and more samples in a typical workflow.**
"
# ╔═╡ 946b1c43-e750-40c9-9f14-79da9735e437
target_acc_prob = 0.8
# ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c
chn = sample(inference_mdl,
NUTS(target_acc_prob; adtype = AutoReverseDiff(true)),
MCMCThreads(),
250,
4;
init_params,
drop_warmup = true)
# ╔═╡ 30498cc7-16a5-441a-b8cd-c19b220c60c1
md"
### Predictive plotting
We can spaghetti plot generated case data from the version of the model _which hasn't conditioned on case data_ using posterior parameters inferred from the version conditioned on observed data. This is known as _posterior predictive checking_, and is a useful diagnostic tool for Bayesian inference (see [here](http://www.stat.columbia.edu/~gelman/book/BDA3.pdf)).
Because we are using synthetic data we can also plot the model predictions for the _unobserved_ infections and check that (at least in this example) we were able to capture some unobserved/latent variables in the process accurate.
"
# ╔═╡ e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5
let
post_check_mdl = fix(full_epi_aware_mdl, (rw_init = 0.0,))
post_check_y_t = mapreduce(hcat, generated_quantities(post_check_mdl, chn)) do gen
gen.generated_y_t
end
predicted_I_t = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen
gen.I_t
end
p1 = plot(post_check_y_t, c = :grey, alpha = 0.05, lab = "")
scatter!(p1, truth_data,
lab = "Observed cases",
xlabel = "Time",
ylabel = "Cases",
title = "Post. predictive checking: cases",
ylims = (-0.5, maximum(truth_data) * 1.5),
c = :green)
p2 = plot(predicted_I_t, c = :grey, alpha = 0.05, lab = "")
scatter!(p2, true_infections,
lab = "Actual infections",
xlabel = "Time",
ylabel = "Unobserved Infections",
title = "Post. predictions: infections",
ylims = (-0.5, maximum(true_infections) * 1.5),
c = :red)
plot(p1, p2,
layout = (1, 2),
size = (700, 400))
end
# ╔═╡ fd6321b1-4c3a-4123-b0dc-c45b951e0b80
md"
As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations.
"
# ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3
let
parameters_to_plot = (:σ_RW, :neg_bin_cluster_factor)
plts = map(parameters_to_plot) do name
var_samples = chn[name] |> vec
histogram(var_samples,
bins = 50,
norm = :pdf,
lw = 0,
fillalpha = 0.5,
lab = "MCMC")
vline!([getfield(random_epidemic, name)], lab = "True value")
title!(string(name))
end
plot(plts..., layout = (2, 1))
end
# ╔═╡ 81efe8ca-b753-4a12-bafc-a887a999377b
md"
## Reproductive number back-calculation
As mentioned at the top, we _don't_ directly use the concept of reproductive numbers in this note. However, we can back-calculate the implied $\mathcal{R}(t)$ values, conditional on the specified generation interval being correct.
Here we spaghetti plot posterior sampled time-varying reproductive numbers against the actual.
"
# ╔═╡ 15b9f37f-8d5f-460d-8c28-d7f2271fd099
let
n = epi_model.data.len_gen_int
Rt_denom = [dot(reverse(epi_model.data.gen_int), true_infections[(t - n):(t - 1)])
for t in (n + 1):length(true_infections)]
true_Rt = true_infections[(n + 1):end] ./ Rt_denom
predicted_Rt = mapreduce(hcat, generated_quantities(inference_mdl, chn)) do gen
_It = gen.I_t
_Rt_denom = [dot(reverse(epi_model.data.gen_int), _It[(t - n):(t - 1)])
for t in (n + 1):length(_It)]
Rt = _It[(n + 1):end] ./ _Rt_denom
end
plt = plot((n + 1):time_horizon, predicted_Rt, c = :grey, alpha = 0.05, lab = "")
plot!(plt, (n + 1):time_horizon, true_Rt,
lab = "true Rt",
xlabel = "Time",
ylabel = "Rt",
title = "Post. predictions: reproductive number",
c = :red,
lw = 2)
end
# ╔═╡ Cell order:
# ╠═c593a2a0-d7f5-11ee-0931-d9f65ae84a72
# ╟─3ebc8384-f73d-4597-83a7-07a3744fed61
# ╠═da479d8d-1312-4b98-b0af-5be52dffaf3f
# ╟─5a0d5ab8-e985-4126-a1ac-58fe08beee38
# ╠═56ae496b-0094-460b-89cb-526627991717
# ╟─767beffd-1ef5-4e6c-9ac6-edb52e60fb44
# ╠═9e43cbe3-94de-44fc-a788-b9c7adb34218
# ╟─f067284f-a1a6-44a6-9b79-f8c2de447673
# ╠═c0662d48-4b54-4b6d-8c91-ddf4b0e3aa43
# ╟─fd72094f-1b95-4d07-a8b0-ef47dc560dfc
# ╠═6639e66f-7725-4976-81b2-6472419d1a62
# ╟─df5e59f8-3185-4bed-9cca-7c266df17cec
# ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012
# ╠═10c750db-6d00-4ef6-9caa-3cf7b3c0d711
# ╠═45b287b8-22b5-4f09-9a93-51df82477b01
# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa
# ╟─e813d547-6100-4c43-b84c-8cebe306bda8
# ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0
# ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f
# ╠═448669bc-99f4-4823-b15e-fcc9040ba31b
# ╟─e49713e8-4840-4083-8e3f-fc52d791be7b
# ╠═abeff860-58c3-4644-9325-66ffd4446b6d
# ╟─821628fb-8044-48b0-aa4f-0b7b57a2f45a
# ╠═36b34fd2-2891-42ca-b5dc-abb482e516ee
# ╟─0aadd9e3-7f91-4b45-9663-67d11335f0d0
# ╠═7e0e6012-8648-4f84-a25a-8b0138c4b72a
# ╠═b20c28be-7b07-410c-a33b-ea5ad6828c12
# ╠═d073e63b-62da-4743-ace0-78ef7806bc0b
# ╠═a04f3c1b-7e11-4800-9c2a-9fc0021de6e7
# ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543
# ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9
# ╟─4a4c6e91-8d8f-4bbf-bb7e-a36dc281e312
# ╠═525aa98c-d0e5-4ffa-b808-d90fc986204c
# ╠═259a7042-e74f-43c7-aeb4-97a3beeb7776
# ╟─32638954-2c99-4d4e-8e03-52154030c657
# ╠═b4033728-b321-4100-8194-1fd9fe2d268d
# ╟─9222b436-9445-4039-abbf-25c8cddb7f63
# ╠═83389965-7e63-4cf6-bada-cb521b6a6257
# ╟─073a1d40-456a-450e-969f-11b23eb7fd1f
# ╠═0379b058-4c35-440a-bc01-aafa0178bdbf
# ╠═a7798f71-9bb5-4506-9476-0cc11553b9e2
# ╟─4deb3a51-781d-48c4-91f6-6adf2b1affcf
# ╠═946b1c43-e750-40c9-9f14-79da9735e437
# ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c
# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1
# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5
# ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80
# ╠═10d8fe24-83a6-47ac-97b7-a374481473d3
# ╟─81efe8ca-b753-4a12-bafc-a887a999377b
# ╠═15b9f37f-8d5f-460d-8c28-d7f2271fd099