Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade the testmode pipeline script to pipeline scripts: prior predictive + full #445

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
0fd303c
remove most compat from pipeline
SamuelBrand1 Aug 16, 2024
7d6a4a2
Update test_pipelinefunctions.jl
SamuelBrand1 Aug 16, 2024
ac21878
Fix the truthdata output to be missing or Int
SamuelBrand1 Aug 16, 2024
5704a7a
Merge branch 'fix-truthdata-ouput' into 406-consider-moving-from-broa…
SamuelBrand1 Aug 16, 2024
843bbf6
change to daily increments in latent processes
SamuelBrand1 Aug 16, 2024
bc8d716
Change default AD mode to ReverseDiff{true}
SamuelBrand1 Aug 16, 2024
c74b525
Create changelog.md
SamuelBrand1 Aug 16, 2024
7fbe16d
reformat
SamuelBrand1 Aug 16, 2024
390090e
change filename
SamuelBrand1 Aug 17, 2024
7d54d44
Update simulate.jl
SamuelBrand1 Aug 17, 2024
e6f3d18
add a testmode to pipelinetypes
SamuelBrand1 Aug 17, 2024
98e7a6c
tighten typing
SamuelBrand1 Aug 17, 2024
1a50101
Merge branch 'main' into 404-upgrade-the-testmode-pipeline-script-to-…
SamuelBrand1 Aug 30, 2024
fabb335
fix constructor
SamuelBrand1 Aug 30, 2024
60526a8
fix constructor
SamuelBrand1 Aug 30, 2024
f1e0b74
pipeline helper functions
SamuelBrand1 Aug 30, 2024
bc712e9
unit tests and end-to-end tests
SamuelBrand1 Aug 30, 2024
913dcc8
remove old scripts
SamuelBrand1 Aug 30, 2024
6c939a5
Fix y_t type
SamuelBrand1 Aug 31, 2024
c5b5bd2
Merge branch 'main' into 404-upgrade-the-testmode-pipeline-script-to-…
SamuelBrand1 Aug 31, 2024
7069b03
bring inference prefixing into line with truthdata prefixing
SamuelBrand1 Aug 31, 2024
2ee8e49
Merge branch '404-upgrade-the-testmode-pipeline-script-to-just-pipeli…
SamuelBrand1 Aug 31, 2024
863eb3b
Merge branch 'main' into 433-forecast-function-in-pipeline-is-broken-…
SamuelBrand1 Sep 3, 2024
1755377
Merge branch 'main' into 404-upgrade-the-testmode-pipeline-script-to-…
SamuelBrand1 Sep 3, 2024
e5c0857
Merge branch '433-forecast-function-in-pipeline-is-broken-after-move-…
SamuelBrand1 Sep 3, 2024
5562326
reformat
SamuelBrand1 Sep 10, 2024
b625dcd
interim commit
SamuelBrand1 Sep 19, 2024
775f75b
rename and collect plot functions
SamuelBrand1 Sep 19, 2024
d046932
tidy up tests
SamuelBrand1 Sep 19, 2024
5943a73
basic plots to Makie
SamuelBrand1 Sep 20, 2024
38bfdf2
new plot tests and test reorganisation
SamuelBrand1 Sep 20, 2024
7de189f
Update runtests.jl
SamuelBrand1 Sep 30, 2024
5d03fc8
reformat
SamuelBrand1 Sep 30, 2024
e4951b4
455 plotting methods of prior predictive (#462)
SamuelBrand1 Oct 1, 2024
f6ac3ae
Merge branch 'main' into 404-upgrade-the-testmode-pipeline-script-to-…
SamuelBrand1 Oct 3, 2024
f93efeb
Merge branch 'main' into 404-upgrade-the-testmode-pipeline-script-to-…
SamuelBrand1 Oct 3, 2024
657ee3c
Make pipeline inference tests more focused (#475)
SamuelBrand1 Oct 4, 2024
5cd5040
Full priorpred check pipeline script (#476)
SamuelBrand1 Oct 7, 2024
ae8e3ca
Add oneexpy to pipeline (#480)
SamuelBrand1 Oct 7, 2024
234c77f
Changed saved prior predictive data (#478)
SamuelBrand1 Oct 7, 2024
66367e0
Merge branch 'main' into 404-upgrade-the-testmode-pipeline-script-to-…
SamuelBrand1 Oct 7, 2024
2515fe3
Update run_priorpred_pipeline.jl
SamuelBrand1 Oct 7, 2024
6a58e3e
Only save strings rather than Exception objects
SamuelBrand1 Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pipeline/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand Down
25 changes: 0 additions & 25 deletions pipeline/scripts/run_pipeline.jl

This file was deleted.

17 changes: 13 additions & 4 deletions pipeline/scripts/run_priorpred_pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ using Pkg
Pkg.activate(joinpath(@__DIR__(), ".."))
using Dagger

@assert !isempty(ARGS) "Test mode script requires the number of draws as an argument."
ndraws = parse(Int64, ARGS[1])

@info("""
Running the analysis pipeline.
Running the prior predictive pipeline in test mode with $(ndraws) draws per model.
--------------------------------------------
""")

Expand All @@ -15,11 +18,17 @@ pids = addprocs(; exeflags = ["--project=$(Base.active_project())"])

@everywhere using EpiAwarePipeline

# Create an instance of the pipeline behaviour
pipeline = RtwithoutRenewalPriorPipeline()
# Create instances of the pipeline behaviour

pipelines = [
SmoothOutbreakPipeline(ndraws = ndraws, nchains = 1, priorpredictive = true),
MeasuresOutbreakPipeline(ndraws = ndraws, nchains = 1, priorpredictive = true),
SmoothEndemicPipeline(ndraws = ndraws, nchains = 1, priorpredictive = true),
RoughEndemicPipeline(ndraws = ndraws, nchains = 1, priorpredictive = true)
]

# Run the pipeline
do_pipeline(pipeline)
do_pipeline(pipelines)

# Remove the workers
rmprocs(pids)
13 changes: 7 additions & 6 deletions pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ with execution determined by available computational resources.
module EpiAwarePipeline

using CSV, Dagger, DataFramesMeta, Dates, Distributions, DocStringExtensions, DrWatson,
EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2, MCMCChains, Turing,
DynamicPPL, LogExpFunctions, RCall, LinearAlgebra, Random, AlgebraOfGraphics,
CairoMakie, ReverseDiff
EpiAware, Statistics, ADTypes, AbstractMCMC, JLD2, MCMCChains, Turing, DynamicPPL,
LogExpFunctions, RCall, LinearAlgebra, Random, AlgebraOfGraphics, CairoMakie,
ReverseDiff

using EpiAware.EpiInfModels: oneexpy

# Exported pipeline types
export AbstractEpiAwarePipeline, EpiAwarePipeline, AbstractRtwithoutRenewalPipeline,
Expand Down Expand Up @@ -56,7 +58,7 @@ export make_prediction_dataframe_from_output, make_truthdata_dataframe,
export figureone, figuretwo

# Exported functions: plot functions
export plot_truth_data, plot_Rt
export plot_truth_data, plot_Rt, prior_predictive_plot

include("docstrings.jl")
include("pipeline/pipeline.jl")
Expand All @@ -67,6 +69,5 @@ include("infer/infer.jl")
include("forecast/forecast.jl")
include("scoring/scoring.jl")
include("analysis/analysis.jl")
include("mainplots/mainplots.jl")
include("plot_functions.jl")
include("plotting/plotting.jl")
end
38 changes: 22 additions & 16 deletions pipeline/src/constructors/make_inference_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Constructs an inference method for the given pipeline. This is a default method.
- An inference method.

"""
function make_inference_method(pipeline::AbstractEpiAwarePipeline; ndraws::Integer = 2000,
function make_inference_method(ndraws::Integer, pipeline::AbstractEpiAwarePipeline;
mcmc_ensemble::AbstractMCMC.AbstractMCMCEnsemble = MCMCSerial(),
nruns_pthf::Integer = 4, maxiters_pthf::Integer = 100, nchains::Integer = 4)
return EpiMethod(
Expand All @@ -19,27 +19,28 @@ function make_inference_method(pipeline::AbstractEpiAwarePipeline; ndraws::Integ
end

"""
Method for sampling from prior predictive distribution of the model.
Example pipeline.
"""
function make_inference_method(pipeline::RtwithoutRenewalPriorPipeline; n_samples = 2_000)
function make_inference_method(
pipeline::EpiAwareExamplePipeline; ndraws::Integer = 20,
mcmc_ensemble::AbstractMCMC.AbstractMCMCEnsemble = MCMCThreads(),
nruns_pthf::Integer = 4, maxiters_pthf::Integer = 100, nchains::Integer = 4)
return EpiMethod(
pre_sampler_steps = AbstractEpiOptMethod[],
sampler = DirectSample(n_samples = n_samples)
pre_sampler_steps = [ManyPathfinder(nruns = nruns_pthf, maxiters = maxiters_pthf)],
sampler = NUTSampler(
target_acceptance = 0.9, adtype = AutoReverseDiff(; compile = true),
ndraws = ndraws, nchains = nchains, mcmc_parallel = mcmc_ensemble)
)
end

"""
Pipeline test mode method for sampling from prior predictive distribution of the model.
Method for sampling from prior predictive distribution of the model.
"""
function make_inference_method(
pipeline::EpiAwareExamplePipeline; ndraws::Integer = 20,
mcmc_ensemble::AbstractMCMC.AbstractMCMCEnsemble = MCMCThreads(),
nruns_pthf::Integer = 4, maxiters_pthf::Integer = 100, nchains::Integer = 4)
pipeline::AbstractRtwithoutRenewalPipeline, ::Val{:priorpredictive})
return EpiMethod(
pre_sampler_steps = [ManyPathfinder(nruns = nruns_pthf, maxiters = maxiters_pthf)],
sampler = NUTSampler(
target_acceptance = 0.9, adtype = AutoReverseDiff(; compile = true), ndraws = ndraws,
nchains = nchains, mcmc_parallel = mcmc_ensemble)
pre_sampler_steps = AbstractEpiOptMethod[],
sampler = DirectSample(n_samples = pipeline.ndraws)
)
end

Expand All @@ -55,7 +56,12 @@ Constructs an inference method for the Rt-without-renewal pipeline.
# Examples
"""
function make_inference_method(pipeline::AbstractRtwithoutRenewalPipeline)
return make_inference_method(pipeline; ndraws = pipeline.ndraws,
mcmc_ensemble = pipeline.mcmc_ensemble, nruns_pthf = pipeline.nruns_pthf,
maxiters_pthf = pipeline.maxiters_pthf, nchains = pipeline.nchains)
if pipeline.priorpredictive
return make_inference_method(pipeline, Val(:priorpredictive))
else
return make_inference_method(
pipeline.ndraws, pipeline; mcmc_ensemble = pipeline.mcmc_ensemble,
nruns_pthf = pipeline.nruns_pthf,
maxiters_pthf = pipeline.maxiters_pthf, nchains = pipeline.nchains)
end
end
8 changes: 8 additions & 0 deletions pipeline/src/constructors/selector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,11 @@ Example/test mode is to return a randomly selected item from the list.
function _selector(list, pipeline::EpiAwareExamplePipeline)
return [rand(list)]
end

"""
Internal method for selecting from a list of items based on the pipeline type.
Example/test mode is to return a randomly selected item from the list.
"""
function _selector(list, pipeline::AbstractRtwithoutRenewalPipeline)
return pipeline.testmode ? [rand(list)] : list
end
139 changes: 116 additions & 23 deletions pipeline/src/infer/InferenceConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@ Inference configuration struct for specifying the parameters and models used in
- `epimethod::E`: Inference method.
- `transformation::F`: Transformation function.
- `log_I0_prior::Distribution`: Prior for log initial infections. Default is `Normal(log(100.0), 1e-5)`.
- `lookahead::X`: Number of days to forecast ahead.
- `latent_model_name::String`: Name of the latent model.
- `pipeline`: Pipeline type defining the inference scenario.

# Constructors
- `InferenceConfig(igp, latent_model, observation_model; gi_mean, gi_std, case_data, tspan, epimethod, transformation = exp)`: Constructs an `InferenceConfig` object with the specified parameters.
- `InferenceConfig(inference_config::Dict; case_data, tspan, epimethod)`: Constructs an `InferenceConfig` object from a dictionary of configuration values.

"""
struct InferenceConfig{T, F, IGP, L, O, E, D <: Distribution, X <: Integer}
struct InferenceConfig{
T, F, IGP, L, O, E, D <: Distribution, X <: Integer,
P <: AbstractRtwithoutRenewalPipeline}
gi_mean::T
gi_std::T
igp::IGP
Expand All @@ -32,19 +37,23 @@ struct InferenceConfig{T, F, IGP, L, O, E, D <: Distribution, X <: Integer}
transformation::F
log_I0_prior::D
lookahead::X
latent_model_name::String
pipeline::P

function InferenceConfig(igp, latent_model, observation_model; gi_mean, gi_std,
case_data, truth_I_t, truth_I0, tspan, epimethod,
transformation = exp, log_I0_prior, lookahead)
new{typeof(gi_mean), typeof(transformation),
typeof(igp), typeof(latent_model), typeof(observation_model),
typeof(epimethod), typeof(log_I0_prior), typeof(lookahead)}(
function InferenceConfig(
igp, latent_model, observation_model; gi_mean, gi_std, case_data,
truth_I_t, truth_I0, tspan, epimethod, transformation = oneexpy,
log_I0_prior, lookahead, latent_model_name, pipeline)
new{typeof(gi_mean), typeof(transformation), typeof(igp),
typeof(latent_model), typeof(observation_model), typeof(epimethod),
typeof(log_I0_prior), typeof(lookahead), typeof(pipeline)}(
gi_mean, gi_std, igp, latent_model, observation_model,
case_data, truth_I_t, truth_I0, tspan, epimethod, transformation, log_I0_prior, lookahead)
case_data, truth_I_t, truth_I0, tspan, epimethod,
transformation, log_I0_prior, lookahead, latent_model_name, pipeline)
end

function InferenceConfig(
inference_config::Dict; case_data, truth_I_t, truth_I0, tspan, epimethod)
inference_config::Dict; case_data, truth_I_t, truth_I0, tspan, epimethod, pipeline)
InferenceConfig(
inference_config["igp"],
inference_config["latent_namemodels"].second,
Expand All @@ -57,11 +66,71 @@ struct InferenceConfig{T, F, IGP, L, O, E, D <: Distribution, X <: Integer}
tspan = tspan,
epimethod = epimethod,
log_I0_prior = inference_config["log_I0_prior"],
lookahead = inference_config["lookahead"]
lookahead = inference_config["lookahead"],
latent_model_name = inference_config["latent_namemodels"].first,
pipeline
)
end
end

"""
Internal function that returns a dictionary containing key configuration fields from the given `InferenceConfig` object.
The dictionary includes the following keys:

- `"igp"`: The string representation of the `igp` field.
- `"latent_model"`: The name of the latent model.
- `"gi_mean"`: The mean of the generation interval.
- `"gi_std"`: The standard deviation of the generation interval.
- `"tspan"`: The time span for the inference.
- `"priorpredictive"`: The prior predictive setting.

# Arguments
- `config::InferenceConfig`: An instance of `InferenceConfig` containing the configuration details.

# Returns
- `Dict{String, Any}`: A dictionary with the key configuration fields.
"""
function _saved_config_fields(config::InferenceConfig)
return Dict(
"igp" => string(config.igp),
"latent_model" => config.latent_model_name,
"gi_mean" => config.gi_mean,
"gi_std" => config.gi_std,
"tspan" => string(config.tspan[1]) * "_" * string(config.tspan[2]),
"priorpredictive" => config.pipeline.priorpredictive,
"scenario" => config.pipeline |> typeof |> string
)
end

"""
This function makes inference on the underlying parameters of the model specified
in the `InferenceConfig` object `config`.

# Arguments
- `config::InferenceConfig`: The configuration object containing the case data
to make inference on and model configuration.
- `epiprob::EpiProblem`: The EpiProblem object containing the model to make inference on.

# Returns
- `inference_results`: The results of the simulation or inference.

"""
function create_inference_results(config, epiprob)
#Return the sampled infections and observations
idxs = config.tspan[1]:config.tspan[2]
y_t = ismissing(config.case_data) ? missing :
Vector{Union{Missing, Int64}}(config.case_data[idxs])
inference_results = apply_method(epiprob,
config.epimethod,
(y_t = y_t,)
)
inference_results = apply_method(epiprob,
config.epimethod,
(y_t = y_t,);
)
return inference_results
end

"""
This method makes inference on the underlying parameters of the model specified
in the `InferenceConfig` object `config`.
Expand All @@ -77,22 +146,46 @@ to make inference on and model configuration.
function infer(config::InferenceConfig)
#Define the EpiProblem
epiprob = define_epiprob(config)
idxs = config.tspan[1]:config.tspan[2]

#Define savable configuration
save_config = _saved_config_fields(config)

#Return the sampled infections and observations
y_t = ismissing(config.case_data) ? missing : config.case_data[idxs]
inference_results = apply_method(epiprob,
config.epimethod,
(y_t = y_t,);
)
inference_results = try
create_inference_results(config, epiprob)
catch e
string(e)
end

forecast_results = generate_forecasts(
inference_results.samples, inference_results.data, epiprob, config.lookahead)
if config.pipeline.priorpredictive
if inference_results isa String
return Dict("priorpredictive" => inference_results,
"inference_config" => save_config)
else
fig = prior_predictive_plot(
config, inference_results, epiprob; ps = [0.025, 0.1, 0.25])
figdir = config.pipeline.testmode ? mktempdir() : plotsdir("priorpredictive")
figpath = joinpath(figdir, "priorpred_" * savename(save_config) * ".png")
CairoMakie.save(figpath, fig)
return Dict("priorpredictive" => "Pass", "inference_config" => save_config)
end
else
forecast_results = try
generate_forecasts(
inference_results.samples, inference_results.data, epiprob, config.lookahead)
catch e
string(e)
end

epidata = epiprob.epi_model.data
score_results = summarise_crps(config, inference_results, forecast_results, epidata)
epidata = epiprob.epi_model.data
score_results = try
summarise_crps(config, inference_results, forecast_results, epidata)
catch e
string(e)
end

return Dict("inference_results" => inference_results,
"epiprob" => epiprob, "inference_config" => config,
"forecast_results" => forecast_results, "score_results" => score_results)
return Dict(
"inference_results" => inference_results, "inference_config" => save_config,
"forecast_results" => forecast_results, "score_results" => score_results)
end
end
2 changes: 1 addition & 1 deletion pipeline/src/infer/define_epiprob.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function define_epiprob(config::InferenceConfig)
model_data = EpiData(
gen_distribution = gen_distribution, transformation = config.transformation)
#Build the epidemiological model
epi = config.igp(model_data, config.log_I0_prior)
epi = config.igp(data = model_data, initialisation_prior = config.log_I0_prior)

epi_prob = EpiProblem(epi_model = epi,
latent_model = config.latent_model,
Expand Down
Loading
Loading