diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index b74bb2aa24..8842378c49 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -149,10 +149,12 @@ class Plot: _limits: dict[str, tuple[Any, Any]] _labels: dict[str, str | Callable[[str], str] | None] - _subplot_spec: dict[str, Any] # TODO values type _facet_spec: FacetSpec _pair_spec: PairSpec + _figure_spec: dict[str, Any] + _subplot_spec: dict[str, Any] + def __init__( self, *args: DataSource | VariableSpec, @@ -175,10 +177,12 @@ def __init__( self._limits = {} self._labels = {} - self._subplot_spec = {} self._facet_spec = {} self._pair_spec = {} + self._subplot_spec = {} + self._figure_spec = {} + self._target = None def _resolve_positionals( @@ -242,10 +246,12 @@ def _clone(self) -> Plot: new._labels.update(self._labels) new._limits.update(self._limits) - new._subplot_spec.update(self._subplot_spec) new._facet_spec.update(self._facet_spec) new._pair_spec.update(self._pair_spec) + new._figure_spec.update(self._figure_spec) + new._subplot_spec.update(self._subplot_spec) + new._target = self._target return new @@ -612,8 +618,7 @@ def configure( new = self._clone() - # TODO this is a hack; make a proper figure spec object - new._figsize = figsize # type: ignore + new._figure_spec["figsize"] = figsize if sharex is not None: new._subplot_spec["sharex"] = sharex @@ -825,9 +830,8 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: self._subplots = subplots = Subplots(subplot_spec, facet_spec, pair_spec) # --- Figure initialization - figure_kws = {"figsize": getattr(p, "_figsize", None)} # TODO fix self._figure = subplots.init_figure( - pair_spec, self.pyplot, figure_kws, p._target, + pair_spec, self.pyplot, p._figure_spec, p._target, ) # --- Figure annotation diff --git a/tests/_core/test_plot.py b/tests/_core/test_plot.py index 093e7e2447..af8da1c443 100644 --- a/tests/_core/test_plot.py +++ b/tests/_core/test_plot.py @@ -1122,6 +1122,12 @@ def test_2d_with_order(self, long_df, reorder): p = Plot(long_df).facet(**variables, order=order) self.check_facet_results_2d(p, long_df, variables, order) + def test_figsize(self): + + figsize = (4, 2) + p = Plot().configure(figsize=figsize).plot() + assert tuple(p._figure.get_size_inches()) == figsize + def test_axis_sharing(self, long_df): variables = {"row": "a", "col": "c"}