Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make scale config apply to variables computed by the stat transform #2915

Merged
merged 7 commits into from
Jul 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions seaborn/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,7 @@ def set_scale_obj(ax, axis, scale):
trans = scale.get_transform()
kws["functions"] = (trans._forward, trans._inverse)
method(scale.name, **kws)
axis_obj = getattr(ax, f"{axis}axis")
scale.set_default_locators_and_formatters(axis_obj)
else:
ax.set(**{f"{axis}scale": scale})
296 changes: 138 additions & 158 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from seaborn._core.scales import Scale
from seaborn._core.subplots import Subplots
from seaborn._core.groupby import GroupBy
from seaborn._core.properties import PROPERTIES, Property, Coordinate
from seaborn._core.properties import PROPERTIES, Property
from seaborn._core.typing import DataSource, VariableSpec, OrderSpec
from seaborn._core.rules import categorical_order
from seaborn._compat import set_scale_obj
Expand Down Expand Up @@ -654,29 +654,37 @@ def save(self, fname, **kwargs) -> Plot:

def plot(self, pyplot=False) -> Plotter:
"""
Compile the plot and return the :class:`Plotter` engine.

Compile the plot spec and return a Plotter object.
"""
# TODO if we have _target object, pyplot should be determined by whether it
# is hooked into the pyplot state machine (how do we check?)

plotter = Plotter(pyplot=pyplot)

# Process the variable assignments and initialize the figure
common, layers = plotter._extract_data(self)
plotter._setup_figure(self, common, layers)
plotter._transform_coords(self, common, layers)

# Process the scale spec for coordinate variables and transform their data
coord_vars = [v for v in self._variables if re.match(r"^x|y", v)]
plotter._setup_scales(self, common, layers, coord_vars)

# Apply statistical transform(s)
plotter._compute_stats(self, layers)
plotter._setup_scales(self, layers)

# Process scale spec for semantic variables and coordinates computed by stat
plotter._setup_scales(self, common, layers)

# TODO Remove these after updating other methods
# ---- Maybe have debug= param that attaches these when True?
plotter._data = common
plotter._layers = layers

# Process the data for each layer and add matplotlib artists
for layer in layers:
plotter._plot_layer(self, layer)

# Add various figure decorations
plotter._make_legend(self)
plotter._finalize_figure(self)

Expand All @@ -685,7 +693,6 @@ def plot(self, pyplot=False) -> Plotter:
def show(self, **kwargs) -> None:
"""
Render and display the plot.

"""
# TODO make pyplot configurable at the class level, and when not using,
# import IPython.display and call on self to populate cell output?
Expand Down Expand Up @@ -898,112 +905,6 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
title_text = ax.set_title(title)
title_text.set_visible(show_title)

