Skip to content

Commit

Permalink
nde implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
htjb committed Nov 8, 2023
1 parent 3dac5ad commit 9a7c3e9
Showing 1 changed file with 320 additions and 0 deletions.
320 changes: 320 additions & 0 deletions anesthetic/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
iso_probability_contours,
match_contour_to_contourf, histogram_bin_edges)
from anesthetic.boundary import cut_and_normalise_gaussian
from margarine.maf import MAF
from margarine.clustered import clusterMAF


class AxesSeries(Series):
Expand Down Expand Up @@ -973,6 +975,167 @@ def kde_plot_1d(ax, data, *args, **kwargs):

return ans

def nde_plot_1d(ax, data, *args, **kwargs):
"""Plot a 1d marginalised distribution.
This functions as a wrapper around :meth:`matplotlib.axes.Axes.plot`, with
a neural density estimation computation via
:class:`margarine.maf` or :class:`margarine.clustered`. All remaining keyword
arguments are passed onwards.
Parameters
----------
ax : :class:`matplotlib.axes.Axes`
Axis object to plot on.
data : np.array
Samples to generate kernel density estimator.
weights : np.array, optional
Sample weights.
ncompress : int, str, default=False
Degree of compression.
* If ``False``: no compression.
* If ``True``: compresses to the channel capacity, equivalent to
``ncompress='entropy'``.
* If ``int``: desired number of samples after compression.
* If ``str``: determine number from the Huggins-Roy family of
effective samples in :func:`anesthetic.utils.neff`
with ``beta=ncompress``.
nplot_1d : int, default=100
Number of plotting points to use.
levels : list
Values at which to draw iso-probability lines.
Default: [0.95, 0.68]
q : int or float or tuple, default=5
Quantile to determine the data range to be plotted.
* ``0``: full data range, i.e. ``q=0`` --> quantile range (0, 1)
* ``int``: q-sigma range, e.g. ``q=1`` --> quantile range (0.16, 0.84)
* ``float``: percentile, e.g. ``q=0.8`` --> quantile range (0.1, 0.9)
* ``tuple``: quantile range, e.g. (0.16, 0.84)
facecolor : bool or string, default=False
If set to True then the 1d plot will be shaded with the value of the
``color`` kwarg. Set to a string such as 'blue', 'k', 'r', 'C1' ect.
to define the color of the shading directly.
beta : int, float, default = 1
The value of beta used to calculate the number of effective samples
nde_epochs : int, default=1000
Number of epochs to train the NDE for.
nde_lr : float, default=1e-4
Learning rate for the NDE.
nde_hidden_layers : list, default=[50, 50]
Number of hidden layers for the NDE and number of nodes
in the layers.
nde_number_networks : int, default=6
Number of networks to use in the NDE.
nde_clustering : bool, default=False
Whether to use clustering in the NDE.
Returns
-------
lines : :class:`matplotlib.lines.Line2D`
A list of line objects representing the plotted data (same as
:meth:`matplotlib.axes.Axes.plot` command).
"""
kwargs = normalize_kwargs(kwargs)
weights = kwargs.pop('weights', None)
if weights is not None:
data = data[weights != 0]
weights = weights[weights != 0]
if ax.get_xaxis().get_scale() == 'log':
data = np.log10(data)

ncompress = kwargs.pop('ncompress', False)
nplot = kwargs.pop('nplot_1d', 100)
bw_method = kwargs.pop('bw_method', None)
levels = kwargs.pop('levels', [0.95, 0.68])
density = kwargs.pop('density', False)

nde_epochs = kwargs.pop('nde_epochs', 1000)
nde_lr = kwargs.pop('nde_lr', 1e-4)
nde_hidden_layers = kwargs.pop('nde_hidden_layers', [50, 50])
nde_number_networks = kwargs.pop('nde_number_networks', 6)
nde_clustering = kwargs.pop('nde_clustering', False)

cmap = kwargs.pop('cmap', None)
color = kwargs.pop('color', (ax._get_lines.get_next_color()
if cmap is None
else plt.get_cmap(cmap)(0.68)))
facecolor = kwargs.pop('facecolor', False)
if 'edgecolor' in kwargs:
edgecolor = kwargs.pop('edgecolor')
if edgecolor:
color = edgecolor
else:
edgecolor = color

