diff --git a/README.rst b/README.rst index 6a6ad009..d9b8554e 100644 --- a/README.rst +++ b/README.rst @@ -2,7 +2,7 @@ anesthetic: nested sampling post-processing =========================================== :Authors: Will Handley and Lukas Hergt -:Version: 2.8.1 +:Version: 2.9.0 :Homepage: https://github.com/handley-lab/anesthetic :Documentation: http://anesthetic.readthedocs.io/ diff --git a/anesthetic/_version.py b/anesthetic/_version.py index 80e22f7a..387cfacc 100644 --- a/anesthetic/_version.py +++ b/anesthetic/_version.py @@ -1 +1 @@ -__version__ = '2.8.1' +__version__ = '2.9.0' diff --git a/anesthetic/plot.py b/anesthetic/plot.py index 1b1e065a..2283533e 100644 --- a/anesthetic/plot.py +++ b/anesthetic/plot.py @@ -974,6 +974,176 @@ def kde_plot_1d(ax, data, *args, **kwargs): return ans +def nde_plot_1d(ax, data, *args, **kwargs): + """Plot a 1d marginalised distribution. + + This functions as a wrapper around :meth:`matplotlib.axes.Axes.plot`, with + a normalising flow as a neural density estimation (NDE) provided by + :class:`margarine.maf.MAF` or :class:`margarine.clustered.clusterMAF`. + See also `Bevins et al. (2022) `_. + + All remaining keyword arguments are passed onwards. + + Parameters + ---------- + ax : :class:`matplotlib.axes.Axes` + Axis object to plot on. + + data : np.array + Samples to generate kernel density estimator. + + weights : np.array, optional + Sample weights. + + ncompress : int, str, default=False + Degree of compression. + + * If ``False``: no compression. + * If ``True``: compresses to the channel capacity, equivalent to + ``ncompress='entropy'``. + * If ``int``: desired number of samples after compression. + * If ``str``: determine number from the Huggins-Roy family of + effective samples in :func:`anesthetic.utils.neff` + with ``beta=ncompress``. + + nplot_1d : int, default=100 + Number of plotting points to use. + + levels : list + Values at which to draw iso-probability lines. + Default: [0.95, 0.68] + + q : int or float or tuple, default=5 + Quantile to determine the data range to be plotted. + + * ``0``: full data range, i.e. ``q=0`` --> quantile range (0, 1) + * ``int``: q-sigma range, e.g. ``q=1`` --> quantile range (0.16, 0.84) + * ``float``: percentile, e.g. ``q=0.8`` --> quantile range (0.1, 0.9) + * ``tuple``: quantile range, e.g. (0.16, 0.84) + + facecolor : bool or string, default=False + If set to True then the 1d plot will be shaded with the value of the + ``color`` kwarg. Set to a string such as 'blue', 'k', 'r', 'C1' ect. + to define the color of the shading directly. + + beta : int, float, default = 1 + The value of beta used to calculate the number of effective samples + + nde_epochs : int, default=1000 + Number of epochs to train the NDE for. + + nde_lr : float, default=1e-4 + Learning rate for the NDE. + + nde_hidden_layers : list, default=[50, 50] + Number of hidden layers for the NDE and number of nodes + in the layers. + + nde_number_networks : int, default=6 + Number of networks to use in the NDE. + + nde_clustering : bool, default=False + Whether to use clustering in the NDE. + + Returns + ------- + lines : :class:`matplotlib.lines.Line2D` + A list of line objects representing the plotted data (same as + :meth:`matplotlib.axes.Axes.plot` command). + + """ + kwargs = normalize_kwargs(kwargs) + weights = kwargs.pop('weights', None) + if weights is not None: + data = data[weights != 0] + weights = weights[weights != 0] + if ax.get_xaxis().get_scale() == 'log': + data = np.log10(data) + + ncompress = kwargs.pop('ncompress', False) + nplot = kwargs.pop('nplot_1d', 100) + levels = kwargs.pop('levels', [0.95, 0.68]) + density = kwargs.pop('density', False) + + nde_epochs = kwargs.pop('nde_epochs', 1000) + nde_lr = kwargs.pop('nde_lr', 1e-4) + nde_hidden_layers = kwargs.pop('nde_hidden_layers', [50, 50]) + nde_number_networks = kwargs.pop('nde_number_networks', 6) + nde_clustering = kwargs.pop('nde_clustering', False) + + cmap = kwargs.pop('cmap', None) + color = kwargs.pop('color', (ax._get_lines.get_next_color() + if cmap is None + else plt.get_cmap(cmap)(0.68))) + facecolor = kwargs.pop('facecolor', False) + if 'edgecolor' in kwargs: + edgecolor = kwargs.pop('edgecolor') + if edgecolor: + color = edgecolor + else: + edgecolor = color + + q = kwargs.pop('q', 5) + q = quantile_plot_interval(q=q) + xmin = quantile(data, q[0], weights) + xmax = quantile(data, q[-1], weights) + x = np.linspace(xmin, xmax, nplot) + + data_compressed, w = sample_compression_1d(data, weights, ncompress) + if w is None: + w = np.ones(data_compressed.shape[0]) + + try: + from margarine.maf import MAF + from margarine.clustered import clusterMAF + + except ImportError: + raise ImportError("Please install margarine to use nde_1d") + + if nde_clustering: + nde = clusterMAF(np.array([data_compressed]).T, weights=w, + number_networks=nde_number_networks, + hidden_layers=nde_hidden_layers, + lr=nde_lr) + else: + nde = MAF(np.array([data_compressed]).T, weights=w, + number_networks=nde_number_networks, + hidden_layers=nde_hidden_layers, + lr=nde_lr) + nde.train(epochs=nde_epochs, early_stop=True) + + if nde_clustering: + pp = nde.log_prob(np.array([x]).T) + else: + pp = nde.log_prob(np.array([x]).T).numpy() + pp = np.exp(pp - pp.max()) + + area = np.trapz(x=x, y=pp) if density else 1 + + if ax.get_xaxis().get_scale() == 'log': + x = 10**x + ans = ax.plot(x, pp/area, color=color, *args, **kwargs) + + if facecolor and facecolor not in [None, 'None', 'none']: + if facecolor is True: + facecolor = color + c = iso_probability_contours(pp, contours=levels) + cmap = basic_cmap(facecolor) + fill = [] + for j in range(len(c)-1): + fill.append(ax.fill_between(x, pp, where=pp >= c[j], + color=cmap(c[j]), edgecolor=edgecolor)) + + ans = ans, fill + + if density: + ax.set_ylim(bottom=0) + else: + ax.set_ylim(0, 1.1) + + return ans + + def hist_plot_1d(ax, data, *args, **kwargs): """Plot a 1d histogram. @@ -1312,6 +1482,175 @@ def kde_contour_plot_2d(ax, data_x, data_y, *args, **kwargs): return contf, cont +def nde_contour_plot_2d(ax, data_x, data_y, *args, **kwargs): + """Plot a 2d marginalised distribution as contours. + + This functions as a wrapper around :meth:`matplotlib.axes.Axes.contour` + and :meth:`matplotlib.axes.Axes.contourf` with a normalising flow as a + neural density estimator (NDE) provided by :class:`margarine.maf.MAF` or + :class:`margarine.clustered.clusterMAF`. + See also `Bevins et al. (2022) `_. + + All remaining keyword arguments are passed onwards to both functions. + + Parameters + ---------- + ax : :class:`matplotlib.axes.Axes` + Axis object to plot on. + + data_x, data_y : np.array + The x and y coordinates of uniformly weighted samples to generate + kernel density estimator. + + weights : np.array, optional + Sample weights. + + levels : list, optional + Amount of mass within each iso-probability contour. + Has to be ordered from outermost to innermost contour. + Default: [0.95, 0.68] + + ncompress : int, str, default='equal' + Degree of compression. + + * If ``int``: desired number of samples after compression. + * If ``False``: no compression. + * If ``True``: compresses to the channel capacity, equivalent to + ``ncompress='entropy'``. + * If ``str``: determine number from the Huggins-Roy family of + effective samples in :func:`anesthetic.utils.neff` + with ``beta=ncompress``. + + nplot_2d : int, default=1000 + Number of plotting points to use. + + nde_epochs : int, default=1000 + Number of epochs to train the NDE for. + + nde_lr : float, default=1e-4 + Learning rate for the NDE. + + nde_hidden_layers : list, default=[50, 50] + Number of hidden layers for the NDE and number of nodes + in the layers. + + nde_number_networks : int, default=6 + Number of networks to use in the NDE. + + nde_clustering : bool, default=False + Whether to use clustering in the NDE. + + Returns + ------- + c : :class:`matplotlib.contour.QuadContourSet` + A set of contourlines or filled regions. + + """ + kwargs = normalize_kwargs(kwargs, dict(linewidths=['linewidth', 'lw'], + linestyles=['linestyle', 'ls'], + color=['c'], + facecolor=['fc'], + edgecolor=['ec'])) + + weights = kwargs.pop('weights', None) + if weights is not None: + data_x = data_x[weights != 0] + data_y = data_y[weights != 0] + weights = weights[weights != 0] + if ax.get_xaxis().get_scale() == 'log': + data_x = np.log10(data_x) + if ax.get_yaxis().get_scale() == 'log': + data_y = np.log10(data_y) + + ncompress = kwargs.pop('ncompress', 'equal') + nplot = kwargs.pop('nplot_2d', 1000) + label = kwargs.pop('label', None) + zorder = kwargs.pop('zorder', 1) + levels = kwargs.pop('levels', [0.95, 0.68]) + + color = kwargs.pop('color', ax._get_lines.get_next_color()) + facecolor = kwargs.pop('facecolor', True) + edgecolor = kwargs.pop('edgecolor', None) + cmap = kwargs.pop('cmap', None) + facecolor, edgecolor, cmap = set_colors(c=color, fc=facecolor, + ec=edgecolor, cmap=cmap) + + nde_epochs = kwargs.pop('nde_epochs', 1000) + nde_lr = kwargs.pop('nde_lr', 1e-4) + nde_hidden_layers = kwargs.pop('nde_hidden_layers', [50, 50]) + nde_number_networks = kwargs.pop('nde_number_networks', 6) + nde_clustering = kwargs.pop('nde_clustering', False) + + kwargs.pop('q', None) + + q = kwargs.pop('q', 5) + q = quantile_plot_interval(q=q) + xmin = quantile(data_x, q[0], weights) + xmax = quantile(data_x, q[-1], weights) + ymin = quantile(data_y, q[0], weights) + ymax = quantile(data_y, q[-1], weights) + X, Y = np.mgrid[xmin:xmax:1j*np.sqrt(nplot), ymin:ymax:1j*np.sqrt(nplot)] + + cov = np.cov(data_x, data_y, aweights=weights) + tri, w = triangular_sample_compression_2d(data_x, data_y, cov, + weights, ncompress) + data = np.array([tri.x, tri.y]).T + + try: + from margarine.maf import MAF + from margarine.clustered import clusterMAF + + except ImportError: + raise ImportError("Please install margarine to use nde_2d") + + if nde_clustering: + nde = clusterMAF(data, weights=w, lr=nde_lr, + hidden_layers=nde_hidden_layers, + number_networks=nde_number_networks) + else: + nde = MAF(data, weights=w, lr=nde_lr, + hidden_layers=nde_hidden_layers, + number_networks=nde_number_networks) + nde.train(epochs=nde_epochs, early_stop=True) + + if nde_clustering: + P = nde.log_prob(np.array([X.ravel(), Y.ravel()]).T) + else: + P = nde.log_prob(np.array([X.ravel(), Y.ravel()]).T).numpy() + P = np.exp(P - P.max()).reshape(X.shape) + + levels = iso_probability_contours(P, contours=levels) + if ax.get_xaxis().get_scale() == 'log': + X = 10**X + if ax.get_yaxis().get_scale() == 'log': + Y = 10**Y + + if facecolor not in [None, 'None', 'none']: + linewidths = kwargs.pop('linewidths', 0.5) + contf = ax.contourf(X, Y, P, levels=levels, cmap=cmap, zorder=zorder, + vmin=0, vmax=P.max(), *args, **kwargs) + contf.set_cmap(cmap) + ax.add_patch(plt.Rectangle((0, 0), 0, 0, lw=2, label=label, + fc=cmap(0.999), ec=cmap(0.32))) + cmap = None + else: + linewidths = kwargs.pop('linewidths', + plt.rcParams.get('lines.linewidth')) + contf = None + ax.add_patch( + plt.Rectangle((0, 0), 0, 0, lw=2, label=label, + fc='None' if cmap is None else cmap(0.999), + ec=edgecolor if cmap is None else cmap(0.32)) + ) + + vmin, vmax = match_contour_to_contourf(levels, vmin=0, vmax=P.max()) + cont = ax.contour(X, Y, P, levels=levels, zorder=zorder, + vmin=vmin, vmax=vmax, linewidths=linewidths, + colors=edgecolor, cmap=cmap, *args, **kwargs) + + return contf, cont + + def hist_plot_2d(ax, data_x, data_y, *args, **kwargs): """Plot a 2d marginalised distribution as a histogram. diff --git a/anesthetic/plotting/_core.py b/anesthetic/plotting/_core.py index fb394ced..f2d32fac 100644 --- a/anesthetic/plotting/_core.py +++ b/anesthetic/plotting/_core.py @@ -9,10 +9,14 @@ def _process_docstring(doc): " - 'kde_1d' : 1d Kernel Density Estimation plot\n" " - 'fastkde_1d' : 1d Kernel Density Estimation plot" " with fastkde package\n" + " - 'nde_1d' : 1d Neural Density Estimation plot" + " with margarine package\n" " - 'hist_2d' : 2d histogram (DataFrame only)\n" " - 'kde_2d' : 2d Kernel Density Estimation plot (DataFrame only)\n" " - 'fastkde_2d' : 2d Kernel Density Estimation plot" " with fastkde package (DataFrame only)\n" + " - 'nde_2d' : 2d Neural Density Estimation plot" + " with margarine package\n" " - 'scatter_2d' : 2d scatter plot (DataFrame only)\n" ) return doc[:i] + e + doc[i:] @@ -22,10 +26,10 @@ class PlotAccessor(_PlotAccessor): # noqa: disable=D101 __doc__ = _process_docstring(_PlotAccessor.__doc__) _common_kinds = _PlotAccessor._common_kinds \ - + ("hist_1d", "kde_1d", "fastkde_1d") + + ("hist_1d", "kde_1d", "fastkde_1d", "nde_1d") _series_kinds = _PlotAccessor._series_kinds + () _dataframe_kinds = _PlotAccessor._dataframe_kinds \ - + ("hist_2d", "kde_2d", "fastkde_2d", "scatter_2d") + + ("hist_2d", "kde_2d", "fastkde_2d", "nde_2d", "scatter_2d") _all_kinds = _common_kinds + _series_kinds + _dataframe_kinds def hist_1d(self, **kwargs): @@ -40,6 +44,10 @@ def fastkde_1d(self, **kwargs): """KDE plot: See :func:`anesthetic.plot.fastkde_plot_1d`.""" return self(kind="fastkde_1d", **kwargs) + def nde_1d(self, **kwargs): + """NDE plot: See :func:`anesthetic.plot.nde_plot_1d`.""" + return self(kind="nde_1d", **kwargs) + def kde_2d(self, x, y, **kwargs): """KDE plot: See :func:`anesthetic.plot.kde_contour_plot_2d`.""" return self(kind="kde_2d", x=x, y=y, **kwargs) @@ -48,6 +56,10 @@ def fastkde_2d(self, x, y, **kwargs): """KDE plot: See :func:`anesthetic.plot.fastkde_contour_plot_2d`.""" return self(kind="fastkde_2d", x=x, y=y, **kwargs) + def nde_2d(self, x, y, **kwargs): + """NDE plot: See :func:`anesthetic.plot.nde_contour_plot_2d`.""" + return self(kind="nde_2d", x=x, y=y, **kwargs) + def hist_2d(self, x, y, **kwargs): """Histogram plot: See :func:`anesthetic.plot.hist_plot_2d`.""" return self(kind="hist_2d", x=x, y=y, **kwargs) diff --git a/anesthetic/plotting/_matplotlib/__init__.py b/anesthetic/plotting/_matplotlib/__init__.py index 9925a020..46f268e2 100644 --- a/anesthetic/plotting/_matplotlib/__init__.py +++ b/anesthetic/plotting/_matplotlib/__init__.py @@ -27,8 +27,10 @@ KdePlot, Kde1dPlot, FastKde1dPlot, + Nde1dPlot, Kde2dPlot, FastKde2dPlot, + Nde2dPlot, HistPlot, Hist1dPlot, Hist2dPlot, @@ -62,7 +64,9 @@ PLOT_CLASSES['hist_1d'] = Hist1dPlot PLOT_CLASSES['kde_1d'] = Kde1dPlot PLOT_CLASSES['fastkde_1d'] = FastKde1dPlot +PLOT_CLASSES['nde_1d'] = Nde1dPlot PLOT_CLASSES['hist_2d'] = Hist2dPlot PLOT_CLASSES['kde_2d'] = Kde2dPlot PLOT_CLASSES['fastkde_2d'] = FastKde2dPlot +PLOT_CLASSES['nde_2d'] = Nde2dPlot PLOT_CLASSES['scatter_2d'] = ScatterPlot2d diff --git a/anesthetic/plotting/_matplotlib/hist.py b/anesthetic/plotting/_matplotlib/hist.py index 763591cb..b9cb31c8 100644 --- a/anesthetic/plotting/_matplotlib/hist.py +++ b/anesthetic/plotting/_matplotlib/hist.py @@ -17,8 +17,10 @@ kde_contour_plot_2d, hist_plot_2d, fastkde_contour_plot_2d, + nde_contour_plot_2d, kde_plot_1d, fastkde_plot_1d, + nde_plot_1d, hist_plot_1d, quantile_plot_interval, ) @@ -165,6 +167,29 @@ def _plot( return fastkde_plot_1d(ax, y, *args, **kwds) +class Nde1dPlot(KdePlot): + # noqa: disable=D101 + @property + def _kind(self) -> Literal["nde_1d"]: + return "nde_1d" + + # noqa: disable=D101 + @classmethod + def _plot( + cls, + ax, + y, + style=None, + ind=None, + column_num=None, + stacking_id=None, + **kwds, + ): + args = (style,) if style is not None else tuple() + kwds.pop('bw_method', None) + return nde_plot_1d(ax, y, *args, **kwds) + + class Hist1dPlot(HistPlot): # noqa: disable=D101 @property @@ -244,6 +269,17 @@ def _plot(cls, ax, x, y, **kwds): return fastkde_contour_plot_2d(ax, x, y, **kwds) +class Nde2dPlot(_CompressedMPLPlot, _PlanePlot2d): + # noqa: disable=D101 + @property + def _kind(self) -> Literal["nde_2d"]: + return "nde_2d" + + @classmethod + def _plot(cls, ax, x, y, **kwds): + return nde_contour_plot_2d(ax, x, y, **kwds) + + class Hist2dPlot(_WeightedMPLPlot, _PlanePlot2d): # noqa: disable=D101 @property diff --git a/anesthetic/samples.py b/anesthetic/samples.py index 182a75ff..c1417f4a 100644 --- a/anesthetic/samples.py +++ b/anesthetic/samples.py @@ -115,6 +115,7 @@ def plot_1d(self, axes=None, *args, **kwargs): * 'hist_1d': :func:`anesthetic.plot.hist_plot_1d` * 'kde_1d': :func:`anesthetic.plot.kde_plot_1d` * 'fastkde_1d': :func:`anesthetic.plot.fastkde_plot_1d` + * 'nde_1d': :func:`anesthetic.plot.nde_plot_1d` Warning -- while the other pandas plotting options {'line', 'bar', 'barh', 'area', 'pie'} are also accessible, these @@ -215,6 +216,7 @@ def plot_2d(self, axes=None, *args, **kwargs): - 'kde_1d': :func:`anesthetic.plot.kde_plot_1d` - 'hist_1d': :func:`anesthetic.plot.hist_plot_1d` - 'fastkde_1d': :func:`anesthetic.plot.fastkde_plot_1d` + - 'nde_1d': :func:`anesthetic.plot.nde_plot_1d` - 'kde': :meth:`pandas.Series.plot.kde` - 'hist': :meth:`pandas.Series.plot.hist` - 'box': :meth:`pandas.Series.plot.box` @@ -226,6 +228,7 @@ def plot_2d(self, axes=None, *args, **kwargs): - 'hist_2d': :func:`anesthetic.plot.hist_plot_2d` - 'scatter_2d': :func:`anesthetic.plot.scatter_plot_2d` - 'fastkde_2d': :func:`anesthetic.plot.fastkde_contour_plot_2d` + - 'nde_2d': :func:`anesthetic.plot.nde_contour_plot_2d` - 'kde': :meth:`pandas.DataFrame.plot.kde` - 'scatter': :meth:`pandas.DataFrame.plot.scatter` - 'hexbin': :meth:`pandas.DataFrame.plot.hexbin` @@ -368,6 +371,7 @@ def plot_2d(self, axes=None, *args, **kwargs): 'hist': {'diagonal': 'hist_1d', 'lower': 'hist_2d'}, 'hist_1d': {'diagonal': 'hist_1d'}, 'hist_2d': {'lower': 'hist_2d'}, + 'nde': {'diagonal': 'nde_1d', 'lower': 'nde_2d'}, } def importance_sample(self, logL_new, action='add', inplace=False): diff --git a/docs/source/conf.py b/docs/source/conf.py index 466a88e2..3fe65354 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -241,7 +241,8 @@ def get_version(short=False): 'scipy':('https://docs.scipy.org/doc/scipy', None), 'pandas':('https://pandas.pydata.org/pandas-docs/stable', None), 'matplotlib':('https://matplotlib.org/stable', None), - 'getdist':('https://getdist.readthedocs.io/en/latest/', None) + 'getdist':('https://getdist.readthedocs.io/en/latest/', None), + 'margarine':('https://margarine.readthedocs.io/en/latest/', None) } # -- Options for todo extension ---------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index caa28267..30d6ce52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,9 +65,10 @@ test = ["pytest", "pytest-cov", "flake8", "pydocstyle", "packaging", "pre-commit ultranest = ["h5py"] astropy = ["astropy"] fastkde = ["fastkde"] +nde = ["margarine"] getdist = ["getdist"] hdf5 = ["tables"] -all = ["h5py", "astropy", "fastkde", "getdist", "tables"] +all = ["h5py", "astropy", "fastkde", "getdist", "tables", "margarine"] [project.scripts] anesthetic = "anesthetic.scripts:gui" diff --git a/tests/test_plot.py b/tests/test_plot.py index 8864d905..5ebe27bf 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -7,8 +7,9 @@ import matplotlib.pyplot as plt import matplotlib.gridspec as gs from anesthetic.plot import (make_1d_axes, make_2d_axes, kde_plot_1d, - fastkde_plot_1d, hist_plot_1d, hist_plot_2d, - fastkde_contour_plot_2d, kde_contour_plot_2d, + fastkde_plot_1d, nde_plot_1d, hist_plot_1d, + hist_plot_2d, fastkde_contour_plot_2d, + nde_contour_plot_2d, kde_contour_plot_2d, scatter_plot_2d, quantile_plot_interval, basic_cmap, AxesSeries, AxesDataFrame) from numpy.testing import assert_array_equal @@ -22,7 +23,8 @@ from scipy.special import erf from scipy.interpolate import interp1d from scipy import stats -from utils import skipif_no_fastkde, astropy_mark_xfail, fastkde_mark_xfail +from utils import (skipif_no_fastkde, skipif_no_margarine, + astropy_mark_xfail, fastkde_mark_xfail) @pytest.fixture(autouse=True) @@ -342,7 +344,8 @@ def test_2d_axes_scatter(axesparams, params, upper): @pytest.mark.parametrize('plot_1d', [kde_plot_1d, - skipif_no_fastkde(fastkde_plot_1d)]) + skipif_no_fastkde(fastkde_plot_1d), + skipif_no_margarine(nde_plot_1d)]) def test_kde_plot_1d(plot_1d): fig, ax = plt.subplots() np.random.seed(0) @@ -555,7 +558,8 @@ def test_hist_plot_2d(): @pytest.mark.parametrize('plot_1d', [kde_plot_1d, - skipif_no_fastkde(fastkde_plot_1d)]) + skipif_no_fastkde(fastkde_plot_1d), + skipif_no_margarine(nde_plot_1d)]) @pytest.mark.parametrize('s', [1, 2]) def test_1d_density_kwarg(plot_1d, s): np.random.seed(0) @@ -590,7 +594,8 @@ def test_1d_density_kwarg(plot_1d, s): @pytest.mark.parametrize('contour_plot_2d', [kde_contour_plot_2d, - skipif_no_fastkde(fastkde_contour_plot_2d)]) + skipif_no_fastkde(fastkde_contour_plot_2d), + skipif_no_margarine(nde_contour_plot_2d)]) def test_contour_plot_2d(contour_plot_2d): fig, ax = plt.subplots() np.random.seed(1) @@ -689,7 +694,8 @@ def test_kde_plot_nplot(): @pytest.mark.parametrize('contour_plot_2d', [kde_contour_plot_2d, - skipif_no_fastkde(fastkde_contour_plot_2d)]) + skipif_no_fastkde(fastkde_contour_plot_2d), + skipif_no_margarine(nde_contour_plot_2d)]) @pytest.mark.parametrize('levels', [[0.9], [0.9, 0.6], [0.9, 0.6, 0.3], @@ -802,6 +808,7 @@ def test_make_axes_logscale(): @pytest.mark.parametrize('plot_1d', [kde_plot_1d, skipif_no_fastkde(fastkde_plot_1d), + skipif_no_margarine(nde_plot_1d), hist_plot_1d]) def test_logscale_1d(plot_1d): np.random.seed(42) @@ -811,7 +818,7 @@ def test_logscale_1d(plot_1d): fig, ax = plt.subplots() ax.set_xscale('log') p = plot_1d(ax, data) - if 'kde' in plot_1d.__name__: + if 'kde' in plot_1d.__name__ or 'nde' in plot_1d.__name__: amax = abs(np.log10(p[0].get_xdata()[np.argmax(p[0].get_ydata())])) else: amax = abs(np.log10(p[1][np.argmax(p[0])])) @@ -847,6 +854,7 @@ def test_logscale_hist_kwargs(b): @pytest.mark.parametrize('plot_2d', [kde_contour_plot_2d, skipif_no_fastkde(fastkde_contour_plot_2d), + skipif_no_margarine(nde_contour_plot_2d), hist_plot_2d, scatter_plot_2d]) def test_logscale_2d(plot_2d): np.random.seed(0) @@ -859,7 +867,7 @@ def test_logscale_2d(plot_2d): fig, ax = plt.subplots() ax.set_xscale('log') p = plot_2d(ax, x, logy) - if 'kde' in plot_2d.__name__: + if 'kde' in plot_2d.__name__ or 'nde' in plot_2d.__name__: if version.parse(matplotlib.__version__) >= version.parse('3.8.0'): xmax, ymax = p[0].get_paths()[1].vertices[0].T else: @@ -883,7 +891,7 @@ def test_logscale_2d(plot_2d): fig, ax = plt.subplots() ax.set_yscale('log') p = plot_2d(ax, logx, y) - if 'kde' in plot_2d.__name__: + if 'kde' in plot_2d.__name__ or 'nde' in plot_2d.__name__: if version.parse(matplotlib.__version__) >= version.parse('3.8.0'): xmax, ymax = p[0].get_paths()[1].vertices[0].T else: @@ -908,7 +916,7 @@ def test_logscale_2d(plot_2d): ax.set_xscale('log') ax.set_yscale('log') p = plot_2d(ax, x, y) - if 'kde' in plot_2d.__name__: + if 'kde' in plot_2d.__name__ or 'nde' in plot_2d.__name__: if version.parse(matplotlib.__version__) >= version.parse('3.8.0'): xmax, ymax = p[0].get_paths()[1].vertices[0].T else: diff --git a/tests/test_samples.py b/tests/test_samples.py index a40b85e5..765bace5 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -21,7 +21,8 @@ from pandas.testing import assert_frame_equal from matplotlib.colors import to_hex from scipy.stats import ks_2samp, kstest, norm -from utils import skipif_no_fastkde, astropy_mark_xfail, fastkde_mark_skip +from utils import (skipif_no_fastkde, skipif_no_margarine, + astropy_mark_xfail, fastkde_mark_skip) @pytest.fixture(autouse=True) @@ -293,7 +294,8 @@ def test_plot_2d_legend(): assert labels == ['l1', 'l2'] -@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde')]) +@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde'), + skipif_no_margarine('nde')]) def test_plot_2d_colours(kind): np.random.seed(3) gd = read_chains("./tests/example_data/gd") @@ -344,7 +346,8 @@ def test_plot_2d_colours(kind): dict(cmap="viridis"), dict(colormap="viridis")]) @pytest.mark.parametrize('kind', ['kde', 'hist', 'default', - skipif_no_fastkde('fastkde')]) + skipif_no_fastkde('fastkde'), + skipif_no_margarine('nde')]) def test_plot_2d_kwargs(kind, kwargs): np.random.seed(42) pc = read_chains("./tests/example_data/pc") @@ -352,7 +355,8 @@ def test_plot_2d_kwargs(kind, kwargs): pc.plot_2d(axes, kind=kind, **kwargs) -@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde')]) +@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde'), + skipif_no_margarine('nde')]) def test_plot_1d_colours(kind): np.random.seed(3) gd = read_chains("./tests/example_data/gd") @@ -432,7 +436,8 @@ def test_plot_1d_no_axes(): assert axes.iloc[2].get_xlabel() == 'x2' -@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde')]) +@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde'), + skipif_no_margarine('nde')]) def test_plot_logscale_1d(kind): ns = read_chains('./tests/example_data/pc') params = ['x0', 'x1', 'x2', 'x3', 'x4'] @@ -445,7 +450,7 @@ def test_plot_logscale_1d(kind): else: assert ax.get_xscale() == 'linear' ax = axes.loc['x2'] - if 'kde' in kind: + if 'kde' in kind or 'nde' in kind: p = ax.get_children() arg = np.argmax(p[0].get_ydata()) pmax = np.log10(p[0].get_xdata()[arg]) @@ -457,7 +462,8 @@ def test_plot_logscale_1d(kind): assert pmax == pytest.approx(-1, abs=d) -@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde')]) +@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde'), + skipif_no_margarine('nde')]) def test_plot_logscale_2d(kind): ns = read_chains('./tests/example_data/pc') params = ['x0', 'x1', 'x2', 'x3', 'x4'] @@ -972,7 +978,8 @@ def test_stats(): @pytest.mark.parametrize('kind', ['kde', 'hist', 'kde_1d', 'hist_1d', - skipif_no_fastkde('fastkde_1d')]) + skipif_no_fastkde('fastkde_1d'), + skipif_no_margarine('nde_1d')]) def test_masking_1d(kind): pc = read_chains("./tests/example_data/pc") mask = pc['x0'].to_numpy() > 0 @@ -982,7 +989,8 @@ def test_masking_1d(kind): @pytest.mark.parametrize('kind', ['kde', 'scatter', 'scatter_2d', 'kde_2d', - 'hist_2d', skipif_no_fastkde('fastkde_2d')]) + 'hist_2d', skipif_no_fastkde('fastkde_2d'), + skipif_no_margarine('nde_2d')]) def test_masking_2d(kind): pc = read_chains("./tests/example_data/pc") mask = pc['x0'].to_numpy() > 0 @@ -1514,12 +1522,29 @@ def test_samples_dot_plot(): axes = samples[['x0', 'x1', 'x2', 'x3', 'x4']].plot.fastkde_1d() assert len(axes.lines) == 5 plt.close("all") + axes = samples.plot.nde_2d('x0', 'x1') + assert axes.get_xlabel() == '$x_0$' + assert axes.get_ylabel() == '$x_1$' + assert len(axes.collections) > 0 + plt.close("all") + axes = samples.drop_labels().plot.nde_2d('x0', 'x1') + assert axes.get_xlabel() == 'x0' + assert axes.get_ylabel() == 'x1' + assert len(axes.collections) > 0 + plt.close("all") + axes = samples.x0.plot.nde_1d() + assert len(axes.lines) == 1 + plt.close("all") + axes = samples[['x0', 'x1', 'x2', 'x3', 'x4']].plot.nde_1d() + assert len(axes.lines) == 5 + plt.close("all") except ImportError: pass @pytest.mark.parametrize('kind', ['kde', 'hist', 'kde_1d', 'hist_1d', - skipif_no_fastkde('fastkde_1d')]) + skipif_no_fastkde('fastkde_1d'), + skipif_no_margarine('nde_1d')]) def test_samples_dot_plot_legend(kind): samples = read_chains('./tests/example_data/pc') fig, ax = plt.subplots() @@ -1572,7 +1597,8 @@ def test_samples_plot_labels(): assert samples.get_label(col) == ax.get_xlabel() -@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde')]) +@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde'), + skipif_no_margarine('nde')]) def test_samples_empty_1d_ylabels(kind): samples = read_chains('./tests/example_data/pc') columns = ['x0', 'x1', 'x2', 'x3', 'x4'] diff --git a/tests/utils.py b/tests/utils.py index 8c1c72cd..8b42a1ea 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,6 +32,22 @@ def skipif_no_fastkde(param): return pytest.param(param, marks=fastkde_mark_skip) +try: + import margarine # noqa: F401 +except ImportError: + pass +reason = "requires margarine package" +condition = 'margarine' not in sys.modules +raises = ImportError +margarine_mark_skip = pytest.mark.skipif(condition, reason=reason) +margarine_mark_xfail = pytest.mark.xfail(condition, raises=raises, + reason=reason) + + +def skipif_no_margarine(param): + return pytest.param(param, marks=margarine_mark_skip) + + try: import getdist # noqa: F401 except ImportError: