Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to inferencedata method #67

Merged
merged 10 commits into from
Dec 9, 2022
Merged

Conversation

sethaxen
Copy link
Contributor

@sethaxen sethaxen commented Dec 8, 2022

This PR proposes some changes to inferencedata1. The main ones are:

  • Instead of _symbol, use the suffix _var for keyword args. This is consistent with the from_cmdstan, from_pystan, and from_cmdstanpy methods in Python ArviZ and would be more familiar for a Python Stan user coming to Julia.
  • The _var keyword args can now take Symbol, or an iterator of Symbols.
  • Add a predictions_var keyword arg
  • Add split_nt and split_nt_all convenience methods and refactor to use them for splitting the samples named tuple into the various group named tuples
  • Support other numbers of warmup or sample draws than 1000.

Here's an example output, using the notebook at https://github.com/StanJulia/Stan.jl/blob/master/Examples_Notebooks/InferenceObjects.jl

julia> idata = StanSample.inferencedata(
           m_schools;
           posterior_predictive_var=:y_hat,
           log_likelihood_var=[:log_lik],
           dims=(; (k => [:school] for k in [:theta, :theta_tilde, :y_hat, :log_lik])...),
       )
InferenceData with groups:
  > posterior
  > posterior_predictive
  > log_likelihood
  > sample_stats
  > warmup_posterior
  > warmup_posterior_predictive
  > warmup_sample_stats
  > warmup_log_likelihood

julia> idata.posterior
Dataset with dimensions: 
  Dim{:school} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} 1001:2000 ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 4 layers:
  :theta_tilde Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×1000×4)
  :mu          Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :tau         Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :theta       Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×1000×4)

with metadata Dict{String, Any} with 1 entry:
  "created_at" => "2022-12-08T22:34:23.197"

julia> idata.sample_stats
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} 1001:2000 ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 7 layers:
  :tree_depth      Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :energy          Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :diverging       Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :acceptance_rate Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :n_steps         Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :lp              Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :step_size       Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)

with metadata Dict{String, Any} with 1 entry:
  "created_at" => "2022-12-08T22:34:23.07"

julia> idata.log_likelihood
Dataset with dimensions: 
  Dim{:school} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} 1001:2000 ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 1 layer:
  :log_lik Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×1000×4)

with metadata Dict{String, Any} with 1 entry:
  "created_at" => "2022-12-08T22:34:23.095"

julia> idata.warmup_posterior
Dataset with dimensions: 
  Dim{:school} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} 1:1000 ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 4 layers:
  :theta_tilde Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×1000×4)
  :mu          Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :tau         Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :theta       Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×1000×4)

with metadata Dict{String, Any} with 1 entry:
  "created_at" => "2022-12-08T22:34:23.197"

Note that in the next breaking release of InferenceObjects, the dimension orders of arrays will change (arviz-devs/InferenceObjects.jl#40), and the default indices for all dimensions will be the axes of the underlying arrays (arviz-devs/InferenceObjects.jl#39; so after splitting samples from warmup, no reindexing will be needed)

@sethaxen
Copy link
Contributor Author

sethaxen commented Dec 8, 2022

Relates #60

@goedman goedman merged commit 4e49184 into StanJulia:master Dec 9, 2022
@sethaxen sethaxen deleted the idataupdate branch December 9, 2022 14:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants