Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NDE implementation #353

Draft
wants to merge 18 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
anesthetic: nested sampling post-processing
===========================================
:Authors: Will Handley and Lukas Hergt
:Version: 2.8.1
:Version: 2.9.0
:Homepage: https://github.com/handley-lab/anesthetic
:Documentation: http://anesthetic.readthedocs.io/

Expand Down
2 changes: 1 addition & 1 deletion anesthetic/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.8.1'
__version__ = '2.9.0'
339 changes: 339 additions & 0 deletions anesthetic/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,176 @@
return ans


def nde_plot_1d(ax, data, *args, **kwargs):
"""Plot a 1d marginalised distribution.

This functions as a wrapper around :meth:`matplotlib.axes.Axes.plot`, with
a normalising flow as a neural density estimation (NDE) provided by
:class:`margarine.maf.MAF` or :class:`margarine.clustered.clusterMAF`.
See also `Bevins et al. (2022) <https://arxiv.org/abs/2205.12841>`_.

All remaining keyword arguments are passed onwards.

Parameters
----------
ax : :class:`matplotlib.axes.Axes`
Axis object to plot on.

data : np.array
Samples to generate kernel density estimator.

weights : np.array, optional
Sample weights.

ncompress : int, str, default=False
Degree of compression.

* If ``False``: no compression.
* If ``True``: compresses to the channel capacity, equivalent to
``ncompress='entropy'``.
* If ``int``: desired number of samples after compression.
* If ``str``: determine number from the Huggins-Roy family of
effective samples in :func:`anesthetic.utils.neff`
with ``beta=ncompress``.

nplot_1d : int, default=100
Number of plotting points to use.

levels : list
Values at which to draw iso-probability lines.
Default: [0.95, 0.68]

q : int or float or tuple, default=5
Quantile to determine the data range to be plotted.

* ``0``: full data range, i.e. ``q=0`` --> quantile range (0, 1)
* ``int``: q-sigma range, e.g. ``q=1`` --> quantile range (0.16, 0.84)
* ``float``: percentile, e.g. ``q=0.8`` --> quantile range (0.1, 0.9)
* ``tuple``: quantile range, e.g. (0.16, 0.84)

facecolor : bool or string, default=False
If set to True then the 1d plot will be shaded with the value of the
``color`` kwarg. Set to a string such as 'blue', 'k', 'r', 'C1' ect.
to define the color of the shading directly.

beta : int, float, default = 1
The value of beta used to calculate the number of effective samples

nde_epochs : int, default=1000
Number of epochs to train the NDE for.

nde_lr : float, default=1e-4
Learning rate for the NDE.

nde_hidden_layers : list, default=[50, 50]
Number of hidden layers for the NDE and number of nodes
in the layers.

nde_number_networks : int, default=6
Number of networks to use in the NDE.

nde_clustering : bool, default=False
Whether to use clustering in the NDE.

Returns
-------
lines : :class:`matplotlib.lines.Line2D`
A list of line objects representing the plotted data (same as
:meth:`matplotlib.axes.Axes.plot` command).

"""
kwargs = normalize_kwargs(kwargs)
weights = kwargs.pop('weights', None)
if weights is not None:
data = data[weights != 0]
weights = weights[weights != 0]
if ax.get_xaxis().get_scale() == 'log':
data = np.log10(data)

Check warning on line 1061 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1055-L1061

Added lines #L1055 - L1061 were not covered by tests

ncompress = kwargs.pop('ncompress', False)
nplot = kwargs.pop('nplot_1d', 100)
levels = kwargs.pop('levels', [0.95, 0.68])
density = kwargs.pop('density', False)

Check warning on line 1066 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1063-L1066

Added lines #L1063 - L1066 were not covered by tests

nde_epochs = kwargs.pop('nde_epochs', 1000)
nde_lr = kwargs.pop('nde_lr', 1e-4)
nde_hidden_layers = kwargs.pop('nde_hidden_layers', [50, 50])
nde_number_networks = kwargs.pop('nde_number_networks', 6)
nde_clustering = kwargs.pop('nde_clustering', False)

Check warning on line 1072 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1068-L1072

Added lines #L1068 - L1072 were not covered by tests

