Skip to content

Commit

Permalink
Partly functional prototype of Plot.on
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Waskom committed Aug 30, 2021
1 parent 09f2ad6 commit 55a0d34
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 15 deletions.
43 changes: 34 additions & 9 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure, SubFigure
import matplotlib.pyplot as plt # TODO defer import into Plot.show()

from seaborn._core.rules import categorical_order, variable_type
from seaborn._core.data import PlotData
Expand All @@ -26,8 +28,6 @@
from typing import Literal, Any
from collections.abc import Callable, Generator, Iterable, Hashable
from pandas import DataFrame, Series, Index
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from matplotlib.scale import ScaleBase
from matplotlib.colors import Normalize
from seaborn._core.mappings import SemanticMapping
Expand Down Expand Up @@ -75,11 +75,13 @@ def __init__(
"y": ScaleWrapper(mpl.scale.LinearScale("y"), "unknown"),
}

self._target_obj = None

self._subplotspec = {}
self._facetspec = {}
self._pairspec = {}

def on(self) -> Plot:
def on(self, obj: Axes | SubFigure | Figure) -> Plot:

# TODO Provisional name for a method that accepts an existing Axes object,
# and possibly one that does all of the figure/subplot configuration
Expand All @@ -91,7 +93,15 @@ def on(self) -> Plot:
# larger figure. Not sure what to do about that. I suppose existing figure could
# disabling legend_out.

raise NotImplementedError()
if not isinstance(obj, (Axes, SubFigure, Figure)):
err = (
f"`obj` must be an instance of {Axes}, {SubFigure}, or {Figure}. "
f"Got an object of class {obj.__class__} instead."
)
raise TypeError(err)

self._target_obj = obj

return self

def add(
Expand Down Expand Up @@ -377,7 +387,7 @@ def plot(self, pyplot=False) -> Plot:
self._plot_layer(layer, layer_mappings)

# TODO this should be configurable
self._figure.tight_layout()
# self._figure.tight_layout()

return self

Expand All @@ -391,7 +401,10 @@ def show(self, **kwargs) -> None:

# Keep an eye on whether matplotlib implements "attaching" an existing
# figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024
self.clone().plot(pyplot=True)
if self._target_obj is None:
self.clone().plot(pyplot=True)
else:
self.plot(pyplot=True)
plt.show(**kwargs)

def save(self) -> Plot: # TODO perhaps this should not return self?
Expand Down Expand Up @@ -459,8 +472,18 @@ def _setup_figure(self, pyplot: bool = False) -> None:
)

# --- Figure initialization
figure_kws = {"figsize": getattr(self, "_figsize", None)} # TODO fix this hack
self._figure = subplots.init_figure(pyplot, figure_kws)
if isinstance(self._target_obj, Axes):
if self._facetspec or self._pairspec:
err = (
"Cannot create multiple subplots after calling `Plot.on` with a "
f"{Axes} object. You may want to provide a {SubFigure} instead."
)
raise RuntimeError(err)
self._figure = subplots.init_from_axes(self._target_obj)

else:
figure_kws = {"figsize": getattr(self, "_figsize", None)} # TODO fix
self._figure = subplots.init_figure(pyplot, self._target_obj, figure_kws)

# --- Assignment of scales
for sub in subplots:
Expand Down Expand Up @@ -809,6 +832,8 @@ def _repr_png_(self) -> bytes:
# But we can still show a Plot where the user has manually invoked .plot()
if hasattr(self, "_figure"):
figure = self._figure
elif self._target_obj is not None:
figure = self.plot()._figure
else:
figure = self.clone().plot()._figure

Expand Down
45 changes: 39 additions & 6 deletions seaborn/_core/subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import numpy as np
import matplotlib as mpl
from matplotlib.figure import Figure, SubFigure
import matplotlib.pyplot as plt

from seaborn._core.rules import categorical_order

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Generator
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from seaborn._core.data import PlotData


Expand Down Expand Up @@ -134,19 +135,51 @@ def _determine_axis_sharing(self) -> None:
val = True
self.subplot_spec[key] = val

def init_figure(self, pyplot: bool, figure_kws: dict | None = None) -> Figure:
def init_from_axes(self, axes: Axes) -> Figure:

self._subplot_list = [{
"ax": axes,
"left": True,
"right": True,
"top": True,
"bottom": True,
"col": None,
"row": None,
"x": "x",
"y": "y",
}]
return axes.figure

def init_figure(
self,
pyplot: bool,
figure_obj: Figure | SubFigure = None,
figure_kws: dict | None = None,
) -> Figure:
"""Initialize matplotlib objects and add seaborn-relevant metadata."""
# TODO other methods don't have defaults, maybe don't have one here either
if figure_kws is None:
figure_kws = {}

if pyplot:
figure = plt.figure(**figure_kws)
# TODO we need a clearer distinction between "figure" and "figure_obj"
# (The first is the top of the mpl artist hierarchy we want to track,
# while the second is the object that directly own the subplots.)
if isinstance(figure_obj, Figure):
figure = figure_obj
elif isinstance(figure_obj, SubFigure):
figure = figure_obj.figure
else:
figure = mpl.figure.Figure(**figure_kws)
if pyplot:
figure = plt.figure(**figure_kws)
else:
figure = mpl.figure.Figure(**figure_kws)
figure_obj = figure
self._figure = figure

axs = figure.subplots(**self.subplot_spec, squeeze=False)
# TODO check that figure does not currently have .axes?
assert not figure_obj.axes

axs = figure_obj.subplots(**self.subplot_spec, squeeze=False)

if self.wrap:
# Remove unused Axes and flatten the rest into a (2D) vector
Expand Down

0 comments on commit 55a0d34

Please sign in to comment.