q = kwargs.pop('q', 5)
q = quantile_plot_interval(q=q)
xmin = quantile(data, q[0], weights)
xmax = quantile(data, q[-1], weights)
x = np.linspace(xmin, xmax, nplot)

data_compressed, w = sample_compression_1d(data, weights, ncompress)
if w is None:
w = np.ones(data_compressed.shape[0])

if nde_clustering:
nde = clusterMAF(np.array([data_compressed]).T, weights=w,
number_networks=nde_number_networks,
hidden_layers=nde_hidden_layers,
lr=nde_lr)
else:
nde = MAF(np.array([data_compressed]).T, weights=w,
number_networks=nde_number_networks,
hidden_layers=nde_hidden_layers,
lr=nde_lr)
nde.train(epochs=nde_epochs, early_stop=True)

if nde_clustering:
pp = nde.log_prob(np.array([x]).T)
else:
pp = nde.log_prob(np.array([x]).T).numpy()
pp = np.exp(pp - pp.max())

area = np.trapz(x=x, y=pp) if density else 1

if ax.get_xaxis().get_scale() == 'log':
x = 10**x
ans = ax.plot(x, pp/area, color=color, *args, **kwargs)

if facecolor and facecolor not in [None, 'None', 'none']:
if facecolor is True:
facecolor = color
c = iso_probability_contours(pp, contours=levels)
cmap = basic_cmap(facecolor)
fill = []
for j in range(len(c)-1):
fill.append(ax.fill_between(x, pp, where=pp >= c[j],
color=cmap(c[j]), edgecolor=edgecolor))

ans = ans, fill

if density:
ax.set_ylim(bottom=0)
else:
ax.set_ylim(0, 1.1)

return ans


