Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add converter from Turing using both Chains and Model #133

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
dd7abf8
Add Turing to extras
sethaxen May 19, 2021
dff0ed1
Add initial implementation of from_turing
sethaxen May 19, 2021
55cbc78
Handle non-array eltype constraints
sethaxen May 19, 2021
d25e592
Apply suggestions from code review
sethaxen May 20, 2021
31f902b
Repair predictive model code
sethaxen May 20, 2021
e1da410
Run formatter
sethaxen May 20, 2021
454c23e
Constrain type of model
sethaxen May 20, 2021
98511f4
Add model name to attributes
sethaxen May 20, 2021
191cfd7
Support specifying groups to not be generated
sethaxen May 20, 2021
2c6cc32
Constrain type of rng
sethaxen May 20, 2021
134a6a1
Add docstring
sethaxen May 20, 2021
9e3edb7
Document from_turing
sethaxen May 20, 2021
3408476
Also generate constant_data
sethaxen May 20, 2021
180e560
Make code more modular
sethaxen May 20, 2021
85ddd69
Add Turing tests
sethaxen May 20, 2021
4f50d7a
Force library to be Turing
sethaxen May 20, 2021
1b58b6b
Overload setattribute! for InferenceData
sethaxen May 20, 2021
12ca874
Add function to add inference library info
sethaxen May 20, 2021
a7bb79f
Globally use library utility
sethaxen May 20, 2021
42c8823
Test library utility for Turing
sethaxen May 20, 2021
842447f
Increment version number
sethaxen May 20, 2021
c9b4562
Repair Turing example
sethaxen May 20, 2021
e0d9ae3
Don't import Turing's exports
sethaxen May 20, 2021
a90df63
Return correct variable name
sethaxen May 20, 2021
ea273ef
Indent wrapped lines
sethaxen May 20, 2021
6e93fa7
Update quickstart.md
sethaxen May 20, 2021
92e8a25
Run formatter
sethaxen May 20, 2021
1581eb0
Deep copy arguments
sethaxen May 20, 2021
b863095
Capture status in string
sethaxen May 20, 2021
13ad03a
Better handle adding library info
sethaxen May 21, 2021
a188e27
Run formatter
sethaxen May 21, 2021
ee474ae
Add attribute and library tests
sethaxen May 21, 2021
5ee652a
Extract observed_data from model
sethaxen May 21, 2021
09c37da
Update example
sethaxen May 21, 2021
a05bfd2
Update quickstart
sethaxen May 21, 2021
bf474a1
Fix test
sethaxen May 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
name = "ArviZ"
uuid = "131c737c-5715-5e2e-ad31-c244f01c1dc7"
authors = ["Seth Axen <[email protected]>"]
version = "0.5.4"
version = "0.5.5"

[deps]
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
PkgVersion = "eebad327-c553-4316-9ea0-9fa01ccd7688"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

Expand All @@ -21,10 +23,12 @@ DataFrames = "0.20, 0.21, 0.22, 1.0"
MCMCChains = "0.3.15, 0.4, 1.0, 2.0, 3.0, 4.0"
MonteCarloMeasurements = "0.6.4, 0.7, 0.8"
NamedTupleTools = "0.11.0, 0.12, 0.13"
PkgVersion = "0.1"
PyCall = "1.91.2"
PyPlot = "2.8.2"
Requires = "0.5.2, 1.0"
StatsBase = "0.32, 0.33"
Turing = "0.15"
julia = "^1"

[extras]
Expand All @@ -33,6 +37,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[targets]
test = ["CmdStan", "MCMCChains", "MonteCarloMeasurements", "Random", "Test"]
test = ["CmdStan", "MCMCChains", "MonteCarloMeasurements", "Random", "Test", "Turing"]
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Documenter, ArviZ
using MCMCChains: MCMCChains # make `from_mcmcchains` available for API docs
using Turing: Turing # make `from_mcmcchains` and `from_turing` available for API docs

makedocs(;
modules=[ArviZ],
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
| [`from_namedtuple`](@ref) | Convert `NamedTuple` data into an `InferenceData`. |
| [`from_dict`](@ref) | Convert `Dict` data into an `InferenceData`. |
| [`from_cmdstan`](@ref) | Convert CmdStan data into an `InferenceData`. |
| [`from_turing`](@ref) | Convert data from Turing into an `InferenceData`. |
| [`from_mcmcchains`](@ref) | Convert `MCMCChains` data into an `InferenceData`. |
| [`concat`](@ref) | Concatenate `InferenceData` objects. |
| [`concat!`](@ref) | Concatenate `InferenceData` objects in-place. |
Expand Down
50 changes: 7 additions & 43 deletions docs/src/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ idata = from_mcmcchains(
turing_chns;
coords=Dict("school" => schools),
dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
library="Turing",
library=Turing,
)
```

Expand Down Expand Up @@ -154,52 +154,16 @@ gcf()

### Additional information in Turing.jl

With a few more steps, we can use Turing to compute additional useful groups to add to the [`InferenceData`](@ref).

To sample from the prior, one simply calls `sample` but with the `Prior` sampler:

```@example turing
prior = sample(param_mod, Prior(), nsamples; progress=false)
```

To draw from the prior and posterior predictive distributions we can instantiate a "predictive model", i.e. a Turing model but with the observations set to `missing`, and then calling `predict` on the predictive model and the previously drawn samples:

```@example turing
# Instantiate the predictive model
param_mod_predict = turing_model(similar(y, Missing), σ)
# and then sample!
prior_predictive = predict(param_mod_predict, prior)
posterior_predictive = predict(param_mod_predict, turing_chns)
```

And to extract the pointwise log-likelihoods, which is useful if you want to compute metrics such as [`loo`](@ref),

```@example turing
loglikelihoods = Turing.pointwise_loglikelihoods(
param_mod, MCMCChains.get_sections(turing_chns, :parameters)
)
```

This can then be included in the [`from_mcmcchains`](@ref) call from above:
We would like to compute additional useful groups to add to the [`InferenceData`](@ref).
ArviZ includes a Turing-specific converter [`from_turing`](@ref) that, given a model, posterior samples, and data, can add the missing groups:

```@example turing
using LinearAlgebra
# Ensure the ordering of the loglikelihoods matches the ordering of `posterior_predictive`
ynames = string.(keys(posterior_predictive))
loglikelihoods_vals = getindex.(Ref(loglikelihoods), ynames)
# Reshape into `(nchains, nsamples, size(y)...)`
loglikelihoods_arr = permutedims(cat(loglikelihoods_vals...; dims=3), (2, 1, 3))

idata = from_mcmcchains(
idata = from_turing(
turing_chns;
posterior_predictive=posterior_predictive,
log_likelihood=Dict("y" => loglikelihoods_arr),
prior=prior,
prior_predictive=prior_predictive,
observed_data=Dict("y" => y),
model=param_mod,
rng=rng,
coords=Dict("school" => schools),
dims=Dict("y" => ["school"], "σ" => ["school"], "θ" => ["school"]),
library="Turing",
)
```

Expand Down Expand Up @@ -444,7 +408,7 @@ gcf()

```@example
using Pkg
Pkg.status()
Text(sprint(io -> Pkg.status(; io=io)))
```

```@example
Expand Down
7 changes: 7 additions & 0 deletions src/ArviZ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ __precompile__()
module ArviZ

using Base: @__doc__
using Random
using Requires
using REPL
using NamedTupleTools
using DataFrames
using PkgVersion: PkgVersion

using PyCall
using Conda
Expand Down Expand Up @@ -76,6 +78,7 @@ export InferenceData,
from_dict,
from_cmdstan,
from_mcmcchains,
from_turing,
concat,
concat!

Expand Down Expand Up @@ -109,6 +112,10 @@ function __init__()
import .MCMCChains: Chains, sections
include("mcmcchains.jl")
end
@require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" begin
import .Turing: Turing
include("turing.jl")
end
return nothing
end

Expand Down
26 changes: 26 additions & 0 deletions src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,29 @@ function reorder_groups!(data::InferenceData; group_order=SUPPORTED_GROUPS)
setproperty!(obj, :_groups, string.([sorted_names; other_names]))
return data
end

function setattribute!(data::InferenceData, key, value)
for (_, group) in groups(data)
setattribute!(group, key, value)
end
return data
end

function deleteattribute!(data::InferenceData, key)
for (_, group) in groups(data)
deleteattribute!(group, key)
end
return data
end

_add_library_attributes!(data, ::Nothing) = data
function _add_library_attributes!(data, library)
setattribute!(data, :inference_library, string(library))
if library isa Module
lib_version = string(PkgVersion.Version(library))
setattribute!(data, :inference_library_version, lib_version)
else
deleteattribute!(data, :inference_library_version)
end
return data
end
23 changes: 14 additions & 9 deletions src/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,14 @@ end
attributes(data::Dataset) = getproperty(PyObject(data), :_attrs)

function setattribute!(data::Dataset, key, value)
attrs = merge(attributes(data), Dict(key => value))
attrs = merge(attributes(data), Dict(string(key) => value))
setproperty!(PyObject(data), :_attrs, attrs)
return attrs
end

function deleteattribute!(data::Dataset, key)
attrs = attributes(data)
delete!(attrs, string(key))
setproperty!(PyObject(data), :_attrs, attrs)
return attrs
end
Expand Down Expand Up @@ -132,11 +139,10 @@ function convert_to_constant_dataset(
end

default_attrs = base.make_attrs()
if library !== nothing
default_attrs = merge(default_attrs, Dict("inference_library" => string(library)))
end
attrs = merge(default_attrs, attrs)
return Dataset(; data_vars=data, coords=coords, attrs=attrs)
ds = Dataset(; data_vars=data, coords=coords, attrs=attrs)
_add_library_attributes!(ds, library)
return ds
end

@doc doc"""
Expand Down Expand Up @@ -164,10 +170,9 @@ ArviZ.dict_to_dataset(Dict("x" => randn(4, 100), "y" => randn(4, 100)))
dict_to_dataset

function dict_to_dataset(data; library=nothing, attrs=Dict(), kwargs...)
if library !== nothing
attrs = merge(attrs, Dict("inference_library" => string(library)))
end
return arviz.dict_to_dataset(data; attrs=attrs, kwargs...)
ds = arviz.dict_to_dataset(data; attrs=attrs, kwargs...)
_add_library_attributes!(ds, library)
return ds
end

@doc doc"""
Expand Down
27 changes: 9 additions & 18 deletions src/mcmcchains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ Convert data in an `MCMCChains.Chains` format into an [`InferenceData`](@ref).
Any keyword argument below without an an explicitly annotated type above is allowed, so long
as it can be passed to [`convert_to_inference_data`](@ref).

For chains data from Turing, see [`from_turing`](@ref) for more options.

# Arguments

- `posterior::Chains`: Draws from the posterior
Expand Down Expand Up @@ -170,7 +172,6 @@ function from_mcmcchains(
kwargs...,
)
kwargs = convert(Dict, merge((; dims=Dict()), kwargs))
library = string(library)
rekey_fun = d -> rekey(d, stats_key_map)

# Convert chains to dicts
Expand Down Expand Up @@ -201,23 +202,18 @@ function from_mcmcchains(
group_data = popsubdict!(post_dict, group_data)
end
group_dataset = if group_data isa Chains
convert_to_dataset(group_data; library=library, eltypes=eltypes, kwargs...)
convert_to_dataset(group_data; eltypes=eltypes, kwargs...)
else
convert_to_dataset(group_data; library=library, kwargs...)
convert_to_dataset(group_data; kwargs...)
end
setattribute!(group_dataset, "inference_library", library)
concat!(all_idata, InferenceData(; group => group_dataset))
end

attrs_library = Dict("inference_library" => library)
if posterior === nothing
attrs = attrs_library
else
attrs = merge(attributes_dict(posterior), attrs_library)
end
attrs = posterior === nothing ? Dict() : attributes_dict(posterior)
kwargs = convert(Dict, merge((; attrs=attrs, dims=Dict()), kwargs))
post_idata = _from_dict(post_dict; sample_stats=stats_dict, kwargs...)
concat!(all_idata, post_idata)
_add_library_attributes!(all_idata, library)
return all_idata
end
function from_mcmcchains(
Expand All @@ -241,18 +237,13 @@ function from_mcmcchains(
posterior_predictive,
predictions,
log_likelihood;
library=library,
eltypes=eltypes,
kwargs...,
)

if prior !== nothing
pre_prior_idata = convert_to_inference_data(
prior;
posterior_predictive=prior_predictive,
library=library,
eltypes=eltypes,
kwargs...,
prior; posterior_predictive=prior_predictive, eltypes=eltypes, kwargs...
)
prior_idata = rekey(
pre_prior_idata,
Expand All @@ -272,10 +263,10 @@ function from_mcmcchains(
]
group_data === nothing && continue
group_data = convert_to_eltypes(group_data, eltypes)
group_dataset = convert_to_constant_dataset(group_data; library=library, kwargs...)
group_dataset = convert_to_constant_dataset(group_data; kwargs...)
concat!(all_idata, InferenceData(; group => group_dataset))
end

_add_library_attributes!(all_idata, library)
return all_idata
end

Expand Down
14 changes: 3 additions & 11 deletions src/namedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,14 @@ function from_namedtuple(
isempty(group_data) && continue
end
group_dataset = convert_to_dataset(group_data; kwargs...)
if library !== nothing
setattribute!(group_dataset, "inference_library", string(library))
end
concat!(all_idata, InferenceData(; group => group_dataset))
end

(post_dict === nothing || isempty(post_dict)) && return all_idata

group_dataset = convert_to_dataset(post_dict; kwargs...)
if library !== nothing
setattribute!(group_dataset, "inference_library", string(library))
end
concat!(all_idata, InferenceData(; posterior=group_dataset))

_add_library_attributes!(all_idata, library)
return all_idata
end
function from_namedtuple(
Expand All @@ -177,7 +171,6 @@ function from_namedtuple(
sample_stats,
predictions,
log_likelihood;
library=library,
kwargs...,
)

Expand All @@ -186,7 +179,6 @@ function from_namedtuple(
prior;
posterior_predictive=prior_predictive,
sample_stats=sample_stats_prior,
library=library,
kwargs...,
)
prior_idata = rekey(
Expand All @@ -207,10 +199,10 @@ function from_namedtuple(
]
group_data === nothing && continue
group_dict = convert(Dict, group_data)
group_dataset = convert_to_constant_dataset(group_dict; library=library, kwargs...)
group_dataset = convert_to_constant_dataset(group_dict; kwargs...)
concat!(all_idata, InferenceData(; group => group_dataset))
end

_add_library_attributes!(all_idata, library)
return all_idata
end
function from_namedtuple(data::AbstractVector{<:NamedTuple}; kwargs...)
Expand Down
Loading