def _transform_coords(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:

for var in p._variables:

# Parse name to identify variable (x, y, xmin, etc.) and axis (x/y)
# TODO should we have xmin0/xmin1 or x0min/x1min?
m = re.match(r"^(?P<prefix>(?P<axis>[x|y])\d*).*", var)

if m is None:
continue

prefix = m["prefix"]
axis = m["axis"]

share_state = self._subplots.subplot_spec[f"share{axis}"]

# Concatenate layers, using only the relevant coordinate and faceting vars,
# This is unnecessarily wasteful, as layer data will often be redundant.
# But figuring out the minimal amount we need is more complicated.
cols = [var, "col", "row"]
# TODO basically copied from _setup_scales, and very clumsy
layer_values = [common.frame.filter(cols)]
for layer in layers:
if layer["data"].frame is None:
for df in layer["data"].frames.values():
layer_values.append(df.filter(cols))
else:
layer_values.append(layer["data"].frame.filter(cols))

if layer_values:
var_df = pd.concat(layer_values, ignore_index=True)
else:
var_df = pd.DataFrame(columns=cols)

prop = Coordinate(axis)
scale = self._get_scale(p, prefix, prop, var_df[var])

# Shared categorical axes are broken on matplotlib<3.4.0.
# https://github.com/matplotlib/matplotlib/pull/18308
# This only affects us when sharing *paired* axes. This is a novel/niche
# behavior, so we will raise rather than hack together a workaround.
if Version(mpl.__version__) < Version("3.4.0"):
from seaborn._core.scales import Nominal
paired_axis = axis in p._pair_spec
cat_scale = isinstance(scale, Nominal)
ok_dim = {"x": "col", "y": "row"}[axis]
shared_axes = share_state not in [False, "none", ok_dim]
if paired_axis and cat_scale and shared_axes:
err = "Sharing paired categorical axes requires matplotlib>=3.4.0"
raise RuntimeError(err)

# Now loop through each subplot, deriving the relevant seed data to setup
# the scale (so that axis units / categories are initialized properly)
# And then scale the data in each layer.
subplots = [view for view in self._subplots if view[axis] == prefix]

# Setup the scale on all of the data and plug it into self._scales
# We do this because by the time we do self._setup_scales, coordinate data
# will have been converted to floats already, so scale inference fails
self._scales[var] = scale._setup(var_df[var], prop)

# Set up an empty series to receive the transformed values.
# We need this to handle piecemeal transforms of categories -> floats.
transformed_data = []
for layer in layers:
index = layer["data"].frame.index
transformed_data.append(pd.Series(dtype=float, index=index, name=var))

for view in subplots:
axis_obj = getattr(view["ax"], f"{axis}axis")

if share_state in [True, "all"]:
# The all-shared case is easiest, every subplot sees all the data
seed_values = var_df[var]
else:
# Otherwise, we need to setup separate scales for different subplots
if share_state in [False, "none"]:
# Fully independent axes are also easy: use each subplot's data
idx = self._get_subplot_index(var_df, view)
elif share_state in var_df:
# Sharing within row/col is more complicated
use_rows = var_df[share_state] == view[share_state]
idx = var_df.index[use_rows]
else:
# This configuration doesn't make much sense, but it's fine
idx = var_df.index

seed_values = var_df.loc[idx, var]

scale = scale._setup(seed_values, prop, axis=axis_obj)

for layer, new_series in zip(layers, transformed_data):
layer_df = layer["data"].frame
if var in layer_df:
idx = self._get_subplot_index(layer_df, view)
new_series.loc[idx] = scale(layer_df.loc[idx, var])

# TODO need decision about whether to do this or modify axis transform
set_scale_obj(view["ax"], axis, scale._matplotlib_scale)

# Now the transformed data series are complete, set update the layer data
for layer, new_series in zip(layers, transformed_data):
layer_df = layer["data"].frame
if var in layer_df:
layer_df[var] = new_series

def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None:

grouping_vars = [v for v in PROPERTIES if v not in "xy"]
Expand Down Expand Up @@ -1073,68 +974,147 @@ def _get_scale(

return scale

def _setup_scales(self, p: Plot, layers: list[Layer]) -> None:
def _get_subplot_data(self, df, var, view, share_state):

# Identify all of the variables that will be used at some point in the plot
variables = set()
for layer in layers:
if layer["data"].frame.empty and layer["data"].frames:
for df in layer["data"].frames.values():
variables.update(df.columns)
if share_state in [True, "all"]:
# The all-shared case is easiest, every subplot sees all the data
seed_values = df[var]
else:
# Otherwise, we need to setup separate scales for different subplots
if share_state in [False, "none"]:
# Fully independent axes are also easy: use each subplot's data
idx = self._get_subplot_index(df, view)
elif share_state in df:
# Sharing within row/col is more complicated
use_rows = df[share_state] == view[share_state]
idx = df.index[use_rows]
else:
variables.update(layer["data"].frame.columns)
# This configuration doesn't make much sense, but it's fine
idx = df.index

for var in variables:
seed_values = df.loc[idx, var]

if var in self._scales:
# Scales for coordinate variables added in _transform_coords
continue
return seed_values

# Get the data all the distinct appearances of this variable.
parts = []
def _setup_scales(
self, p: Plot,
common: PlotData,
layers: list[Layer],
variables: list[str] | None = None,
) -> None:

if variables is None:
# Add variables that have data but not a scale, which happens
# because this method can be called multiple time, to handle
# variables added during the Stat transform.
variables = []
for layer in layers:
if layer["data"].frame.empty and layer["data"].frames:
for df in layer["data"].frames.values():
parts.append(df.get(var))
else:
parts.append(layer["data"].frame.get(var))
var_values = pd.concat(
parts, axis=0, join="inner", ignore_index=True
).rename(var)
variables.extend(layer["data"].frame.columns)
for df in layer["data"].frames.values():
variables.extend(v for v in df if v not in variables)
variables = [v for v in variables if v not in self._scales]

# Determine whether this is an coordinate variable
for var in variables:

# Determine whether this is a coordinate variable
# (i.e., x/y, paired x/y, or derivative such as xmax)
m = re.match(r"^(?P<prefix>(?P<axis>x|y)\d*).*", var)
m = re.match(r"^(?P<coord>(?P<axis>x|y)\d*).*", var)
if m is None:
axis = None
coord = axis = None
else:
var = m["prefix"]
coord = m["coord"]
axis = m["axis"]

prop = PROPERTIES.get(var if axis is None else axis, Property())
scale = self._get_scale(p, var, prop, var_values)
# Get keys that handle things like x0, xmax, properly where relevant
prop_key = var if axis is None else axis
scale_key = var if coord is None else coord

if prop_key not in PROPERTIES:
continue

# Concatenate layers, using only the relevant coordinate and faceting vars,
# This is unnecessarily wasteful, as layer data will often be redundant.
# But figuring out the minimal amount we need is more complicated.
cols = [var, "col", "row"]
parts = [common.frame.filter(cols)]
for layer in layers:
parts.append(layer["data"].frame.filter(cols))
for df in layer["data"].frames.values():
parts.append(df.filter(cols))
var_df = pd.concat(parts, ignore_index=True)

prop = PROPERTIES[prop_key]
scale = self._get_scale(p, scale_key, prop, var_df[var])

if scale_key not in p._variables:
# TODO this implies that the variable was added by the stat
# It allows downstream orientation inference to work properly.
# But it feels rather hacky, so ideally revisit.
scale._priority = 0 # type: ignore

if axis is None:
# We could think about having a broader concept of (un)shared properties
# In general, not something you want to do (different scales in facets)
# But could make sense e.g. with paired plots. Build later.
share_state = None
subplots = []
else:
share_state = self._subplots.subplot_spec[f"share{axis}"]
subplots = [view for view in self._subplots if view[axis] == coord]

# Shared categorical axes are broken on matplotlib<3.4.0.
# https://github.com/matplotlib/matplotlib/pull/18308
# This only affects us when sharing *paired* axes. This is a novel/niche
# behavior, so we will raise rather than hack together a workaround.
if axis is not None and Version(mpl.__version__) < Version("3.4.0"):
from seaborn._core.scales import Nominal
paired_axis = axis in p._pair_spec
cat_scale = isinstance(scale, Nominal)
ok_dim = {"x": "col", "y": "row"}[axis]
shared_axes = share_state not in [False, "none", ok_dim]
if paired_axis and cat_scale and shared_axes:
err = "Sharing paired categorical axes requires matplotlib>=3.4.0"
raise RuntimeError(err)

# Initialize the data-dependent parameters of the scale
# Note that this returns a copy and does not mutate the original
# This dictionary is used by the semantic mappings
if scale is None:
# TODO what is the cleanest way to implement identity scale?
# We don't really need a Scale, and Identity() will be
# overloaded anyway (but maybe a general Identity object
# that can be used as Scale/Mark/Stat/Move?)
# Note that this may not be the right spacer to use
# (but that is only relevant for coordinates, where identity scale
# doesn't make sense or is poorly defined, since we don't use pixels.)
self._scales[var] = Scale._identity()
else:
scale = scale._setup(var_values, prop)
if isinstance(prop, Coordinate):
# If we have a coordinate here, we didn't assign a scale for it
# in _transform_coords, which means it was added during compute_stat
# This allows downstream orientation inference to work properly.
# But it feels a little hacky, so perhaps revisit.
scale._priority = 0 # type: ignore
self._scales[var] = scale
self._scales[var] = scale._setup(var_df[var], prop)

# Everything below here applies only to coordinate variables
# We additionally skip it when we're working with a value
# that is derived from a coordinate we've already processed.
# e.g., the Stat consumed y and added ymin/ymax. In that case,
# we've already setup the y scale and ymin/max are in scale space.
if axis is None or (var != coord and coord in p._variables):
continue

# Set up an empty series to receive the transformed values.
# We need this to handle piecemeal transforms of categories -> floats.
transformed_data = []
for layer in layers:
index = layer["data"].frame.index
empty_series = pd.Series(dtype=float, index=index, name=var)
transformed_data.append(empty_series)

for view in subplots:

axis_obj = getattr(view["ax"], f"{axis}axis")
seed_values = self._get_subplot_data(var_df, var, view, share_state)
view_scale = scale._setup(seed_values, prop, axis=axis_obj)
set_scale_obj(view["ax"], axis, view_scale._matplotlib_scale)

for layer, new_series in zip(layers, transformed_data):
layer_df = layer["data"].frame
if var in layer_df:
idx = self._get_subplot_index(layer_df, view)
new_series.loc[idx] = view_scale(layer_df.loc[idx, var])

# Now the transformed data series are complete, set update the layer data
for layer, new_series in zip(layers, transformed_data):
layer_df = layer["data"].frame
if var in layer_df:
layer_df[var] = new_series

def _plot_layer(self, p: Plot, layer: Layer) -> None:

Expand Down
Loading