def hist_plot_1d(ax, data, *args, **kwargs):
"""Plot a 1d histogram.
Expand Down Expand Up @@ -1311,6 +1474,163 @@ def kde_contour_plot_2d(ax, data_x, data_y, *args, **kwargs):

return contf, cont

def nde_contour_plot_2d(ax, data_x, data_y, *args, **kwargs):
"""Plot a 2d marginalised distribution as contours.
This functions as a wrapper around :meth:`matplotlib.axes.Axes.contour`
and :meth:`matplotlib.axes.Axes.contourf` with a normalising flow as an neural
density estimator (NDE) via :class:`margarine.maf`
All remaining keyword arguments are passed onwards to both functions.
Parameters
----------
ax : :class:`matplotlib.axes.Axes`
Axis object to plot on.
data_x, data_y : np.array
The x and y coordinates of uniformly weighted samples to generate
kernel density estimator.
weights : np.array, optional
Sample weights.
levels : list, optional
Amount of mass within each iso-probability contour.
Has to be ordered from outermost to innermost contour.
Default: [0.95, 0.68]
ncompress : int, str, default='equal'
Degree of compression.
* If ``int``: desired number of samples after compression.
* If ``False``: no compression.
* If ``True``: compresses to the channel capacity, equivalent to
``ncompress='entropy'``.
* If ``str``: determine number from the Huggins-Roy family of
effective samples in :func:`anesthetic.utils.neff`
with ``beta=ncompress``.
nplot_2d : int, default=1000
Number of plotting points to use.
nde_epochs : int, default=1000
Number of epochs to train the NDE for.
nde_lr : float, default=1e-4
Learning rate for the NDE.
nde_hidden_layers : list, default=[50, 50]
Number of hidden layers for the NDE and number of nodes
in the layers.
nde_number_networks : int, default=6
Number of networks to use in the NDE.
nde_clustering : bool, default=False
Whether to use clustering in the NDE.
Returns
-------
c : :class:`matplotlib.contour.QuadContourSet`
A set of contourlines or filled regions.
"""
kwargs = normalize_kwargs(kwargs, dict(linewidths=['linewidth', 'lw'],
linestyles=['linestyle', 'ls'],
color=['c'],
facecolor=['fc'],
edgecolor=['ec']))

weights = kwargs.pop('weights', None)
if weights is not None:
data_x = data_x[weights != 0]
data_y = data_y[weights != 0]
weights = weights[weights != 0]
if ax.get_xaxis().get_scale() == 'log':
data_x = np.log10(data_x)
if ax.get_yaxis().get_scale() == 'log':
data_y = np.log10(data_y)

ncompress = kwargs.pop('ncompress', 'equal')
nplot = kwargs.pop('nplot_2d', 1000)
bw_method = kwargs.pop('bw_method', None)
label = kwargs.pop('label', None)
zorder = kwargs.pop('zorder', 1)
levels = kwargs.pop('levels', [0.95, 0.68])

color = kwargs.pop('color', ax._get_lines.get_next_color())
facecolor = kwargs.pop('facecolor', True)
edgecolor = kwargs.pop('edgecolor', None)
cmap = kwargs.pop('cmap', None)
facecolor, edgecolor, cmap = set_colors(c=color, fc=facecolor,
ec=edgecolor, cmap=cmap)

nde_epochs = kwargs.pop('nde_epochs', 1000)
nde_lr = kwargs.pop('nde_lr', 1e-4)
nde_hidden_layers = kwargs.pop('nde_hidden_layers', [50, 50])
nde_number_networks = kwargs.pop('nde_number_networks', 6)
nde_clustering = kwargs.pop('nde_clustering', False)

kwargs.pop('q', None)

q = kwargs.pop('q', 5)
q = quantile_plot_interval(q=q)
xmin = quantile(data_x, q[0], weights)
xmax = quantile(data_x, q[-1], weights)
ymin = quantile(data_y, q[0], weights)
ymax = quantile(data_y, q[-1], weights)
X, Y = np.mgrid[xmin:xmax:1j*np.sqrt(nplot), ymin:ymax:1j*np.sqrt(nplot)]

cov = np.cov(data_x, data_y, aweights=weights)
tri, w = triangular_sample_compression_2d(data_x, data_y, cov,
weights, ncompress)
data = np.array([tri.x, tri.y]).T

if nde_clustering:
nde = clusterMAF(data, weights=w, lr=nde_lr, hidden_layers=nde_hidden_layers,
number_networks=nde_number_networks)
else:
nde = MAF(data, weights=w, lr=nde_lr, hidden_layers=nde_hidden_layers,
number_networks=nde_number_networks)
nde.train(epochs=nde_epochs, early_stop=True)

if nde_clustering:
P = nde.log_prob(np.array([X.ravel(), Y.ravel()]).T)
else:
P = nde.log_prob(np.array([X.ravel(), Y.ravel()]).T).numpy()
P = np.exp(P - P.max()).reshape(X.shape)

levels = iso_probability_contours(P, contours=levels)
if ax.get_xaxis().get_scale() == 'log':
X = 10**X
if ax.get_yaxis().get_scale() == 'log':
Y = 10**Y

if facecolor not in [None, 'None', 'none']:
linewidths = kwargs.pop('linewidths', 0.5)
contf = ax.contourf(X, Y, P, levels=levels, cmap=cmap, zorder=zorder,
vmin=0, vmax=P.max(), *args, **kwargs)
contf.set_cmap(cmap)
ax.add_patch(plt.Rectangle((0, 0), 0, 0, lw=2, label=label,
fc=cmap(0.999), ec=cmap(0.32)))
cmap = None
else:
linewidths = kwargs.pop('linewidths',
plt.rcParams.get('lines.linewidth'))
contf = None
ax.add_patch(
plt.Rectangle((0, 0), 0, 0, lw=2, label=label,
fc='None' if cmap is None else cmap(0.999),
ec=edgecolor if cmap is None else cmap(0.32))
)

vmin, vmax = match_contour_to_contourf(levels, vmin=0, vmax=P.max())
cont = ax.contour(X, Y, P, levels=levels, zorder=zorder,
vmin=vmin, vmax=vmax, linewidths=linewidths,
colors=edgecolor, cmap=cmap, *args, **kwargs)

return contf, cont


def hist_plot_2d(ax, data_x, data_y, *args, **kwargs):
"""Plot a 2d marginalised distribution as a histogram.
Expand Down

0 comments on commit 9a7c3e9

Please sign in to comment.