cmap = kwargs.pop('cmap', None)
color = kwargs.pop('color', (ax._get_lines.get_next_color()

Check warning on line 1075 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1074-L1075

Added lines #L1074 - L1075 were not covered by tests
if cmap is None
else plt.get_cmap(cmap)(0.68)))
facecolor = kwargs.pop('facecolor', False)
if 'edgecolor' in kwargs:
edgecolor = kwargs.pop('edgecolor')
if edgecolor:
color = edgecolor

Check warning on line 1082 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1078-L1082

Added lines #L1078 - L1082 were not covered by tests
else:
edgecolor = color

Check warning on line 1084 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1084

Added line #L1084 was not covered by tests

q = kwargs.pop('q', 5)
q = quantile_plot_interval(q=q)
xmin = quantile(data, q[0], weights)
xmax = quantile(data, q[-1], weights)
x = np.linspace(xmin, xmax, nplot)

Check warning on line 1090 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1086-L1090

Added lines #L1086 - L1090 were not covered by tests

data_compressed, w = sample_compression_1d(data, weights, ncompress)
if w is None:
w = np.ones(data_compressed.shape[0])

Check warning on line 1094 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1092-L1094

Added lines #L1092 - L1094 were not covered by tests

try:
from margarine.maf import MAF
from margarine.clustered import clusterMAF

Check warning on line 1098 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1096-L1098

Added lines #L1096 - L1098 were not covered by tests

except ImportError:
raise ImportError("Please install margarine to use nde_1d")

Check warning on line 1101 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1100-L1101

Added lines #L1100 - L1101 were not covered by tests

if nde_clustering:
nde = clusterMAF(np.array([data_compressed]).T, weights=w,

Check warning on line 1104 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1103-L1104

Added lines #L1103 - L1104 were not covered by tests
number_networks=nde_number_networks,
hidden_layers=nde_hidden_layers,
lr=nde_lr)
else:
nde = MAF(np.array([data_compressed]).T, weights=w,

Check warning on line 1109 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1109

Added line #L1109 was not covered by tests
number_networks=nde_number_networks,
hidden_layers=nde_hidden_layers,
lr=nde_lr)
nde.train(epochs=nde_epochs, early_stop=True)

Check warning on line 1113 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1113

Added line #L1113 was not covered by tests

if nde_clustering:
pp = nde.log_prob(np.array([x]).T)

Check warning on line 1116 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1115-L1116

Added lines #L1115 - L1116 were not covered by tests
else:
pp = nde.log_prob(np.array([x]).T).numpy()
pp = np.exp(pp - pp.max())

Check warning on line 1119 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1118-L1119

Added lines #L1118 - L1119 were not covered by tests

area = np.trapz(x=x, y=pp) if density else 1

Check warning on line 1121 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1121

Added line #L1121 was not covered by tests

if ax.get_xaxis().get_scale() == 'log':
x = 10**x
ans = ax.plot(x, pp/area, color=color, *args, **kwargs)

Check warning on line 1125 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1123-L1125

Added lines #L1123 - L1125 were not covered by tests

if facecolor and facecolor not in [None, 'None', 'none']:
if facecolor is True:
facecolor = color
c = iso_probability_contours(pp, contours=levels)
cmap = basic_cmap(facecolor)
fill = []
for j in range(len(c)-1):
fill.append(ax.fill_between(x, pp, where=pp >= c[j],

Check warning on line 1134 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1127-L1134

Added lines #L1127 - L1134 were not covered by tests
color=cmap(c[j]), edgecolor=edgecolor))

ans = ans, fill

Check warning on line 1137 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1137

Added line #L1137 was not covered by tests

if density:
ax.set_ylim(bottom=0)

Check warning on line 1140 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1139-L1140

Added lines #L1139 - L1140 were not covered by tests
else:
ax.set_ylim(0, 1.1)

Check warning on line 1142 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1142

Added line #L1142 was not covered by tests

return ans

Check warning on line 1144 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1144

Added line #L1144 was not covered by tests


def hist_plot_1d(ax, data, *args, **kwargs):
"""Plot a 1d histogram.

Expand Down Expand Up @@ -1312,6 +1482,175 @@
return contf, cont


def nde_contour_plot_2d(ax, data_x, data_y, *args, **kwargs):
"""Plot a 2d marginalised distribution as contours.

