Skip to content

Commit

Permalink
Add mean for HierarchicalDistribution
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Nov 3, 2023
1 parent 4de7e7f commit a06bdf7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/distributions/hierarchical_distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,15 @@ function Distributions.insupport(ud::UnshapedHDist, x::AbstractVector)
end


function Statistics.mean(d::HierarchicalDistribution)
varshape(d)(mean(unshaped(d)))
end

function Statistics.mean(ud::UnshapedHDist)
mean(nestedview(rand(_bat_determ_rng(), ud, 10^5)))
end


function Statistics.cov(ud::UnshapedHDist)
cov(nestedview(rand(_bat_determ_rng(), ud, 10^5)))
end
2 changes: 2 additions & 0 deletions test/distributions/test_hierarchical_distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Test

using Distributions, StatsBase, IntervalSets, ValueShapes, ArraysOfArrays
using AutoDiffOperators, ForwardDiff
using InverseFunctions

import AdvancedHMC

Expand Down Expand Up @@ -58,5 +59,6 @@ import AdvancedHMC

@test isapprox(cov(unshaped(hd)), cov_expected, rtol = 0.05)
@test isapprox(mean(unshaped.(rand(sampler(hd), 10^5))), [2.3, 2.3], rtol = 0.05)
@test isapprox(inverse(varshape(hd))(mean(hd)), mean(unshaped(hd)), rtol = 0.05)
end
end

0 comments on commit a06bdf7

Please sign in to comment.