Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

histplot with normalization for each x-bin across hue #2656

Closed
zerothi opened this issue Sep 6, 2021 · 8 comments
Closed

histplot with normalization for each x-bin across hue #2656

zerothi opened this issue Sep 6, 2021 · 8 comments

Comments

@zerothi
Copy link

zerothi commented Sep 6, 2021

Consider data points that describe:

  • different animal species
  • each specie lives across different countries
  • a population for each specie in each country

The current plotting histograms considers each bin (animal specie) as a full collection of data and each hue (country) as separate histograms.
Since each specie is decoupled one would like to do something like this:

sns.histplot(population, x='specie', hue='country', stat='probability')

however this will not decouple the species from each other. What would be nice if one could normalize each specie individually across the countries, aka:

sns.histplot(population, x='specie', hue='country', stat='probability', bin_norm=True, common_norm=True)

which normalizes across hues for each bin plot. This would make all bin sizes equal height of 1, but the hue for each will show the probability distribution for the specie population and clearly indicate which specie has a tendency in each country.

I agree this isn't the regular histogram, but perhaps it would fit seaborn anyways? :)

@zerothi
Copy link
Author

zerothi commented Sep 6, 2021

I implement this locally doing something like this:

    bin_norm = plot_kws.pop("bin_norm", common_norm and sum_weight > 0.)

    if bin_norm:
        # Normalize each bin to 1 given the different
        # hues
        df = pd.DataFrame(histograms).sum(axis=1)
        for histogram in histograms.values():
            histogram /= df

just after the

    for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True):
        ....

loop.

@mwaskom
Copy link
Owner

mwaskom commented Sep 6, 2021

Hi, I'm not sure I follow this. Could you make an example with an actual dataset? Is this distinct from your other issue?

@mwaskom mwaskom added this to the v0.12.0 milestone Sep 6, 2021
@mwaskom mwaskom added wishlist and removed bug labels Sep 6, 2021
@mwaskom mwaskom removed this from the v0.12.0 milestone Sep 6, 2021
@zerothi
Copy link
Author

zerothi commented Sep 8, 2021

Ok, I have added a test example here.
Note that the code is so immensely long since it had to fix the problem encountered in #2655 and also adding a new keyword bin_norm, I don't have a particular good name for bin_norm but perhaps norm=str would be a way to go to differentiate the different norms?