This functions as a wrapper around :meth:`matplotlib.axes.Axes.contour`
and :meth:`matplotlib.axes.Axes.contourf` with a normalising flow as a
neural density estimator (NDE) provided by :class:`margarine.maf.MAF` or
:class:`margarine.clustered.clusterMAF`.
See also `Bevins et al. (2022) <https://arxiv.org/abs/2205.12841>`_.

All remaining keyword arguments are passed onwards to both functions.

Parameters
----------
ax : :class:`matplotlib.axes.Axes`
Axis object to plot on.

data_x, data_y : np.array
The x and y coordinates of uniformly weighted samples to generate
kernel density estimator.

weights : np.array, optional
Sample weights.

levels : list, optional
Amount of mass within each iso-probability contour.
Has to be ordered from outermost to innermost contour.
Default: [0.95, 0.68]

ncompress : int, str, default='equal'
Degree of compression.

* If ``int``: desired number of samples after compression.
* If ``False``: no compression.
* If ``True``: compresses to the channel capacity, equivalent to
``ncompress='entropy'``.
* If ``str``: determine number from the Huggins-Roy family of
effective samples in :func:`anesthetic.utils.neff`
with ``beta=ncompress``.

nplot_2d : int, default=1000
Number of plotting points to use.

nde_epochs : int, default=1000
Number of epochs to train the NDE for.

nde_lr : float, default=1e-4
Learning rate for the NDE.

nde_hidden_layers : list, default=[50, 50]
Number of hidden layers for the NDE and number of nodes
in the layers.

nde_number_networks : int, default=6
Number of networks to use in the NDE.

nde_clustering : bool, default=False
Whether to use clustering in the NDE.

Returns
-------
c : :class:`matplotlib.contour.QuadContourSet`
A set of contourlines or filled regions.

