From 67c9c3cf2730c74b82424eaa384dfa4abb11e2ea Mon Sep 17 00:00:00 2001
From: Patrick Aschermayr
Date: Fri, 29 Jul 2022 11:03:38 +0100
Subject: [PATCH] Update trace
---
Project.toml | 2 +-
src/sampling/inference.jl | 83 ++++++++++++++++++++++++++++++++++++---
2 files changed, 78 insertions(+), 7 deletions(-)
diff --git a/Project.toml b/Project.toml
index 4fcecdd..07eeec1 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "Baytes"
uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f"
authors = ["Patrick Aschermayr "]
-version = "0.1.9"
+version = "0.1.10"
[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
diff --git a/src/sampling/inference.jl b/src/sampling/inference.jl
index bb75591..9a86875 100644
--- a/src/sampling/inference.jl
+++ b/src/sampling/inference.jl
@@ -1,4 +1,68 @@
################################################################################
+"""
+$(TYPEDEF)
+
+Contains arguments for trace to extract parameter. Used to construct a 'TraceTransform'.
+
+# Fields
+$(TYPEDFIELDS)
+"""
+struct TransformInfo
+ "Chain indices that are used for output diagnostics."
+ chains :: Vector{Int64}
+ "Algorithm indices that are used for output diagnostics."
+ algorithms :: Vector{Int64}
+ "Number of burnin steps before output diagnostics are taken."
+ burnin :: Int64
+ "Number of steps that are set between 2 consecutive samples."
+ thinning :: Int64
+ "Maximum number of iterations to be collected for each chain."
+ maxiterations :: Int64
+ "StepRange for indices of effective samples"
+ effective_iterations :: StepRange{Int64, Int64}
+ function TransformInfo(
+ chains::Vector{Int64},
+ algorithms::Vector{Int64},
+ burnin::Int64,
+ thinning::Int64,
+ maxiterations::Int64
+ ) where {
+ T<:Tagged,P
+ }
+ ArgCheck.@argcheck maxiterations >= burnin >= 0
+ ArgCheck.@argcheck thinning > 0
+ ArgCheck.@argcheck maxiterations > 0
+ #Assign indices for subsetting trace
+ effective_iterations = (burnin+1):thinning:maxiterations
+ return new(chains, algorithms, burnin, thinning, maxiterations, effective_iterations)
+ end
+end
+function TransformInfo(
+ chains::Vector{Int64},
+ algorithms::Vector{Int64},
+ effective_iterations::StepRange{Int64, Int64}
+)
+ burnin = effective_iterations.start-1
+ thinning = effective_iterations.step
+ iterations = effective_iterations.stop
+ return TransformInfo(
+ chains,
+ algorithms,
+ burnin,
+ thinning,
+ iterations
+)
+end
+
+################################################################################
+"""
+$(TYPEDEF)
+
+Contains arguments for trace to extract parameter from a 'Trace'.
+
+# Fields
+$(TYPEDFIELDS)
+"""
struct TraceTransform{T<:Tagged, P}
"Contains parameter where output information is printed."
tagged :: T
@@ -35,22 +99,28 @@ struct TraceTransform{T<:Tagged, P}
return new{T,P}(tagged, paramnames, chains, algorithms, burnin, thinning, maxiterations, effective_iterations)
end
end
+
function TraceTransform(
trace::Trace,
model::ModelWrapper,
- tagged::Tagged = Tagged(model, trace.info.sampling.printedparam.printed)
+ tagged::Tagged = Tagged(model, trace.info.sampling.printedparam.printed),
+ info::TransformInfo = TransformInfo(
+ collect(Base.OneTo(trace.info.sampling.Nchains)),
+ collect(Base.OneTo(trace.info.sampling.Nalgorithms)),
+ trace.info.sampling.burnin,
+ trace.info.sampling.thinning,
+ trace.info.sampling.iterations
+ )
)
+ @unpack chains, algorithms, burnin, thinning, maxiterations = info
paramnames = ModelWrappers.paramnames(
tagged.info.reconstruct.default, tagged.info.constraint, subset(model.val, tagged.parameter)
)
+
return TraceTransform(
tagged,
paramnames,
- collect(Base.OneTo(trace.info.sampling.Nchains)),
- collect(Base.OneTo(trace.info.sampling.Nalgorithms)),
- trace.info.sampling.burnin,
- trace.info.sampling.thinning,
- trace.info.sampling.iterations
+ chains, algorithms, burnin, thinning, maxiterations
)
end
@@ -191,6 +261,7 @@ end
############################################################################################
#export
export
+ TransformInfo,
TraceTransform,
trace_to_3DArray,
trace_to_posteriormean,