From 67c9c3cf2730c74b82424eaa384dfa4abb11e2ea Mon Sep 17 00:00:00 2001 From: Patrick Aschermayr Date: Fri, 29 Jul 2022 11:03:38 +0100 Subject: [PATCH] Update trace --- Project.toml | 2 +- src/sampling/inference.jl | 83 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 4fcecdd..07eeec1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Baytes" uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f" authors = ["Patrick Aschermayr "] -version = "0.1.9" +version = "0.1.10" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/sampling/inference.jl b/src/sampling/inference.jl index bb75591..9a86875 100644 --- a/src/sampling/inference.jl +++ b/src/sampling/inference.jl @@ -1,4 +1,68 @@ ################################################################################ +""" +$(TYPEDEF) + +Contains arguments for trace to extract parameter. Used to construct a 'TraceTransform'. + +# Fields +$(TYPEDFIELDS) +""" +struct TransformInfo + "Chain indices that are used for output diagnostics." + chains :: Vector{Int64} + "Algorithm indices that are used for output diagnostics." + algorithms :: Vector{Int64} + "Number of burnin steps before output diagnostics are taken." + burnin :: Int64 + "Number of steps that are set between 2 consecutive samples." + thinning :: Int64 + "Maximum number of iterations to be collected for each chain." + maxiterations :: Int64 + "StepRange for indices of effective samples" + effective_iterations :: StepRange{Int64, Int64} + function TransformInfo( + chains::Vector{Int64}, + algorithms::Vector{Int64}, + burnin::Int64, + thinning::Int64, + maxiterations::Int64 + ) where { + T<:Tagged,P + } + ArgCheck.@argcheck maxiterations >= burnin >= 0 + ArgCheck.@argcheck thinning > 0 + ArgCheck.@argcheck maxiterations > 0 + #Assign indices for subsetting trace + effective_iterations = (burnin+1):thinning:maxiterations + return new(chains, algorithms, burnin, thinning, maxiterations, effective_iterations) + end +end +function TransformInfo( + chains::Vector{Int64}, + algorithms::Vector{Int64}, + effective_iterations::StepRange{Int64, Int64} +) + burnin = effective_iterations.start-1 + thinning = effective_iterations.step + iterations = effective_iterations.stop + return TransformInfo( + chains, + algorithms, + burnin, + thinning, + iterations +) +end + +################################################################################ +""" +$(TYPEDEF) + +Contains arguments for trace to extract parameter from a 'Trace'. + +# Fields +$(TYPEDFIELDS) +""" struct TraceTransform{T<:Tagged, P} "Contains parameter where output information is printed." tagged :: T @@ -35,22 +99,28 @@ struct TraceTransform{T<:Tagged, P} return new{T,P}(tagged, paramnames, chains, algorithms, burnin, thinning, maxiterations, effective_iterations) end end + function TraceTransform( trace::Trace, model::ModelWrapper, - tagged::Tagged = Tagged(model, trace.info.sampling.printedparam.printed) + tagged::Tagged = Tagged(model, trace.info.sampling.printedparam.printed), + info::TransformInfo = TransformInfo( + collect(Base.OneTo(trace.info.sampling.Nchains)), + collect(Base.OneTo(trace.info.sampling.Nalgorithms)), + trace.info.sampling.burnin, + trace.info.sampling.thinning, + trace.info.sampling.iterations + ) ) + @unpack chains, algorithms, burnin, thinning, maxiterations = info paramnames = ModelWrappers.paramnames( tagged.info.reconstruct.default, tagged.info.constraint, subset(model.val, tagged.parameter) ) + return TraceTransform( tagged, paramnames, - collect(Base.OneTo(trace.info.sampling.Nchains)), - collect(Base.OneTo(trace.info.sampling.Nalgorithms)), - trace.info.sampling.burnin, - trace.info.sampling.thinning, - trace.info.sampling.iterations + chains, algorithms, burnin, thinning, maxiterations ) end @@ -191,6 +261,7 @@ end ############################################################################################ #export export + TransformInfo, TraceTransform, trace_to_3DArray, trace_to_posteriormean,