"""
kwargs = normalize_kwargs(kwargs, dict(linewidths=['linewidth', 'lw'],

Check warning on line 1549 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1549

Added line #L1549 was not covered by tests
linestyles=['linestyle', 'ls'],
color=['c'],
facecolor=['fc'],
edgecolor=['ec']))

weights = kwargs.pop('weights', None)
if weights is not None:
data_x = data_x[weights != 0]
data_y = data_y[weights != 0]
weights = weights[weights != 0]
if ax.get_xaxis().get_scale() == 'log':
data_x = np.log10(data_x)
if ax.get_yaxis().get_scale() == 'log':
data_y = np.log10(data_y)

Check warning on line 1563 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1555-L1563

Added lines #L1555 - L1563 were not covered by tests

ncompress = kwargs.pop('ncompress', 'equal')
nplot = kwargs.pop('nplot_2d', 1000)
label = kwargs.pop('label', None)
zorder = kwargs.pop('zorder', 1)
levels = kwargs.pop('levels', [0.95, 0.68])

Check warning on line 1569 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1565-L1569

Added lines #L1565 - L1569 were not covered by tests

color = kwargs.pop('color', ax._get_lines.get_next_color())
facecolor = kwargs.pop('facecolor', True)
edgecolor = kwargs.pop('edgecolor', None)
cmap = kwargs.pop('cmap', None)
facecolor, edgecolor, cmap = set_colors(c=color, fc=facecolor,

Check warning on line 1575 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1571-L1575

Added lines #L1571 - L1575 were not covered by tests
ec=edgecolor, cmap=cmap)

nde_epochs = kwargs.pop('nde_epochs', 1000)
nde_lr = kwargs.pop('nde_lr', 1e-4)
nde_hidden_layers = kwargs.pop('nde_hidden_layers', [50, 50])
nde_number_networks = kwargs.pop('nde_number_networks', 6)
nde_clustering = kwargs.pop('nde_clustering', False)

Check warning on line 1582 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1578-L1582

Added lines #L1578 - L1582 were not covered by tests

kwargs.pop('q', None)

Check warning on line 1584 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1584

Added line #L1584 was not covered by tests

q = kwargs.pop('q', 5)
q = quantile_plot_interval(q=q)
xmin = quantile(data_x, q[0], weights)
xmax = quantile(data_x, q[-1], weights)
ymin = quantile(data_y, q[0], weights)
ymax = quantile(data_y, q[-1], weights)
X, Y = np.mgrid[xmin:xmax:1j*np.sqrt(nplot), ymin:ymax:1j*np.sqrt(nplot)]

Check warning on line 1592 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1586-L1592

Added lines #L1586 - L1592 were not covered by tests

cov = np.cov(data_x, data_y, aweights=weights)
tri, w = triangular_sample_compression_2d(data_x, data_y, cov,

Check warning on line 1595 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1594-L1595

Added lines #L1594 - L1595 were not covered by tests
weights, ncompress)
data = np.array([tri.x, tri.y]).T

Check warning on line 1597 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1597

Added line #L1597 was not covered by tests

try:
from margarine.maf import MAF
from margarine.clustered import clusterMAF

Check warning on line 1601 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1599-L1601

Added lines #L1599 - L1601 were not covered by tests

except ImportError:
raise ImportError("Please install margarine to use nde_2d")

Check warning on line 1604 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1603-L1604

Added lines #L1603 - L1604 were not covered by tests

if nde_clustering:
nde = clusterMAF(data, weights=w, lr=nde_lr,

Check warning on line 1607 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1606-L1607

Added lines #L1606 - L1607 were not covered by tests
hidden_layers=nde_hidden_layers,
number_networks=nde_number_networks)
else:
nde = MAF(data, weights=w, lr=nde_lr,

Check warning on line 1611 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1611

Added line #L1611 was not covered by tests
hidden_layers=nde_hidden_layers,
number_networks=nde_number_networks)
nde.train(epochs=nde_epochs, early_stop=True)

Check warning on line 1614 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1614

Added line #L1614 was not covered by tests

if nde_clustering:
P = nde.log_prob(np.array([X.ravel(), Y.ravel()]).T)

Check warning on line 1617 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1616-L1617

Added lines #L1616 - L1617 were not covered by tests
else:
P = nde.log_prob(np.array([X.ravel(), Y.ravel()]).T).numpy()
P = np.exp(P - P.max()).reshape(X.shape)

Check warning on line 1620 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1619-L1620

Added lines #L1619 - L1620 were not covered by tests

levels = iso_probability_contours(P, contours=levels)
if ax.get_xaxis().get_scale() == 'log':
X = 10**X
if ax.get_yaxis().get_scale() == 'log':
Y = 10**Y

Check warning on line 1626 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1622-L1626

Added lines #L1622 - L1626 were not covered by tests

if facecolor not in [None, 'None', 'none']:
linewidths = kwargs.pop('linewidths', 0.5)
contf = ax.contourf(X, Y, P, levels=levels, cmap=cmap, zorder=zorder,

Check warning on line 1630 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1628-L1630

Added lines #L1628 - L1630 were not covered by tests
vmin=0, vmax=P.max(), *args, **kwargs)
contf.set_cmap(cmap)
ax.add_patch(plt.Rectangle((0, 0), 0, 0, lw=2, label=label,

Check warning on line 1633 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1632-L1633

Added lines #L1632 - L1633 were not covered by tests
fc=cmap(0.999), ec=cmap(0.32)))
cmap = None

Check warning on line 1635 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1635

Added line #L1635 was not covered by tests
else:
linewidths = kwargs.pop('linewidths',

Check warning on line 1637 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1637

Added line #L1637 was not covered by tests
plt.rcParams.get('lines.linewidth'))
contf = None
ax.add_patch(

Check warning on line 1640 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1639-L1640

Added lines #L1639 - L1640 were not covered by tests
plt.Rectangle((0, 0), 0, 0, lw=2, label=label,
fc='None' if cmap is None else cmap(0.999),
ec=edgecolor if cmap is None else cmap(0.32))
)

vmin, vmax = match_contour_to_contourf(levels, vmin=0, vmax=P.max())
cont = ax.contour(X, Y, P, levels=levels, zorder=zorder,

Check warning on line 1647 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1646-L1647

Added lines #L1646 - L1647 were not covered by tests
vmin=vmin, vmax=vmax, linewidths=linewidths,
colors=edgecolor, cmap=cmap, *args, **kwargs)

return contf, cont

Check warning on line 1651 in anesthetic/plot.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/plot.py#L1651

Added line #L1651 was not covered by tests


def hist_plot_2d(ax, data_x, data_y, *args, **kwargs):
"""Plot a 2d marginalised distribution as a histogram.

Expand Down
Loading
Loading