From 3f2047334376c369e7baa19119543442c461f281 Mon Sep 17 00:00:00 2001 From: Cornelius-G Date: Thu, 4 Jun 2020 15:38:03 +0200 Subject: [PATCH 1/3] add MarginalDist --- src/plotting/BATHistogram.jl | 96 +++++----- src/plotting/BATHistogram_utils.jl | 20 +- src/plotting/Marginalization.jl | 106 +++++++++++ src/plotting/plotting.jl | 5 +- src/plotting/recipes_Marginalization_1D.jl | 131 +++++++++++++ src/plotting/recipes_Marginalization_2D.jl | 208 +++++++++++++++++++++ src/plotting/recipes_prior.jl | 12 +- src/plotting/recipes_samples_1D.jl | 15 +- src/plotting/recipes_samples_2D.jl | 10 +- src/plotting/split_histograms.jl | 44 ++--- test_marginal.jl | 110 +++++++++++ 11 files changed, 661 insertions(+), 96 deletions(-) create mode 100644 src/plotting/Marginalization.jl create mode 100644 src/plotting/recipes_Marginalization_1D.jl create mode 100644 src/plotting/recipes_Marginalization_2D.jl create mode 100644 test_marginal.jl diff --git a/src/plotting/BATHistogram.jl b/src/plotting/BATHistogram.jl index e65ac4bdc..98a16ae0e 100644 --- a/src/plotting/BATHistogram.jl +++ b/src/plotting/BATHistogram.jl @@ -5,30 +5,30 @@ mutable struct BATHistogram end - -# construct 1D BATHistogram from sample vector -function BATHistogram( - maybe_shaped_samples::DensitySampleVector, - key::Union{Integer, Symbol}; - nbins = 200, - closed::Symbol = :left, - filter::Bool = false -) - samples = BAT.unshaped.(maybe_shaped_samples) - - if filter - samples = BAT.drop_low_weight_samples(samples) - end - - idx = asindex(maybe_shaped_samples, key) - - hist = fit(Histogram, - flatview(samples.v)[idx, :], - FrequencyWeights(samples.weight), - nbins = nbins, closed = closed) - - return BATHistogram(hist) -end +# +# # construct 1D BATHistogram from sample vector +# function BATHistogram( +# maybe_shaped_samples::DensitySampleVector, +# key::Union{Integer, Symbol}; +# nbins = 200, +# closed::Symbol = :left, +# filter::Bool = false +# ) +# samples = BAT.unshaped.(maybe_shaped_samples) +# +# if filter +# samples = BAT.drop_low_weight_samples(samples) +# end +# +# idx = asindex(maybe_shaped_samples, key) +# +# hist = fit(Histogram, +# flatview(samples.v)[idx, :], +# FrequencyWeights(samples.weight), +# nbins = nbins, closed = closed) +# +# return BATHistogram(hist) +# end @@ -50,30 +50,30 @@ end # construct 2D BATHistogram from sample vector -function BATHistogram( - maybe_shaped_samples::DensitySampleVector, - params::Union{NTuple{2, Symbol}, NTuple{2, Integer}}; - nbins = 200, - closed::Symbol = :left, - filter::Bool = false -) - samples = unshaped.(maybe_shaped_samples) - - if filter - samples = BAT.drop_low_weight_samples(samples) - end - - i = asindex(maybe_shaped_samples, params[1]) - j = asindex(maybe_shaped_samples, params[2]) - - hist = fit(Histogram, - (flatview(samples.v)[i, :], - flatview(samples.v)[j, :]), - FrequencyWeights(samples.weight), - nbins = nbins, closed = closed) - - return BATHistogram(hist) -end +# function BATHistogram( +# maybe_shaped_samples::DensitySampleVector, +# params::Union{NTuple{2, Symbol}, NTuple{2, Integer}}; +# nbins = 200, +# closed::Symbol = :left, +# filter::Bool = false +# ) +# samples = unshaped.(maybe_shaped_samples) +# +# if filter +# samples = BAT.drop_low_weight_samples(samples) +# end +# +# i = asindex(maybe_shaped_samples, params[1]) +# j = asindex(maybe_shaped_samples, params[2]) +# +# hist = fit(Histogram, +# (flatview(samples.v)[i, :], +# flatview(samples.v)[j, :]), +# FrequencyWeights(samples.weight), +# nbins = nbins, closed = closed) +# +# return BATHistogram(hist) +# end diff --git a/src/plotting/BATHistogram_utils.jl b/src/plotting/BATHistogram_utils.jl index 80b6718bf..577e391cf 100644 --- a/src/plotting/BATHistogram_utils.jl +++ b/src/plotting/BATHistogram_utils.jl @@ -8,13 +8,14 @@ Find the modes of a BATHistogram. Returns a vector of the bin-centers of the bin(s) with the heighest weight. """ -function find_localmodes(bathist::BATHistogram) - dims = ndims(bathist.h.weights) +function find_localmodes(marg::MarginalDist) + hist = marg.dist.h + dims = ndims(hist.weights) - max = maximum(bathist.h.weights) - maxima_idx = findall(x->x==max, bathist.h.weights) + max = maximum(hist.weights) + maxima_idx = findall(x->x==max, hist.weights) - bin_centers = get_bin_centers(bathist) + bin_centers = get_bin_centers(marg) return [[bin_centers[d][maxima_idx[i][d]] for d in 1:dims] for i in 1:length(maxima_idx) ] end @@ -27,9 +28,10 @@ end Returns a vector of the bin-centers. """ -function get_bin_centers(bathist::BATHistogram) - edges = bathist.h.edges - dims = ndims(bathist.h.weights) +function get_bin_centers(marg::MarginalDist) + hist = marg.dist.h + edges = hist.edges + dims = ndims(hist.weights) centers = [[edges[d][i]+0.5*(edges[d][i+1]-edges[d][i]) for i in 1:length(edges[d])-1] for d in 1:dims] @@ -63,7 +65,7 @@ function islower(weights, idx) end function isupper(weights, idx) - if idx==length(weights) && weights[idx-1]>0 + if idx==length(weights) && weights[idx-1]>0 return true elseif weights[idx]==0 && weights[idx-1]>0 return true diff --git a/src/plotting/Marginalization.jl b/src/plotting/Marginalization.jl new file mode 100644 index 000000000..27a07c536 --- /dev/null +++ b/src/plotting/Marginalization.jl @@ -0,0 +1,106 @@ +struct MarginalDist{N,D<:Distribution,VS<:AbstractValueShape} + dims::NTuple{N,Int} + dist::D + origvalshape::VS +end + + +#TODO: does not work for unshaped samples +function bat_marginalize( + maybe_shaped_samples::DensitySampleVector, + key::Union{Integer, Symbol, Expr}; + nbins = 200, + closed::Symbol = :left, + filter::Bool = false, + normalize = true +) + samples = BAT.unshaped.(maybe_shaped_samples) + + if filter + samples = BAT.drop_low_weight_samples(samples) + end + + idx = asindex(maybe_shaped_samples, key) + + hist = fit(Histogram, + flatview(samples.v)[idx, :], + FrequencyWeights(samples.weight), + nbins = nbins, closed = closed) + + normalize ? hist = StatsBase.normalize(hist) : nothing + + uvbd = EmpiricalDistributions.UvBinnedDist(hist) + + return MarginalDist((idx,), uvbd, varshape(maybe_shaped_samples)) +end + + +function bat_marginalize( + maybe_shaped_samples::DensitySampleVector, + key::Union{NTuple{2,Integer}, NTuple{2,Union{Symbol, Expr}}}; + nbins = 200, + closed::Symbol = :left, + filter::Bool = false, + normalize = true +) + samples = unshaped.(maybe_shaped_samples) + + if filter + samples = BAT.drop_low_weight_samples(samples) + end + + i = asindex(maybe_shaped_samples, key[1]) + j = asindex(maybe_shaped_samples, key[2]) + + hist = fit(Histogram, + (flatview(samples.v)[i, :], + flatview(samples.v)[j, :]), + FrequencyWeights(samples.weight), + nbins = nbins, closed = closed) + + normalize ? hist = StatsBase.normalize(hist) : nothing + + mvbd = EmpiricalDistributions.MvBinnedDist(hist) + + return MarginalDist((i,j), mvbd, varshape(maybe_shaped_samples)) +end + + +#for prior +function bat_marginalize( + prior::NamedTupleDist, + key::Union{Integer, Symbol}; + nbins = 200, + closed::Symbol = :left, + nsamples::Integer = 10^6, + normalize = true +) + idx = asindex(prior, key) + r = rand(prior, nsamples) + + hist = fit(Histogram, r[idx, :], nbins = nbins, closed = closed) + normalize ? hist = StatsBase.normalize(hist) : nothing + + uvbd = EmpiricalDistributions.UvBinnedDist(hist) + + return MarginalDist((idx,), uvbd, varshape(prior)) +end + + +function BATHistogram( + prior::NamedTupleDist, + params::Union{NTuple{2, Symbol}, NTuple{2, Integer}}; + nbins = 200, + closed::Symbol = :left, + nsamples::Integer = 10^6 +) + i = asindex(prior, params[1]) + j = asindex(prior, params[2]) + + r = rand(prior, nsamples) + hist = fit(Histogram, (r[i, :], r[j, :]), nbins = nbins, closed = closed) + + mvbd = EmpiricalDistributions.MvBinnedDist(hist) + + return MarginalDist((i,j), mvbd, varshape(prior)) +end diff --git a/src/plotting/plotting.jl b/src/plotting/plotting.jl index e389b2998..83ca96422 100644 --- a/src/plotting/plotting.jl +++ b/src/plotting/plotting.jl @@ -5,9 +5,10 @@ const standard_colors = [:chartreuse2, :yellow, :red] include("BATHistogram.jl") +include("Marginalization.jl") +include("recipes_Marginalization_1D.jl") +include("recipes_Marginalization_2D.jl") include("recipes_stats.jl") -include("recipes_BATHistogram_1D.jl") -include("recipes_BATHistogram_2D.jl") include("recipes_samples_overview.jl") include("recipes_prior_overview.jl") include("split_histograms.jl") diff --git a/src/plotting/recipes_Marginalization_1D.jl b/src/plotting/recipes_Marginalization_1D.jl new file mode 100644 index 000000000..c27544835 --- /dev/null +++ b/src/plotting/recipes_Marginalization_1D.jl @@ -0,0 +1,131 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + +# TODO: add plot without Int for overview? + +function plothistogram(h::StatsBase.Histogram, swap::Bool) + if swap + return h.weights, h.edges[1][1:end-1] + else + return h.edges[1][1:end-1], h.weights + end +end + + +@recipe function f( + marg::MarginalDist, + idx::Integer; + intervals = standard_confidence_vals, + normalize = true, + colors = standard_colors, + interval_labels = [] +) + hist = marg.dist.h + normalize ? hist=StatsBase.normalize(hist) : nothing + + orientation = get(plotattributes, :orientation, :vertical) + (orientation != :vertical) ? swap = true : swap = false + plotattributes[:orientation] = :vertical # without: auto-scaling of axes not correct + + seriestype = get(plotattributes, :seriestype, :stephist) + + xlabel = get(plotattributes, :xguide, "x$(idx)") + ylabel = get(plotattributes, :yguide, "p(x$(idx))") + + if swap + xguide := ylabel + yguide := xlabel + else + xguide := xlabel + yguide := ylabel + end + + # step histogram + if seriestype == :stephist || seriestype == :steppost + @series begin + seriestype := :steppost + label --> "" + linecolor --> :dodgerblue + plothistogram(hist, swap) + end + + # filled histogram + elseif seriestype == :histogram + @series begin + seriestype := :steppost + label --> "" + fillrange --> 0 + fillcolor --> :dodgerblue + linewidth --> 0 + plothistogram(hist, swap) + end + + + # smallest intervals aka highest density region (HDR) + elseif seriestype == :smallest_intervals || seriestype == :HDR + hists, realintervals = get_smallest_intervals(hist, intervals) + colors = colors[sortperm(intervals, rev=true)] + + # colored histogram for each interval + for i in 1:length(realintervals) + @series begin + seriestype := :steppost + fillcolor --> colors[i] + linewidth --> 0 + fillrange --> 0 + + if length(interval_labels) > 0 + label := interval_labels[i] + else + label := "smallest $(@sprintf("%.2f", realintervals[i]*100))% interval(s)" + end + plothistogram(hists[i], swap) + end + end + + # black contour line for total histogram + @series begin + seriestype := :steppost + linecolor --> :black + linewidth --> 0.7 + label --> "" + plothistogram(hist, swap) + end + + + # central intervals + elseif seriestype == :central_intervals + hists, realintervals = split_central(hist, intervals) + colors = colors[sortperm(intervals, rev=true)] + + # colored histogram for each interval + for i in 1:length(realintervals) + @series begin + seriestype := :steppost + fillcolor --> colors[i] + linewidth --> 0 + fillrange --> 0 + + if length(interval_labels) > 0 + label := interval_labels[i] + else + label := "central $(@sprintf("%.2f", realintervals[i]*100))% interval(s)" + end + + plothistogram(hists[i], swap) + end + end + + # black contour line for total histogram + @series begin + seriestype := :steppost + linecolor --> :black + linewidth --> 0.7 + label --> "" + plothistogram(hist, swap) + end + + else + error("seriestype $seriestype not supported") + end + +end diff --git a/src/plotting/recipes_Marginalization_2D.jl b/src/plotting/recipes_Marginalization_2D.jl new file mode 100644 index 000000000..6df43d599 --- /dev/null +++ b/src/plotting/recipes_Marginalization_2D.jl @@ -0,0 +1,208 @@ +# This file is a part of BAT.jl, licensed under the MIT License (MIT). +@recipe function f( + marg::MarginalDist, + parsel::NTuple{2,Integer}; + intervals = standard_confidence_vals, + colors = standard_colors, + diagonal = Dict(), + upper = Dict(), + right = Dict(), + interval_labels = [], + normalize = true +) + _plots_module() != nothing || throw(ErrorException("Package Plots not available, but required for this operation")) + println("hi") + hist = marg.dist.h + + seriestype = get(plotattributes, :seriestype, :histogram2d) + + xlabel = get(plotattributes, :xguide, "x$(parsel[1])") + ylabel = get(plotattributes, :yguide, "x$(parsel[2])") + + + # histogram / heatmap + if seriestype == :histogram2d || seriestype == :histogram || seriestype == :hist + @series begin + seriestype := :bins2d + xguide --> xlabel + yguide --> ylabel + colorbar --> true + + hist.edges[1], hist.edges[2], _plots_module().Surface(hist.weights) + end + + + # smallest interval contours + elseif seriestype == :smallest_intervals_contour || seriestype == :smallest_intervals_contourf + + colors = colors[sortperm(intervals, rev=true)] + + if seriestype == :smallest_intervals_contour + plotstyle = :contour + else + plotstyle = :contourf + end + + lev = calculate_levels(hist, intervals) + x, y = get_bin_centers(hist) + m = hist.weights + + # quick fix: needed when plotting contour on top of histogram + # otherwise scaling of histogram colorbar would change scaling + lev = lev/10000 + m = m/10000 + + colorbar --> false + xguide --> xlabel + yguide --> ylabel + + if _plots_module().backend() == _plots_module().PyPlotBackend() + @series begin + seriestype := plotstyle + levels --> lev + linewidth --> 2 + seriescolor --> colors # currently only works with pyplot + (x, y, m') + end + else + @series begin + seriestype := plotstyle + levels --> lev + linewidth --> 2 + (x, y, m') + end + end + + + # smallest intervals heatmap + elseif seriestype == :smallest_intervals + colors = colors[sortperm(intervals, rev=true)] + + hists, realintervals = get_smallest_intervals(hist, intervals) + + for (i, int) in enumerate(realintervals) + @series begin + seriestype := :bins2d + seriescolor --> _plots_module().cgrad([colors[i], colors[i]]) + xguide --> xlabel + yguide --> ylabel + + hists[i].edges[1], hists[i].edges[2], _plots_module().Surface(hists[i].weights) + end + + # fake a legend + interval_label = isempty(interval_labels) ? "smallest $(@sprintf("%.2f", realintervals[i]*100))% interval(s)" : interval_labels[i] + + @series begin + seriestype := :shape + fillcolor --> colors[i] + linewidth --> 0 + label --> interval_label + colorbar --> false + [hists[i].edges[1][1], hists[i].edges[1][1]], [hists[i].edges[2][1], hists[i].edges[2][1]] + end + end + + + # marginal histograms + elseif seriestype == :marginal + layout --> _plots_module().grid(2,2, widths=(0.8, 0.2), heights=(0.2, 0.8)) + link --> :both + + if get(diagonal, "seriestype", :histogram) != :histogram + colorbar --> false + end + + @series begin + subplot := 1 + xguide := xlabel + yguide := "p("*xlabel*")" + seriestype := get(upper, "seriestype", :histogram) + bins --> get(upper, "nbins", 200) + normalize --> get(upper, "normalize", true) + colors --> get(upper, "colors", standard_colors) + intervals --> get(upper, "intervals", standard_confidence_vals) + legend --> get(upper, "legend", true) + + hist, 1 + end + + # empty plot (needed since @layout macro not available) + @series begin + seriestype := :scatter + subplot := 2 + grid := false + xaxis := false + yaxis := false + markersize := 0.001 + markerstrokewidth := 0 + markeralpha := 1 + markerstrokealpha := 1 + legend := false + label := "" + xguide := "" + yguide := "" + [(0,0)] + end + + @series begin + subplot := 3 + seriestype := get(diagonal, "seriestype", :histogram) + xguide --> xlabel + yguide --> ylabel + normalize --> get(diagonal, "normalize", true) + bins --> get(diagonal, "nbins", 200) + colors --> get(diagonal, "colors", standard_colors) + intervals --> get(diagonal, "intervals", standard_confidence_vals) + legend --> get(diagonal, "legend", false) + + hist, (1, 2) + end + + @series begin + subplot := 4 + seriestype := get(right, "seriestype", :histogram) + orientation := :horizontal + xguide := ylabel + yguide := "p("*ylabel*")" + normalize --> get(right, "normalize", true) + bins --> get(right, "nbins", 200) + colors --> get(right, "colors", standard_colors) + intervals --> get(right, "intervals", standard_confidence_vals) + legend --> get(right, "legend", true) + + hist, 2 + end + + else + error("seriestype $seriestype not supported") + end + +end + + + +# rectangle bounds +@recipe function f(bounds::HyperRectBounds, parsel::NTuple{2,Integer}) + pi_x, pi_y = parsel + + vol = spatialvolume(bounds) + vhi = vol.hi[[pi_x, pi_y]]; vlo = vol.lo[[pi_x, pi_y]] + rect_xy = rectangle_path(vlo, vhi) + bext = 0.1 * (vhi - vlo) + xlims = (vlo[1] - bext[1], vhi[1] + bext[1]) + ylims = (vlo[2] - bext[2], vhi[2] + bext[2]) + + @series begin + seriestype := :path + label --> "bounds" + seriescolor --> :darkred + seriesalpha --> 0.75 + linewidth --> 2 + xlims --> xlims + ylims --> ylims + (rect_xy[:,1], rect_xy[:,2]) + end + + nothing +end diff --git a/src/plotting/recipes_prior.jl b/src/plotting/recipes_prior.jl index 1fdac6bdc..827a3152e 100644 --- a/src/plotting/recipes_prior.jl +++ b/src/plotting/recipes_prior.jl @@ -21,8 +21,14 @@ throw(ArgumentError("Symbol :$parsel refers to a multivariate parameter. Use :($parsel[i]) instead.")) end - bathist = BATHistogram(prior, idx, nbins = bins, closed = closed) - normalize ? bathist.h = StatsBase.normalize(bathist.h) : nothing + marg = bat_marginalize( + prior, + idx, + nbins = bins, + nsamples = nsamples, + closed = closed, + normalize = normalize + ) xlabel = if isa(parsel, Symbol) || isa(parsel, Expr) "$parsel" @@ -50,7 +56,7 @@ colors --> colors interval_labels --> interval_labels - bathist, 1 + marg, 1 end end diff --git a/src/plotting/recipes_samples_1D.jl b/src/plotting/recipes_samples_1D.jl index 4601d4dfb..04bf4f95c 100644 --- a/src/plotting/recipes_samples_1D.jl +++ b/src/plotting/recipes_samples_1D.jl @@ -21,16 +21,15 @@ throw(ArgumentError("Symbol :$parsel refers to a multivariate parameter. Use :($parsel[i]) instead.")) end - bathist = BATHistogram( + marg = bat_marginalize( maybe_shaped_samples, - idx, + parsel, nbins = bins, closed = closed, - filter = filter + filter = filter, + normalize = normalize ) - normalize ? bathist.h = StatsBase.normalize(bathist.h) : nothing - orientation = get(plotattributes, :orientation, :vertical) (orientation != :vertical) ? swap=true : swap = false @@ -61,13 +60,13 @@ colors --> colors interval_labels --> interval_labels - bathist, 1 + marg, 1 end #------ stats ---------------------------- stats = MCMCBasicStats(maybe_shaped_samples) - line_height = maximum(bathist.h.weights)*1.03 + line_height = maximum(marg.dist.h.weights)*1.03 mean_options = convert_to_options(mean) globalmode_options = convert_to_options(globalmode) @@ -123,7 +122,7 @@ # local mode(s) if localmode_options != () - localmode_values = find_localmodes(bathist) + localmode_values = find_localmodes(marg) for (i, l) in enumerate(localmode_values) @series begin diff --git a/src/plotting/recipes_samples_2D.jl b/src/plotting/recipes_samples_2D.jl index fafcdd144..754af5f72 100644 --- a/src/plotting/recipes_samples_2D.jl +++ b/src/plotting/recipes_samples_2D.jl @@ -44,14 +44,16 @@ xguide := get(plotattributes, :xguide, xlabel) yguide := get(plotattributes, :yguide, ylabel) - hist = BATHistogram( + marg = bat_marginalize( samples, (xindx, yindx), nbins = bins, closed = closed, - filter=filter + filter = filter ) + println(typeof(marg)) + if seriestype == :scatter base_markersize = get(plotattributes, :markersize, 1.5) @@ -93,7 +95,7 @@ upper --> upper right --> right - hist, (1, 2) + marg, (1, 2) end end @@ -163,7 +165,7 @@ if localmode_options != () - localmode_values = find_localmodes(hist) + localmode_values = find_localmodes(marg) for (i, l) in enumerate(localmode_values) @series begin seriestype := :scatter diff --git a/src/plotting/split_histograms.jl b/src/plotting/split_histograms.jl index 3b40b6d2f..77ec8eda6 100644 --- a/src/plotting/split_histograms.jl +++ b/src/plotting/split_histograms.jl @@ -2,14 +2,14 @@ # for 1d and 2d histogramsm function get_smallest_intervals( - histogram::BATHistogram, + histogram::StatsBase.Histogram, intervals::Array{Float64, 1} ) intervals = sort(intervals) - bathist = deepcopy(histogram) - dims = size(bathist.h.weights) - weights = vec(bathist.h.weights) + hist = deepcopy(histogram) + dims = size(hist.weights) + weights = vec(hist.weights) totalweight = sum(weights) rel_weights = weights/totalweight @@ -28,14 +28,14 @@ function get_smallest_intervals( end end - hists = Array{BATHistogram}(undef, length(intervals)) + hists = Array{StatsBase.Histogram}(undef, length(intervals)) for i in 1:length(intervals) - hists[i] = deepcopy(bathist) - hists[i].h.weights = reshape(hists_weights[i], dims) + hists[i] = deepcopy(hist) + hists[i].weights = reshape(hists_weights[i], dims) end - realintervals = get_probability_content(bathist, hists) + realintervals = get_probability_content(hist, hists) return reverse(hists), reverse(realintervals) end @@ -44,20 +44,20 @@ end # for 1d histograms function split_central( - histogram::BATHistogram, + histogram::StatsBase.Histogram, intervals::Array{Float64, 1} ) intervals = sort(intervals) intervals = (1 .-intervals)/2 - bathist = deepcopy(histogram) - hists = Array{BATHistogram}(undef, length(intervals)) + hist = deepcopy(histogram) + hists = Array{StatsBase.Histogram}(undef, length(intervals)) for i in 1:length(intervals) - hists[i] = deepcopy(bathist) + hists[i] = deepcopy(hist) end - weights = vec(bathist.h.weights) + weights = vec(hist.h.weights) totalweight = sum(weights) rel_weights = weights/totalweight @@ -68,7 +68,7 @@ function split_central( for l in 1:length(weights) if sum_left + rel_weights[l] < intv sum_left = sum_left + rel_weights[l] - hists[i].h.weights[l] = 0 + hists[i].weights[l] = 0 else break end @@ -77,14 +77,14 @@ function split_central( for r in length(weights):-1:1 if sum_right + rel_weights[r] < intv sum_right = sum_right + rel_weights[r] - hists[i].h.weights[r] = 0 + hists[i].weights[r] = 0 else break end end end - realintervals = get_probability_content(bathist, hists) + realintervals = get_probability_content(hist, hists) return reverse(hists), reverse(realintervals) end @@ -93,23 +93,23 @@ end # calculate probability percentage enclosed inside the intervals of hists function get_probability_content( - hist::BATHistogram, - hists::Array{BATHistogram, 1} + hist::StatsBase.Histogram, + hists::Array{StatsBase.Histogram, 1} ) - totalweight = sum(hist.h.weights) - return [sum(hists[i].h.weights)/totalweight for i in 1:length(hists)] + totalweight = sum(hist.weights) + return [sum(hists[i].weights)/totalweight for i in 1:length(hists)] end function calculate_levels( - bathist::BATHistogram, + hist::StatsBase.Histogram, intervals::Array{<:Real, 1} ) intervals = sort(intervals) levels = Vector{Real}(undef, length(intervals)+1) - weights = sort(vec(bathist.h.weights), rev=true) + weights = sort(vec(hist.weights), rev=true) weight_ids = sortperm(weights, rev=true); sum_of_weights = sum(weights) diff --git a/test_marginal.jl b/test_marginal.jl new file mode 100644 index 000000000..713f56202 --- /dev/null +++ b/test_marginal.jl @@ -0,0 +1,110 @@ +# # BAT.jl plotting tutorial + +using BAT +using Distributions +using IntervalSets + + +likelihood = params -> begin + + r1 = logpdf.( + MixtureModel(Normal[ + Normal(-10.0, 1.2), + Normal(0.0, 1.8), + Normal(10.0, 2.5)], [0.1, 0.3, 0.6]), params.a) + + r2 = logpdf.( + MixtureModel(Normal[ + Normal(-5.0, 2.2), + Normal(5.0, 1.5)], [0.3, 0.7]), params.b[1]) + + r3 = logpdf.(Normal(2.0, 1.5), params.b[2]) + + return LogDVal(r1+r2+r3) +end + +prior = BAT.NamedTupleDist( + a = Normal(-3, 4.5), + b = [-20.0..20.0, -10..10] +) + +posterior = PosteriorDensity(likelihood, prior); + +samples, chains = bat_sample(posterior, (10^5, 4), MetropolisHastings()); + +unshaped_samples = BAT.unshaped.(samples) + +BAT.bat_marginalize(samples, (:(b[1]),:(b[2]))) + + +BAT.bat_marginalize(unshaped_samples, (1, 2)) + + + + + +# ## Set up plotting +# Set up plotting using the [Plots.jl](https://github.com/JuliaPlots/Plots.jl) package: +using Plots + + +plot(samples, :a) #default seriestype = :smallest_intervals (alias :HDR) + +plot(prior, :a) +#or: plot(prior, 1) + +# ## Knowledge update plot +# The knowledge update after performing the sampling can be visualized by plotting the prior and the samples of the psterior together in one plot using `plot!()`: +plot(samples, :(b[1])) +plot!(prior, :(b[1])) +0.7^15*100 + + +lost = 0 +win = 0 +x=0 +for i in 1:20 + while(win<=lost) + global x +=0.5 + res = 18*x + + win = res-lost + end + lost = lost + 6*x + println(x) +end + +function f(x, lost) + j=0 + for i in 1:30 + x += 0.5 + res = 18*x + + #println("lost before: $lost") + win = res + plus = win-6*x-lost + + if(plus > 0) + j = j+1 + println("\nj: $j") + println("x: $x") + println("6*x: $(6*x)") + println("win: $win") + println("plus: $plus") + println("total: $total") + total = total + 6*x + lost = lost + 6*x + # println("i: $i") + #println("lost after: $lost\n") + end + + + end +end + +f(0, 0) + +p = 1-12/37 +p^15*100 + +0.25^10*100 From 1bd4f35d1ccba30d1b7211c5913b953ff9caf9c1 Mon Sep 17 00:00:00 2001 From: Cornelius-G Date: Mon, 8 Jun 2020 18:10:22 +0200 Subject: [PATCH 2/3] introduce MarginalDist, remove BATHistogram --- examples/dev-internal/plotting_examples.jl | 7 +- src/plotting/BATHistogram.jl | 95 -------- src/plotting/BATHistogram_utils.jl | 94 -------- .../{Marginalization.jl => MarginalDist.jl} | 58 +++-- ...it_histograms.jl => MarginalDist_utils.jl} | 84 ++++++- src/plotting/plotting.jl | 10 +- src/plotting/recipes_BATHistogram_1D.jl | 131 ----------- src/plotting/recipes_BATHistogram_2D.jl | 209 ------------------ ...ation_1D.jl => recipes_MarginalDist_1D.jl} | 11 +- ...ation_2D.jl => recipes_MarginalDist_2D.jl} | 14 +- src/plotting/recipes_prior.jl | 15 +- src/plotting/recipes_samples_1D.jl | 2 +- src/plotting/recipes_samples_2D.jl | 7 +- src/plotting/valueshapes_utils.jl | 24 +- test_marginal.jl | 110 --------- 15 files changed, 179 insertions(+), 692 deletions(-) delete mode 100644 src/plotting/BATHistogram.jl delete mode 100644 src/plotting/BATHistogram_utils.jl rename src/plotting/{Marginalization.jl => MarginalDist.jl} (58%) rename src/plotting/{split_histograms.jl => MarginalDist_utils.jl} (60%) delete mode 100644 src/plotting/recipes_BATHistogram_1D.jl delete mode 100644 src/plotting/recipes_BATHistogram_2D.jl rename src/plotting/{recipes_Marginalization_1D.jl => recipes_MarginalDist_1D.jl} (93%) rename src/plotting/{recipes_Marginalization_2D.jl => recipes_MarginalDist_2D.jl} (97%) delete mode 100644 test_marginal.jl diff --git a/examples/dev-internal/plotting_examples.jl b/examples/dev-internal/plotting_examples.jl index 126c6491f..5a5ece5f4 100644 --- a/examples/dev-internal/plotting_examples.jl +++ b/examples/dev-internal/plotting_examples.jl @@ -5,11 +5,6 @@ using Distributions using IntervalSets # ## Generate samples to be plotted -struct MultiModalModel<:AbstractDensity - r::Vector{Float64} - sigma::Vector{Float64} -end - likelihood = params -> begin @@ -123,7 +118,7 @@ plot(samples, :a, localmode=false, # ### Default 2D plot of samples: pyplot() -plot(samples, (:a,:(b[1])), mean=true, std=true) #default seriestype = :smallest_intervals (alias :HDR) +plot(samples, (:a, :(b[1])), mean=true, std=true) #default seriestype = :smallest_intervals (alias :HDR) # The default seriestype for plotting samples is a 3-color heatmap showing the smallest intervals (highest density regions) containing 68.3%, 95.5% and 99.7% of the posterior probability. By default, the local mode # of the histogram is indicated by a black square. diff --git a/src/plotting/BATHistogram.jl b/src/plotting/BATHistogram.jl deleted file mode 100644 index 98a16ae0e..000000000 --- a/src/plotting/BATHistogram.jl +++ /dev/null @@ -1,95 +0,0 @@ -export BATHistogram - -mutable struct BATHistogram - h::StatsBase.Histogram -end - - -# -# # construct 1D BATHistogram from sample vector -# function BATHistogram( -# maybe_shaped_samples::DensitySampleVector, -# key::Union{Integer, Symbol}; -# nbins = 200, -# closed::Symbol = :left, -# filter::Bool = false -# ) -# samples = BAT.unshaped.(maybe_shaped_samples) -# -# if filter -# samples = BAT.drop_low_weight_samples(samples) -# end -# -# idx = asindex(maybe_shaped_samples, key) -# -# hist = fit(Histogram, -# flatview(samples.v)[idx, :], -# FrequencyWeights(samples.weight), -# nbins = nbins, closed = closed) -# -# return BATHistogram(hist) -# end - - - -# construct 1D BATHistogram from prior -function BATHistogram( - prior::NamedTupleDist, - key::Union{Integer, Symbol}; - nbins = 200, - closed::Symbol = :left, - nsamples::Integer = 10^6 -) - idx = asindex(prior, key) - r = rand(prior, nsamples) - hist = fit(Histogram, r[idx, :], nbins = nbins, closed = closed) - - return BATHistogram(hist) -end - - - -# construct 2D BATHistogram from sample vector -# function BATHistogram( -# maybe_shaped_samples::DensitySampleVector, -# params::Union{NTuple{2, Symbol}, NTuple{2, Integer}}; -# nbins = 200, -# closed::Symbol = :left, -# filter::Bool = false -# ) -# samples = unshaped.(maybe_shaped_samples) -# -# if filter -# samples = BAT.drop_low_weight_samples(samples) -# end -# -# i = asindex(maybe_shaped_samples, params[1]) -# j = asindex(maybe_shaped_samples, params[2]) -# -# hist = fit(Histogram, -# (flatview(samples.v)[i, :], -# flatview(samples.v)[j, :]), -# FrequencyWeights(samples.weight), -# nbins = nbins, closed = closed) -# -# return BATHistogram(hist) -# end - - - -# # construct 2D BATHistogram from prior -function BATHistogram( - prior::NamedTupleDist, - params::Union{NTuple{2, Symbol}, NTuple{2, Integer}}; - nbins = 200, - closed::Symbol = :left, - nsamples::Integer = 10^6 -) - i = asindex(prior, params[1]) - j = asindex(prior, params[2]) - - r = rand(prior, nsamples) - hist = fit(Histogram, (r[i, :], r[j, :]), nbins = nbins, closed = closed) - - return BATHistogram(hist) -end diff --git a/src/plotting/BATHistogram_utils.jl b/src/plotting/BATHistogram_utils.jl deleted file mode 100644 index 577e391cf..000000000 --- a/src/plotting/BATHistogram_utils.jl +++ /dev/null @@ -1,94 +0,0 @@ -# This file is a part of BAT.jl, licensed under the MIT License (MIT). - -""" - find_localmodes(bathist::BATHistogram) - -*BAT-internal, not part of stable public API.* - -Find the modes of a BATHistogram. -Returns a vector of the bin-centers of the bin(s) with the heighest weight. -""" -function find_localmodes(marg::MarginalDist) - hist = marg.dist.h - dims = ndims(hist.weights) - - max = maximum(hist.weights) - maxima_idx = findall(x->x==max, hist.weights) - - bin_centers = get_bin_centers(marg) - - return [[bin_centers[d][maxima_idx[i][d]] for d in 1:dims] for i in 1:length(maxima_idx) ] -end - - -""" - get_bin_centers(bathist::BATHistogram) - -*BAT-internal, not part of stable public API.* - -Returns a vector of the bin-centers. -""" -function get_bin_centers(marg::MarginalDist) - hist = marg.dist.h - edges = hist.edges - dims = ndims(hist.weights) - - centers = [[edges[d][i]+0.5*(edges[d][i+1]-edges[d][i]) for i in 1:length(edges[d])-1] for d in 1:dims] - - return centers -end - - -# create a BATHistogram containing some dimensions of higher-dimensional BATHistogram -function subhistogram( - bathist::BATHistogram, - params::Array{<:Integer,1} -) - dims = collect(1:ndims(bathist.h.weights)) - weights = sum(bathist.h.weights, dims=setdiff(dims, params)) - weights = dropdims(weights, dims=Tuple(setdiff(dims, params))) - - edges = Tuple([bathist.h.edges[p] for p in params]) - hist = StatsBase.Histogram(edges, weights, bathist.h.closed) - - return BATHistogram(hist) -end - -function islower(weights, idx) - if idx==1 && weights[idx]>0 - return true - elseif weights[idx]>0 && weights[idx-1]==0 && idx < length(weights) - return true - else - return false - end -end - -function isupper(weights, idx) - if idx==length(weights) && weights[idx-1]>0 - return true - elseif weights[idx]==0 && weights[idx-1]>0 - return true - else - return false - end -end - - -# return the lower and upper edges for clusters in which the bincontent is non-zero for all dimensions of a BATHistogram -# clusters that are seperated <= atol are combined -function get_interval_edges(bathist::BATHistogram; atol::Real = 0) - weights = bathist.h.weights - len = length(weights) - - lower = [bathist.h.edges[1][i] for i in 1:len if islower(weights, i)] - upper = [bathist.h.edges[1][i] for i in 2:len if isupper(weights, i)] - - if atol != 0 - idxs = [i for i in 1:length(upper)-1 if lower[i+1]-upper[i] <= atol] - deleteat!(upper, idxs) - deleteat!(lower, idxs.+1) - end - - return lower, upper -end diff --git a/src/plotting/Marginalization.jl b/src/plotting/MarginalDist.jl similarity index 58% rename from src/plotting/Marginalization.jl rename to src/plotting/MarginalDist.jl index 27a07c536..319c28372 100644 --- a/src/plotting/Marginalization.jl +++ b/src/plotting/MarginalDist.jl @@ -5,7 +5,6 @@ struct MarginalDist{N,D<:Distribution,VS<:AbstractValueShape} end -#TODO: does not work for unshaped samples function bat_marginalize( maybe_shaped_samples::DensitySampleVector, key::Union{Integer, Symbol, Expr}; @@ -37,7 +36,7 @@ end function bat_marginalize( maybe_shaped_samples::DensitySampleVector, - key::Union{NTuple{2,Integer}, NTuple{2,Union{Symbol, Expr}}}; + key::Union{NTuple{n,Integer}, NTuple{n,Union{Symbol, Expr}}} where n; nbins = 200, closed::Symbol = :left, filter::Bool = false, @@ -49,12 +48,11 @@ function bat_marginalize( samples = BAT.drop_low_weight_samples(samples) end - i = asindex(maybe_shaped_samples, key[1]) - j = asindex(maybe_shaped_samples, key[2]) + idxs = asindex.(Ref(maybe_shaped_samples), key) + s = Tuple(BAT.flatview(samples.v)[i, :] for i in idxs) hist = fit(Histogram, - (flatview(samples.v)[i, :], - flatview(samples.v)[j, :]), + s, FrequencyWeights(samples.weight), nbins = nbins, closed = closed) @@ -62,7 +60,7 @@ function bat_marginalize( mvbd = EmpiricalDistributions.MvBinnedDist(hist) - return MarginalDist((i,j), mvbd, varshape(maybe_shaped_samples)) + return MarginalDist(idxs, mvbd, varshape(maybe_shaped_samples)) end @@ -87,20 +85,52 @@ function bat_marginalize( end -function BATHistogram( +function bat_marginalize( prior::NamedTupleDist, - params::Union{NTuple{2, Symbol}, NTuple{2, Integer}}; + key::Union{NTuple{2, Symbol}, NTuple{2, Integer}}; nbins = 200, closed::Symbol = :left, - nsamples::Integer = 10^6 + nsamples::Integer = 10^6, + normalize=true ) - i = asindex(prior, params[1]) - j = asindex(prior, params[2]) + idxs = asindex.(Ref(prior), key) r = rand(prior, nsamples) - hist = fit(Histogram, (r[i, :], r[j, :]), nbins = nbins, closed = closed) + s = Tuple(r[i, :] for i in idxs) + + hist = fit(Histogram, s, nbins = nbins, closed = closed) + + normalize ? hist = StatsBase.normalize(hist) : nothing mvbd = EmpiricalDistributions.MvBinnedDist(hist) - return MarginalDist((i,j), mvbd, varshape(prior)) + return MarginalDist(idxs, mvbd, varshape(prior)) +end + + + +function bat_marginalize( + original::MarginalDist, + parsel::NTuple{n, Int} where n; + normalize=true +) + original_hist = original.dist.h + dims = collect(1:ndims(original_hist.weights)) + parsel = Tuple(findfirst(x-> x == p, original.dims) for p in parsel) + + weights = sum(original_hist.weights, dims=setdiff(dims, parsel)) + weights = dropdims(weights, dims=Tuple(setdiff(dims, parsel))) + + edges = Tuple([original_hist.edges[p] for p in parsel]) + hist = StatsBase.Histogram(edges, weights, original_hist.closed) + + normalize ? hist = StatsBase.normalize(hist) : nothing + + bd = if length(parsel) == 1 + EmpiricalDistributions.UvBinnedDist(hist) + else + EmpiricalDistributions.MvBinnedDist(hist) + end + + return MarginalDist(parsel, bd, original.origvalshape) end diff --git a/src/plotting/split_histograms.jl b/src/plotting/MarginalDist_utils.jl similarity index 60% rename from src/plotting/split_histograms.jl rename to src/plotting/MarginalDist_utils.jl index 77ec8eda6..d7d6055a5 100644 --- a/src/plotting/split_histograms.jl +++ b/src/plotting/MarginalDist_utils.jl @@ -1,5 +1,87 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). +""" + find_localmodes(marg::MarginalDist) + +*BAT-internal, not part of stable public API.* + +Find the modes of a MarginalDist. +Returns a vector of the bin-centers of the bin(s) with the heighest weight. +""" +function find_localmodes(marg::MarginalDist) + hist = marg.dist.h + dims = ndims(hist.weights) + + max = maximum(hist.weights) + maxima_idx = findall(x->x==max, hist.weights) + + bin_centers = get_bin_centers(marg) + + return [[bin_centers[d][maxima_idx[i][d]] for d in 1:dims] for i in 1:length(maxima_idx) ] +end + + +""" + get_bin_centers(marg::MarginalDist) + +*BAT-internal, not part of stable public API.* + +Returns a vector of the bin-centers. +""" +function get_bin_centers(marg::MarginalDist) + hist = marg.dist.h + edges = hist.edges + dims = ndims(hist.weights) + + centers = [[edges[d][i]+0.5*(edges[d][i+1]-edges[d][i]) for i in 1:length(edges[d])-1] for d in 1:dims] + + return centers +end + + + +function islower(weights, idx) + if idx==1 && weights[idx]>0 + return true + elseif weights[idx]>0 && weights[idx-1]==0 && idx < length(weights) + return true + else + return false + end +end + +function isupper(weights, idx) + if idx==length(weights) && weights[idx-1]>0 + return true + elseif weights[idx]==0 && weights[idx-1]>0 + return true + else + return false + end +end + + +# return the lower and upper edges for clusters in which the bincontent is non-zero for all dimensions of a StatsBase.Histogram +# clusters that are seperated <= atol are combined +function get_interval_edges(h::StatsBase.Histogram; atol::Real = 0) + weights = h.weights + len = length(weights) + + lower = [h.edges[1][i] for i in 1:len if islower(weights, i)] + upper = [h.edges[1][i] for i in 2:len if isupper(weights, i)] + + if atol != 0 + idxs = [i for i in 1:length(upper)-1 if lower[i+1]-upper[i] <= atol] + deleteat!(upper, idxs) + deleteat!(lower, idxs.+1) + end + + return lower, upper +end + + +# This file is a part of BAT.jl, licensed under the MIT License (MIT). + # for 1d and 2d histogramsm function get_smallest_intervals( histogram::StatsBase.Histogram, @@ -57,7 +139,7 @@ function split_central( hists[i] = deepcopy(hist) end - weights = vec(hist.h.weights) + weights = vec(hist.weights) totalweight = sum(weights) rel_weights = weights/totalweight diff --git a/src/plotting/plotting.jl b/src/plotting/plotting.jl index 83ca96422..6be11ebd9 100644 --- a/src/plotting/plotting.jl +++ b/src/plotting/plotting.jl @@ -4,15 +4,13 @@ const standard_confidence_vals = [0.683, 0.955, 0.997] const standard_colors = [:chartreuse2, :yellow, :red] -include("BATHistogram.jl") -include("Marginalization.jl") -include("recipes_Marginalization_1D.jl") -include("recipes_Marginalization_2D.jl") +include("MarginalDist.jl") +include("recipes_MarginalDist_1D.jl") +include("recipes_MarginalDist_2D.jl") +include("MarginalDist_utils.jl") include("recipes_stats.jl") include("recipes_samples_overview.jl") include("recipes_prior_overview.jl") -include("split_histograms.jl") -include("BATHistogram_utils.jl") include("recipes_ahmi.jl") include("recipes_prior.jl") include("recipes_samples_1D.jl") diff --git a/src/plotting/recipes_BATHistogram_1D.jl b/src/plotting/recipes_BATHistogram_1D.jl deleted file mode 100644 index 0651d2a16..000000000 --- a/src/plotting/recipes_BATHistogram_1D.jl +++ /dev/null @@ -1,131 +0,0 @@ -# This file is a part of BAT.jl, licensed under the MIT License (MIT). - -# TODO: add plot without Int for overview? - -function plothistogram(bathist::BATHistogram, swap::Bool) - if swap - return bathist.h.weights, bathist.h.edges[1][1:end-1] - else - return bathist.h.edges[1][1:end-1], bathist.h.weights - end -end - - -@recipe function f( - bathist::BATHistogram, - idx::Integer; - intervals = standard_confidence_vals, - normalize = true, - colors = standard_colors, - interval_labels = [] -) - hist = subhistogram(bathist, [idx]) - normalize ? hist.h=StatsBase.normalize(hist.h) : nothing - - orientation = get(plotattributes, :orientation, :vertical) - (orientation != :vertical) ? swap = true : swap = false - plotattributes[:orientation] = :vertical # without: auto-scaling of axes not correct - - seriestype = get(plotattributes, :seriestype, :stephist) - - xlabel = get(plotattributes, :xguide, "x$(idx)") - ylabel = get(plotattributes, :yguide, "p(x$(idx))") - - if swap - xguide := ylabel - yguide := xlabel - else - xguide := xlabel - yguide := ylabel - end - - # step histogram - if seriestype == :stephist || seriestype == :steppost - @series begin - seriestype := :steppost - label --> "" - linecolor --> :dodgerblue - plothistogram(hist, swap) - end - - # filled histogram - elseif seriestype == :histogram - @series begin - seriestype := :steppost - label --> "" - fillrange --> 0 - fillcolor --> :dodgerblue - linewidth --> 0 - plothistogram(hist, swap) - end - - - # smallest intervals aka highest density region (HDR) - elseif seriestype == :smallest_intervals || seriestype == :HDR - hists, realintervals = get_smallest_intervals(hist, intervals) - colors = colors[sortperm(intervals, rev=true)] - - # colored histogram for each interval - for i in 1:length(realintervals) - @series begin - seriestype := :steppost - fillcolor --> colors[i] - linewidth --> 0 - fillrange --> 0 - - if length(interval_labels) > 0 - label := interval_labels[i] - else - label := "smallest $(@sprintf("%.2f", realintervals[i]*100))% interval(s)" - end - plothistogram(hists[i], swap) - end - end - - # black contour line for total histogram - @series begin - seriestype := :steppost - linecolor --> :black - linewidth --> 0.7 - label --> "" - plothistogram(hist, swap) - end - - - # central intervals - elseif seriestype == :central_intervals - hists, realintervals = split_central(hist, intervals) - colors = colors[sortperm(intervals, rev=true)] - - # colored histogram for each interval - for i in 1:length(realintervals) - @series begin - seriestype := :steppost - fillcolor --> colors[i] - linewidth --> 0 - fillrange --> 0 - - if length(interval_labels) > 0 - label := interval_labels[i] - else - label := "central $(@sprintf("%.2f", realintervals[i]*100))% interval(s)" - end - - plothistogram(hists[i], swap) - end - end - - # black contour line for total histogram - @series begin - seriestype := :steppost - linecolor --> :black - linewidth --> 0.7 - label --> "" - plothistogram(hist, swap) - end - - else - error("seriestype $seriestype not supported") - end - -end diff --git a/src/plotting/recipes_BATHistogram_2D.jl b/src/plotting/recipes_BATHistogram_2D.jl deleted file mode 100644 index 16a810227..000000000 --- a/src/plotting/recipes_BATHistogram_2D.jl +++ /dev/null @@ -1,209 +0,0 @@ -# This file is a part of BAT.jl, licensed under the MIT License (MIT). -@recipe function f( - bathist::BATHistogram, - parsel::NTuple{2,Integer}; - intervals = standard_confidence_vals, - colors = standard_colors, - diagonal = Dict(), - upper = Dict(), - right = Dict(), - interval_labels = [], - normalize = true -) - _plots_module() != nothing || throw(ErrorException("Package Plots not available, but required for this operation")) - - hist = subhistogram(bathist, collect(parsel)) - normalize ? hist.h=StatsBase.normalize(hist.h) : nothing - - seriestype = get(plotattributes, :seriestype, :histogram2d) - - xlabel = get(plotattributes, :xguide, "x$(parsel[1])") - ylabel = get(plotattributes, :yguide, "x$(parsel[2])") - - - # histogram / heatmap - if seriestype == :histogram2d || seriestype == :histogram || seriestype == :hist - @series begin - seriestype := :bins2d - xguide --> xlabel - yguide --> ylabel - colorbar --> true - - hist.h.edges[1], hist.h.edges[2], _plots_module().Surface(hist.h.weights) - end - - - # smallest interval contours - elseif seriestype == :smallest_intervals_contour || seriestype == :smallest_intervals_contourf - - colors = colors[sortperm(intervals, rev=true)] - - if seriestype == :smallest_intervals_contour - plotstyle = :contour - else - plotstyle = :contourf - end - - lev = calculate_levels(hist, intervals) - x, y = get_bin_centers(hist) - m = hist.h.weights - - # quick fix: needed when plotting contour on top of histogram - # otherwise scaling of histogram colorbar would change scaling - lev = lev/10000 - m = m/10000 - - colorbar --> false - xguide --> xlabel - yguide --> ylabel - - if _plots_module().backend() == _plots_module().PyPlotBackend() - @series begin - seriestype := plotstyle - levels --> lev - linewidth --> 2 - seriescolor --> colors # currently only works with pyplot - (x, y, m') - end - else - @series begin - seriestype := plotstyle - levels --> lev - linewidth --> 2 - (x, y, m') - end - end - - - # smallest intervals heatmap - elseif seriestype == :smallest_intervals - colors = colors[sortperm(intervals, rev=true)] - - hists, realintervals = get_smallest_intervals(hist, intervals) - - for (i, int) in enumerate(realintervals) - @series begin - seriestype := :bins2d - seriescolor --> _plots_module().cgrad([colors[i], colors[i]]) - xguide --> xlabel - yguide --> ylabel - - hists[i].h.edges[1], hists[i].h.edges[2], _plots_module().Surface(hists[i].h.weights) - end - - # fake a legend - interval_label = isempty(interval_labels) ? "smallest $(@sprintf("%.2f", realintervals[i]*100))% interval(s)" : interval_labels[i] - - @series begin - seriestype := :shape - fillcolor --> colors[i] - linewidth --> 0 - label --> interval_label - colorbar --> false - [hists[i].h.edges[1][1], hists[i].h.edges[1][1]], [hists[i].h.edges[2][1], hists[i].h.edges[2][1]] - end - end - - - # marginal histograms - elseif seriestype == :marginal - layout --> _plots_module().grid(2,2, widths=(0.8, 0.2), heights=(0.2, 0.8)) - link --> :both - - if get(diagonal, "seriestype", :histogram) != :histogram - colorbar --> false - end - - @series begin - subplot := 1 - xguide := xlabel - yguide := "p("*xlabel*")" - seriestype := get(upper, "seriestype", :histogram) - bins --> get(upper, "nbins", 200) - normalize --> get(upper, "normalize", true) - colors --> get(upper, "colors", standard_colors) - intervals --> get(upper, "intervals", standard_confidence_vals) - legend --> get(upper, "legend", true) - - hist, 1 - end - - # empty plot (needed since @layout macro not available) - @series begin - seriestype := :scatter - subplot := 2 - grid := false - xaxis := false - yaxis := false - markersize := 0.001 - markerstrokewidth := 0 - markeralpha := 1 - markerstrokealpha := 1 - legend := false - label := "" - xguide := "" - yguide := "" - [(0,0)] - end - - @series begin - subplot := 3 - seriestype := get(diagonal, "seriestype", :histogram) - xguide --> xlabel - yguide --> ylabel - normalize --> get(diagonal, "normalize", true) - bins --> get(diagonal, "nbins", 200) - colors --> get(diagonal, "colors", standard_colors) - intervals --> get(diagonal, "intervals", standard_confidence_vals) - legend --> get(diagonal, "legend", false) - - hist, (1, 2) - end - - @series begin - subplot := 4 - seriestype := get(right, "seriestype", :histogram) - orientation := :horizontal - xguide := ylabel - yguide := "p("*ylabel*")" - normalize --> get(right, "normalize", true) - bins --> get(right, "nbins", 200) - colors --> get(right, "colors", standard_colors) - intervals --> get(right, "intervals", standard_confidence_vals) - legend --> get(right, "legend", true) - - hist, 2 - end - - else - error("seriestype $seriestype not supported") - end - -end - - - -# rectangle bounds -@recipe function f(bounds::HyperRectBounds, parsel::NTuple{2,Integer}) - pi_x, pi_y = parsel - - vol = spatialvolume(bounds) - vhi = vol.hi[[pi_x, pi_y]]; vlo = vol.lo[[pi_x, pi_y]] - rect_xy = rectangle_path(vlo, vhi) - bext = 0.1 * (vhi - vlo) - xlims = (vlo[1] - bext[1], vhi[1] + bext[1]) - ylims = (vlo[2] - bext[2], vhi[2] + bext[2]) - - @series begin - seriestype := :path - label --> "bounds" - seriescolor --> :darkred - seriesalpha --> 0.75 - linewidth --> 2 - xlims --> xlims - ylims --> ylims - (rect_xy[:,1], rect_xy[:,2]) - end - - nothing -end diff --git a/src/plotting/recipes_Marginalization_1D.jl b/src/plotting/recipes_MarginalDist_1D.jl similarity index 93% rename from src/plotting/recipes_Marginalization_1D.jl rename to src/plotting/recipes_MarginalDist_1D.jl index c27544835..f07c56b7b 100644 --- a/src/plotting/recipes_Marginalization_1D.jl +++ b/src/plotting/recipes_MarginalDist_1D.jl @@ -12,13 +12,16 @@ end @recipe function f( - marg::MarginalDist, - idx::Integer; + origmarg::MarginalDist, + parsel::Union{Integer, Symbol, Expr}; intervals = standard_confidence_vals, normalize = true, colors = standard_colors, interval_labels = [] ) + indx = asindex(origmarg, parsel) + + marg = bat_marginalize(origmarg, (indx, )) hist = marg.dist.h normalize ? hist=StatsBase.normalize(hist) : nothing @@ -28,8 +31,8 @@ end seriestype = get(plotattributes, :seriestype, :stephist) - xlabel = get(plotattributes, :xguide, "x$(idx)") - ylabel = get(plotattributes, :yguide, "p(x$(idx))") + xlabel = get(plotattributes, :xguide, "x$(indx)") + ylabel = get(plotattributes, :yguide, "p(x$(indx))") if swap xguide := ylabel diff --git a/src/plotting/recipes_Marginalization_2D.jl b/src/plotting/recipes_MarginalDist_2D.jl similarity index 97% rename from src/plotting/recipes_Marginalization_2D.jl rename to src/plotting/recipes_MarginalDist_2D.jl index 6df43d599..5b5dbe266 100644 --- a/src/plotting/recipes_Marginalization_2D.jl +++ b/src/plotting/recipes_MarginalDist_2D.jl @@ -1,7 +1,7 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). @recipe function f( marg::MarginalDist, - parsel::NTuple{2,Integer}; + parsel::NTuple{2,Union{Symbol, Expr, Integer}}; intervals = standard_confidence_vals, colors = standard_colors, diagonal = Dict(), @@ -11,15 +11,12 @@ normalize = true ) _plots_module() != nothing || throw(ErrorException("Package Plots not available, but required for this operation")) - println("hi") hist = marg.dist.h - seriestype = get(plotattributes, :seriestype, :histogram2d) xlabel = get(plotattributes, :xguide, "x$(parsel[1])") ylabel = get(plotattributes, :yguide, "x$(parsel[2])") - # histogram / heatmap if seriestype == :histogram2d || seriestype == :histogram || seriestype == :hist @series begin @@ -44,7 +41,7 @@ end lev = calculate_levels(hist, intervals) - x, y = get_bin_centers(hist) + x, y = get_bin_centers(marg) m = hist.weights # quick fix: needed when plotting contour on top of histogram @@ -106,6 +103,7 @@ # marginal histograms elseif seriestype == :marginal + layout --> _plots_module().grid(2,2, widths=(0.8, 0.2), heights=(0.2, 0.8)) link --> :both @@ -124,7 +122,7 @@ intervals --> get(upper, "intervals", standard_confidence_vals) legend --> get(upper, "legend", true) - hist, 1 + marg, parsel[1] end # empty plot (needed since @layout macro not available) @@ -156,7 +154,7 @@ intervals --> get(diagonal, "intervals", standard_confidence_vals) legend --> get(diagonal, "legend", false) - hist, (1, 2) + marg, (parsel[1], parsel[2]) end @series begin @@ -171,7 +169,7 @@ intervals --> get(right, "intervals", standard_confidence_vals) legend --> get(right, "legend", true) - hist, 2 + marg, parsel[2] end else diff --git a/src/plotting/recipes_prior.jl b/src/plotting/recipes_prior.jl index 827a3152e..5001a0391 100644 --- a/src/plotting/recipes_prior.jl +++ b/src/plotting/recipes_prior.jl @@ -56,7 +56,7 @@ colors --> colors interval_labels --> interval_labels - marg, 1 + marg, idx end end @@ -66,7 +66,7 @@ end # 2D plots @recipe function f( prior::NamedTupleDist, - parsel::Union{NTuple{2,Integer}, NTuple{2,Union{Symbol, Expr}}}; + parsel::Union{NTuple{2,Integer}, NTuple{2,Union{Symbol, Expr, Integer}}}; nsamples=10^6, intervals = standard_confidence_vals, bins = 200, @@ -89,12 +89,13 @@ end throw(ArgumentError("Symbol :$(parsel[2]) refers to a multivariate parameter. Use :($(parsel[2])[i]) instead.")) end - bathist = BATHistogram( + + marg = bat_marginalize( prior, (xidx, yidx), - nbins=bins, - closed=closed, - nsamples=nsamples + nbins = bins, + closed = closed, + normalize = normalize ) @@ -116,6 +117,6 @@ end upper --> upper right --> right - bathist, (1, 2) + marg, (xidx, yidx) end end diff --git a/src/plotting/recipes_samples_1D.jl b/src/plotting/recipes_samples_1D.jl index 04bf4f95c..1841eb9ea 100644 --- a/src/plotting/recipes_samples_1D.jl +++ b/src/plotting/recipes_samples_1D.jl @@ -60,7 +60,7 @@ colors --> colors interval_labels --> interval_labels - marg, 1 + marg, idx end #------ stats ---------------------------- diff --git a/src/plotting/recipes_samples_2D.jl b/src/plotting/recipes_samples_2D.jl index 754af5f72..02cd860a9 100644 --- a/src/plotting/recipes_samples_2D.jl +++ b/src/plotting/recipes_samples_2D.jl @@ -1,7 +1,7 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). @recipe function f( maybe_shaped_samples::DensitySampleVector, - parsel::Union{NTuple{2,Integer}, NTuple{2,Union{Symbol, Expr}}}; + parsel::NTuple{2,Union{Symbol, Expr, Integer}}; intervals = standard_confidence_vals, interval_labels = [], colors = standard_colors, @@ -52,9 +52,6 @@ filter = filter ) - println(typeof(marg)) - - if seriestype == :scatter base_markersize = get(plotattributes, :markersize, 1.5) @@ -95,7 +92,7 @@ upper --> upper right --> right - marg, (1, 2) + marg, (xindx, yindx) end end diff --git a/src/plotting/valueshapes_utils.jl b/src/plotting/valueshapes_utils.jl index f0c40239e..e142d9fd3 100644 --- a/src/plotting/valueshapes_utils.jl +++ b/src/plotting/valueshapes_utils.jl @@ -41,12 +41,23 @@ function asindex(ntd::NamedTupleDist, name::Union{Expr, Symbol}) end function asindex( - x::Union{DensitySampleVector, NamedTupleDist}, + x::Union{DensitySampleVector, NamedTupleDist, MarginalDist}, key::Integer ) return key end +#for MarginalDist +function asindex(marg::MarginalDist, name::Union{Expr, Symbol}) + idx = asindex(marg.origvalshape, name) + if idx in marg.dims + return idx + else + throw(ArgumentError("Key :$name not in MarginalDist")) + end +end + + # Return the name corresponding to the index as Symbol (for univariate) # or Expr for (multivariate) distributions function getname(vs::NamedTupleShape, idx::Integer) @@ -77,6 +88,16 @@ function getstring(samples::BAT.DensitySampleVector, idx::Integer) end end +function getstring(marg::MarginalDist, idx::Integer) + println(marg.dims) + if idx in marg.dims + vs = marg.origvalshape + names = allnames(vs) + return names[idx] + else + throw(ArgumentError("Index $idx not in MarginalDist")) + end +end # Return array of strings with the names of all indices. # For a multivariate distribution, names for each dimension are created by appending "[i]" to the name. @@ -87,6 +108,7 @@ function allnames(vs::NamedTupleShape) return reduce(vcat, names) end + # Return array of strings with the names of all indices. # For a multivariate distribution, the name is repeated for all its dimensions. function repeatednames(vs::NamedTupleShape) diff --git a/test_marginal.jl b/test_marginal.jl deleted file mode 100644 index 713f56202..000000000 --- a/test_marginal.jl +++ /dev/null @@ -1,110 +0,0 @@ -# # BAT.jl plotting tutorial - -using BAT -using Distributions -using IntervalSets - - -likelihood = params -> begin - - r1 = logpdf.( - MixtureModel(Normal[ - Normal(-10.0, 1.2), - Normal(0.0, 1.8), - Normal(10.0, 2.5)], [0.1, 0.3, 0.6]), params.a) - - r2 = logpdf.( - MixtureModel(Normal[ - Normal(-5.0, 2.2), - Normal(5.0, 1.5)], [0.3, 0.7]), params.b[1]) - - r3 = logpdf.(Normal(2.0, 1.5), params.b[2]) - - return LogDVal(r1+r2+r3) -end - -prior = BAT.NamedTupleDist( - a = Normal(-3, 4.5), - b = [-20.0..20.0, -10..10] -) - -posterior = PosteriorDensity(likelihood, prior); - -samples, chains = bat_sample(posterior, (10^5, 4), MetropolisHastings()); - -unshaped_samples = BAT.unshaped.(samples) - -BAT.bat_marginalize(samples, (:(b[1]),:(b[2]))) - - -BAT.bat_marginalize(unshaped_samples, (1, 2)) - - - - - -# ## Set up plotting -# Set up plotting using the [Plots.jl](https://github.com/JuliaPlots/Plots.jl) package: -using Plots - - -plot(samples, :a) #default seriestype = :smallest_intervals (alias :HDR) - -plot(prior, :a) -#or: plot(prior, 1) - -# ## Knowledge update plot -# The knowledge update after performing the sampling can be visualized by plotting the prior and the samples of the psterior together in one plot using `plot!()`: -plot(samples, :(b[1])) -plot!(prior, :(b[1])) -0.7^15*100 - - -lost = 0 -win = 0 -x=0 -for i in 1:20 - while(win<=lost) - global x +=0.5 - res = 18*x - - win = res-lost - end - lost = lost + 6*x - println(x) -end - -function f(x, lost) - j=0 - for i in 1:30 - x += 0.5 - res = 18*x - - #println("lost before: $lost") - win = res - plus = win-6*x-lost - - if(plus > 0) - j = j+1 - println("\nj: $j") - println("x: $x") - println("6*x: $(6*x)") - println("win: $win") - println("plus: $plus") - println("total: $total") - total = total + 6*x - lost = lost + 6*x - # println("i: $i") - #println("lost after: $lost\n") - end - - - end -end - -f(0, 0) - -p = 1-12/37 -p^15*100 - -0.25^10*100 From 4cde2e0d00bb2d43a48f6b2f9607904a49fe622d Mon Sep 17 00:00:00 2001 From: Cornelius-G Date: Tue, 9 Jun 2020 14:48:44 +0200 Subject: [PATCH 3/3] make bat_marginalize return namedtuple --- src/plotting/MarginalDist.jl | 16 +++++++++++----- src/plotting/recipes_MarginalDist_1D.jl | 2 +- src/plotting/recipes_prior.jl | 4 ++-- src/plotting/recipes_samples_1D.jl | 2 +- src/plotting/recipes_samples_2D.jl | 2 +- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/plotting/MarginalDist.jl b/src/plotting/MarginalDist.jl index 319c28372..973f189e9 100644 --- a/src/plotting/MarginalDist.jl +++ b/src/plotting/MarginalDist.jl @@ -29,8 +29,9 @@ function bat_marginalize( normalize ? hist = StatsBase.normalize(hist) : nothing uvbd = EmpiricalDistributions.UvBinnedDist(hist) + marg = MarginalDist((idx,), uvbd, varshape(maybe_shaped_samples)) - return MarginalDist((idx,), uvbd, varshape(maybe_shaped_samples)) + return (result = marg, ) end @@ -59,8 +60,9 @@ function bat_marginalize( normalize ? hist = StatsBase.normalize(hist) : nothing mvbd = EmpiricalDistributions.MvBinnedDist(hist) + marg = MarginalDist(idxs, mvbd, varshape(maybe_shaped_samples)) - return MarginalDist(idxs, mvbd, varshape(maybe_shaped_samples)) + return (result = marg, ) end @@ -80,8 +82,9 @@ function bat_marginalize( normalize ? hist = StatsBase.normalize(hist) : nothing uvbd = EmpiricalDistributions.UvBinnedDist(hist) + marg = MarginalDist((idx,), uvbd, varshape(prior)) - return MarginalDist((idx,), uvbd, varshape(prior)) + return (result = marg, ) end @@ -103,8 +106,9 @@ function bat_marginalize( normalize ? hist = StatsBase.normalize(hist) : nothing mvbd = EmpiricalDistributions.MvBinnedDist(hist) + marg = MarginalDist(idxs, mvbd, varshape(prior)) - return MarginalDist(idxs, mvbd, varshape(prior)) + return (result = marg, ) end @@ -132,5 +136,7 @@ function bat_marginalize( EmpiricalDistributions.MvBinnedDist(hist) end - return MarginalDist(parsel, bd, original.origvalshape) + marg = MarginalDist(parsel, bd, original.origvalshape) + + return (result = marg, ) end diff --git a/src/plotting/recipes_MarginalDist_1D.jl b/src/plotting/recipes_MarginalDist_1D.jl index f07c56b7b..a8afc419d 100644 --- a/src/plotting/recipes_MarginalDist_1D.jl +++ b/src/plotting/recipes_MarginalDist_1D.jl @@ -21,7 +21,7 @@ end ) indx = asindex(origmarg, parsel) - marg = bat_marginalize(origmarg, (indx, )) + marg = bat_marginalize(origmarg, (indx, )).result hist = marg.dist.h normalize ? hist=StatsBase.normalize(hist) : nothing diff --git a/src/plotting/recipes_prior.jl b/src/plotting/recipes_prior.jl index 5001a0391..8823e81e6 100644 --- a/src/plotting/recipes_prior.jl +++ b/src/plotting/recipes_prior.jl @@ -28,7 +28,7 @@ nsamples = nsamples, closed = closed, normalize = normalize - ) + ).result xlabel = if isa(parsel, Symbol) || isa(parsel, Expr) "$parsel" @@ -96,7 +96,7 @@ end nbins = bins, closed = closed, normalize = normalize - ) + ).result xlabel, ylabel = if isa(parsel, Symbol) || isa(parsel, Expr) diff --git a/src/plotting/recipes_samples_1D.jl b/src/plotting/recipes_samples_1D.jl index 1841eb9ea..d907dff6e 100644 --- a/src/plotting/recipes_samples_1D.jl +++ b/src/plotting/recipes_samples_1D.jl @@ -28,7 +28,7 @@ closed = closed, filter = filter, normalize = normalize - ) + ).result orientation = get(plotattributes, :orientation, :vertical) (orientation != :vertical) ? swap=true : swap = false diff --git a/src/plotting/recipes_samples_2D.jl b/src/plotting/recipes_samples_2D.jl index 02cd860a9..599fe3678 100644 --- a/src/plotting/recipes_samples_2D.jl +++ b/src/plotting/recipes_samples_2D.jl @@ -50,7 +50,7 @@ nbins = bins, closed = closed, filter = filter - ) + ).result if seriestype == :scatter base_markersize = get(plotattributes, :markersize, 1.5)