Code
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def plot_univariate_histogram_0112(
        self,
        multiple,
        element,
        fill,
        common_norm,
        common_bins,
        shrink,
        kde,
        kde_kws,
        color,
        legend,
        line_kws,
        estimate_kws,
        **plot_kws,
    ):
    from numbers import Number
    from functools import partial
    import math
    import warnings

    import numpy as np
    import pandas as pd
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    import matplotlib.transforms as tx
    from matplotlib.colors import to_rgba
    from matplotlib.collections import LineCollection
    from scipy import stats

    from seaborn._core import (
        VectorPlotter,
    )
    from seaborn._statistics import (
        KDE,
        Histogram,
        ECDF,
    )
    from seaborn.axisgrid import (
        FacetGrid,
        _facet_docs,
    )
    from seaborn.utils import (
        remove_na,
        _kde_support,
        _normalize_kwargs,
        _check_argument,
        _assign_default_kwargs,
    )
    from seaborn.palettes import color_palette
    from seaborn.external import husl
    from seaborn._decorators import _deprecate_positional_args
    from seaborn._docstrings import (
        DocstringComponents,
        _core_docs,
    )

    # -- Default keyword dicts
    kde_kws = {} if kde_kws is None else kde_kws.copy()
    line_kws = {} if line_kws is None else line_kws.copy()
    estimate_kws = {} if estimate_kws is None else estimate_kws.copy()

    # --  Input checking
    _check_argument("multiple", ["layer", "stack", "fill", "dodge"], multiple)
    _check_argument("element", ["bars", "step", "poly"], element)

    if estimate_kws["discrete"] and element != "bars":
        raise ValueError("`element` must be 'bars' when `discrete` is True")

    auto_bins_with_weights = (
        "weights" in self.variables
        and estimate_kws["bins"] == "auto"
        and estimate_kws["binwidth"] is None
        and not estimate_kws["discrete"]
    )
    if auto_bins_with_weights:
        msg = (
            "`bins` cannot be 'auto' when using weights. "
            "Setting `bins=10`, but you will likely want to adjust."
        )
        warnings.warn(msg, UserWarning)
        estimate_kws["bins"] = 10

    # Simplify downstream code if we are not normalizing
    if estimate_kws["stat"] == "count":
        common_norm = False

    # Now initialize the Histogram estimator
    estimator = Histogram(**estimate_kws)
    histograms = {}

    # Do pre-compute housekeeping related to multiple groups
    # TODO best way to account for facet/semantic?
    if set(self.variables) - {"x", "y"}:

        all_data = self.comp_data.dropna()

        if common_bins:
            all_observations = all_data[self.data_variable]
            estimator.define_bin_params(
                all_observations,
                weights=all_data.get("weights", None),
            )

    else:
        common_norm = False

    # Estimate the smoothed kernel densities, for use later
    if kde:
        # TODO alternatively, clip at min/max bins?
        kde_kws.setdefault("cut", 0)
        kde_kws["cumulative"] = estimate_kws["cumulative"]
        log_scale = self._log_scaled(self.data_variable)
        densities = self._compute_univariate_density(
            self.data_variable,
            common_norm,
            common_bins,
            kde_kws,
            log_scale,
            warn_singular=False,
        )

    sum_weight = 0.
    if common_norm:
        for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True):
            if "weights" in self.variables:
                sum_weight += sub_data["weights"].sum()

    # First pass through the data to compute the histograms
    for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True):

        # Prepare the relevant data
        key = tuple(sub_vars.items())
        sub_data = sub_data.dropna()
        observations = sub_data[self.data_variable]

        if "weights" in self.variables:
            weights = sub_data["weights"]
        else:
            weights = None

        # Do the histogram computation
        heights, edges = estimator(observations, weights=weights)

        # Rescale the smoothed curve to match the histogram
        if kde and key in densities:
            density = densities[key]
            if estimator.cumulative:
                hist_norm = heights.max()
            else:
                hist_norm = (heights * np.diff(edges)).sum()
            densities[key] *= hist_norm

        # Convert edges back to original units for plotting
        if self._log_scaled(self.data_variable):
            edges = np.power(10, edges)

        # Pack the histogram data and metadata together
        orig_widths = np.diff(edges)
        widths = shrink * orig_widths
        edges = edges[:-1] + (1 - shrink) / 2 * orig_widths
        index = pd.MultiIndex.from_arrays([
            pd.Index(edges, name="edges"),
            pd.Index(widths, name="widths"),
        ])
        hist = pd.Series(heights, index=index, name="heights")

        # Apply scaling to normalize across groups
        if common_norm and weights is None:
            hist *= len(sub_data) / len(all_data)
        elif common_norm:
            hist *= weights.sum() / sum_weight

        # Store the finalized histogram data for future plotting
        histograms[key] = hist

    bin_norm = plot_kws.pop("bin_norm", False)
    if common_norm and bin_norm:
        # Normalize each bin to 1 given the different hues
        column_sum = pd.DataFrame(histograms).sum(axis=1)
        for histogram in histograms.values():
            histogram /= column_sum

    # Modify the histogram and density data to resolve multiple groups
    histograms, baselines = self._resolve_multiple(histograms, multiple)
    if kde:
        densities, _ = self._resolve_multiple(
            densities, None if multiple == "dodge" else multiple
        )

    # Set autoscaling-related meta
    sticky_stat = (0, 1) if multiple == "fill" else (0, np.inf)
    if multiple == "fill":
        # Filled plots should not have any margins
        bin_vals = histograms.index.to_frame()
        edges = bin_vals["edges"]
        widths = bin_vals["widths"]
        sticky_data = (
            edges.min(),
            edges.max() + widths.loc[edges.idxmax()]
        )
    else:
        sticky_data = []

    # --- Handle default visual attributes

    # Note: default linewidth is determined after plotting

    # Default color without a hue semantic should follow the color cycle
    # Note, this is fairly complicated and awkward, I'd like a better way
    # TODO and now with the ax business, this is just super annoying FIX!!
    if "hue" not in self.variables:
        if self.ax is None:
            default_color = "C0" if color is None else color
        else:
            if fill:
                if self.var_types[self.data_variable] == "datetime":
                    # Avoid drawing empty fill_between on date axis
                    # https://github.com/matplotlib/matplotlib/issues/17586
                    scout = None
                    default_color = plot_kws.pop("facecolor", color)
                    if default_color is None:
                        default_color = "C0"
                else:
                    artist = mpl.patches.Rectangle
                    plot_kws = _normalize_kwargs(plot_kws, artist)
                    scout = self.ax.fill_between([], [], color=color, **plot_kws)
                    default_color = tuple(scout.get_facecolor().squeeze())
            else:
                artist = mpl.lines.Line2D
                plot_kws = _normalize_kwargs(plot_kws, artist)
                scout, = self.ax.plot([], [], color=color, **plot_kws)
                default_color = scout.get_color()
            if scout is not None:
                scout.remove()

    # Default alpha should depend on other parameters
    if fill:
        # Note: will need to account for other grouping semantics if added
        if "hue" in self.variables and multiple == "layer":
            default_alpha = .5 if element == "bars" else .25
        elif kde:
            default_alpha = .5
        else:
            default_alpha = .75
    else:
        default_alpha = 1
    alpha = plot_kws.pop("alpha", default_alpha)  # TODO make parameter?

    hist_artists = []

    # Go back through the dataset and draw the plots
    for sub_vars, _ in self.iter_data("hue", reverse=True):

        key = tuple(sub_vars.items())
        hist = histograms[key].rename("heights").reset_index()
        bottom = np.asarray(baselines[key])

        ax = self._get_axes(sub_vars)

        # Define the matplotlib attributes that depend on semantic mapping
        if "hue" in self.variables:
            color = self._hue_map(sub_vars["hue"])
        else:
            color = default_color

        artist_kws = self._artist_kws(
            plot_kws, fill, element, multiple, color, alpha
        )

        if element == "bars":

            # Use matplotlib bar plotting

            plot_func = ax.bar if self.data_variable == "x" else ax.barh
            artists = plot_func(
                hist["edges"],
                hist["heights"] - bottom,
                hist["widths"],
                bottom,
                align="edge",
                **artist_kws,
            )
            for bar in artists:
                if self.data_variable == "x":
                    bar.sticky_edges.x[:] = sticky_data
                    bar.sticky_edges.y[:] = sticky_stat
                else:
                    bar.sticky_edges.x[:] = sticky_stat
                    bar.sticky_edges.y[:] = sticky_data

            hist_artists.extend(artists)

        else:

            # Use either fill_between or plot to draw hull of histogram
            if element == "step":

                final = hist.iloc[-1]
                x = np.append(hist["edges"], final["edges"] + final["widths"])
                y = np.append(hist["heights"], final["heights"])
                b = np.append(bottom, bottom[-1])

                if self.data_variable == "x":
                    step = "post"
                    drawstyle = "steps-post"
                else:
                    step = "post"  # fillbetweenx handles mapping internally
                    drawstyle = "steps-pre"

            elif element == "poly":

                x = hist["edges"] + hist["widths"] / 2
                y = hist["heights"]
                b = bottom

                step = None
                drawstyle = None

            if self.data_variable == "x":
                if fill:
                    artist = ax.fill_between(x, b, y, step=step, **artist_kws)
                else:
                    artist, = ax.plot(x, y, drawstyle=drawstyle, **artist_kws)
                artist.sticky_edges.x[:] = sticky_data
                artist.sticky_edges.y[:] = sticky_stat
            else:
                if fill:
                    artist = ax.fill_betweenx(x, b, y, step=step, **artist_kws)
                else:
                    artist, = ax.plot(y, x, drawstyle=drawstyle, **artist_kws)
                artist.sticky_edges.x[:] = sticky_stat
                artist.sticky_edges.y[:] = sticky_data

            hist_artists.append(artist)

        if kde:

            # Add in the density curves

            try:
                density = densities[key]
            except KeyError:
                continue
            support = density.index

            if "x" in self.variables:
                line_args = support, density
                sticky_x, sticky_y = None, (0, np.inf)
            else:
                line_args = density, support
                sticky_x, sticky_y = (0, np.inf), None

            line_kws["color"] = to_rgba(color, 1)
            line, = ax.plot(
                *line_args, **line_kws,
            )

            if sticky_x is not None:
                line.sticky_edges.x[:] = sticky_x
            if sticky_y is not None:
                line.sticky_edges.y[:] = sticky_y

    if element == "bars" and "linewidth" not in plot_kws:

        # Now we handle linewidth, which depends on the scaling of the plot

        # We will base everything on the minimum bin width
        hist_metadata = pd.concat([
            # Use .items for generality over dict or df
            h.index.to_frame() for _, h in histograms.items()
        ]).reset_index(drop=True)
        thin_bar_idx = hist_metadata["widths"].idxmin()
        binwidth = hist_metadata.loc[thin_bar_idx, "widths"]
        left_edge = hist_metadata.loc[thin_bar_idx, "edges"]

        # Set initial value
        default_linewidth = math.inf

        # Loop through subsets based only on facet variables
        for sub_vars, _ in self.iter_data():

            ax = self._get_axes(sub_vars)

            # Needed in some cases to get valid transforms.
            # Innocuous in other cases?
            ax.autoscale_view()

            # Convert binwidth from data coordinates to pixels
            pts_x, pts_y = 72 / ax.figure.dpi * abs(
                ax.transData.transform([left_edge + binwidth] * 2)
                - ax.transData.transform([left_edge] * 2)
            )
            if self.data_variable == "x":
                binwidth_points = pts_x
            else:
                binwidth_points = pts_y

            # The relative size of the lines depends on the appearance
            # This is a provisional value and may need more tweaking
            default_linewidth = min(.1 * binwidth_points, default_linewidth)

        # Set the attributes
        for bar in hist_artists:

            # Don't let the lines get too thick
            max_linewidth = bar.get_linewidth()
            if not fill:
                max_linewidth *= 1.5

            linewidth = min(default_linewidth, max_linewidth)

            # If not filling, don't let lines dissapear
            if not fill:
                min_linewidth = .5
                linewidth = max(linewidth, min_linewidth)

            bar.set_linewidth(linewidth)

    # --- Finalize the plot ----

    # Axis labels
    ax = self.ax if self.ax is not None else self.facets.axes.flat[0]
    default_x = default_y = ""
    if self.data_variable == "x":
        default_y = estimator.stat.capitalize()
    if self.data_variable == "y":
        default_x = estimator.stat.capitalize()
    self._add_axis_labels(ax, default_x, default_y)

    # Legend for semantic variables
    if "hue" in self.variables and legend:

        if fill or element == "bars":
            artist = partial(mpl.patches.Patch)
        else:
            artist = partial(mpl.lines.Line2D, [], [])

        ax_obj = self.ax if self.ax is not None else self.facets
        self._add_legend(
            ax_obj, artist, fill, element, multiple, alpha, plot_kws, {},
        )

# Create data
df = pd.DataFrame({
    'animal': ['mouse', 'mouse', 'mouse', 'dog', 'dog', 'kangaroo'],
    'population': [10,  10, 15, 20, 10, 40],
    'country': ['GB', 'USA', 'AU', 'USA', 'AU', 'AU']
})

from seaborn.distributions import _DistributionPlotter
_DistributionPlotter.plot_univariate_histogram = plot_univariate_histogram_0112

plt.figure()
sns.histplot(df, x='animal', weights='population', hue='country', stat='probability', multiple='stack', common_norm=True, bin_norm=False)

plt.figure()
sns.histplot(df, x='animal', weights='population', hue='country', stat='probability', multiple='stack', common_norm=True, bin_norm=True)
which yields these two images:

Regular histogram plotting (with #2655 fix)
fig1

However, I want to know the probability a given animal dog can be found in a specific country, so each bin has a constant height of 1. The same data will then look like this:
fig2

@mwaskom
Copy link
Owner

mwaskom commented Sep 8, 2021

Is your second plot there distinct from multiple="fill"?

@mwaskom
Copy link
Owner

mwaskom commented Sep 8, 2021

maybe you could make a diff between you code and the existing code? then it would be more possible to understand what you've changed...

@zerothi
Copy link
Author

zerothi commented Sep 8, 2021

maybe you could make a diff between you code and the existing code? then it would be more possible to understand what you've changed...

Absolutely true ;)
Here

