diff --git a/src/distributions/hierarchical_distribution.jl b/src/distributions/hierarchical_distribution.jl index 951b4422f..0ce661115 100644 --- a/src/distributions/hierarchical_distribution.jl +++ b/src/distributions/hierarchical_distribution.jl @@ -255,10 +255,15 @@ function Distributions.insupport(ud::UnshapedHDist, x::AbstractVector) end +function Statistics.mean(d::HierarchicalDistribution) + varshape(d)(mean(unshaped(d))) +end + function Statistics.mean(ud::UnshapedHDist) mean(nestedview(rand(_bat_determ_rng(), ud, 10^5))) end + function Statistics.cov(ud::UnshapedHDist) cov(nestedview(rand(_bat_determ_rng(), ud, 10^5))) end diff --git a/test/distributions/test_hierarchical_distribution.jl b/test/distributions/test_hierarchical_distribution.jl index d57433cec..74e9c6aa4 100644 --- a/test/distributions/test_hierarchical_distribution.jl +++ b/test/distributions/test_hierarchical_distribution.jl @@ -5,6 +5,7 @@ using Test using Distributions, StatsBase, IntervalSets, ValueShapes, ArraysOfArrays using AutoDiffOperators, ForwardDiff +using InverseFunctions import AdvancedHMC @@ -58,5 +59,6 @@ import AdvancedHMC @test isapprox(cov(unshaped(hd)), cov_expected, rtol = 0.05) @test isapprox(mean(unshaped.(rand(sampler(hd), 10^5))), [2.3, 2.3], rtol = 0.05) + @test isapprox(inverse(varshape(hd))(mean(hd)), mean(unshaped(hd)), rtol = 0.05) end end