diff --git a/seaborn/distributions.py b/seaborn/distributions.py
index 5f63289..8329807 100644
--- a/seaborn/distributions.py
+++ b/seaborn/distributions.py
@@ -424,6 +424,12 @@ class _DistributionPlotter(VectorPlotter):
                 warn_singular=False,
             )
 
+        sum_weight = 0.
+        if common_norm:
+            for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True):
+                if "weights" in self.variables:
+                    sum_weight += sub_data["weights"].sum()
+
         # First pass through the data to compute the histograms
         for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True):
 
@@ -464,12 +470,21 @@ class _DistributionPlotter(VectorPlotter):
             hist = pd.Series(heights, index=index, name="heights")
 
             # Apply scaling to normalize across groups
-            if common_norm:
+            if common_norm and weights is None:
                 hist *= len(sub_data) / len(all_data)
+            elif common_norm:
+                hist *= weights.sum() / sum_weight
 
             # Store the finalized histogram data for future plotting
             histograms[key] = hist
 
+        bin_norm = plot_kws.pop("bin_norm", False)
+        if common_norm and bin_norm:
+            # Normalize each bin to 1 given the different hues
+            column_sum = pd.DataFrame(histograms).sum(axis=1)
+            for histogram in histograms.values():
+                histogram /= column_sum
+
         # Modify the histogram and density data to resolve multiple groups
         histograms, baselines = self._resolve_multiple(histograms, multiple)
         if kde:

Is your second plot there distinct from multiple="fill"?

Hmm.. bummer. Sorry for bringing this up. No it is exactly the same. I didn't realise this was what it did... :(

However, does it work with weights?
Thanks!!!

@zerothi zerothi closed this as completed Sep 8, 2021
@zerothi zerothi reopened this Sep 8, 2021
@mwaskom
Copy link
Owner

mwaskom commented Sep 8, 2021

I think it accounts for weights properly with stat="count":

sns.histplot(
    x=["a", "a", "b", "b"],
    hue=["x", "y", "x", "y"],
    weights=[8, 2, .4, .6],
    multiple="fill",
)

image

The "fill" normalization is applied after the stat normalization so it will inherit whatever issues that has, although I think it's difficult to think about what that would be showing even if it worked properly...

@zerothi
Copy link
Author

zerothi commented Sep 9, 2021

Ok. So the bug in #2655 still applies and I'll consider this a duplicate of that. :)

@zerothi zerothi closed this as completed Sep 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants