Skip to content

Commit

Permalink
Accept Stat and/or Move(s) as anonymous positional args in Plot.add (#…
Browse files Browse the repository at this point in the history
…2948)

* Accept Stat and/or Move(s) as anonymous positional args in Plot.add

* Hacky fix to Move/Stat typing

* Fix Plot.add type checks

* Remove unnecessary type check

* Update nextgen demo

* Make Move.__call__ signature mirror Stat.__call__

* Fix tests that create mock Moves
  • Loading branch information
mwaskom authored Aug 13, 2022
1 parent daa9924 commit e9ad419
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 83 deletions.
25 changes: 11 additions & 14 deletions doc/nextgen/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
"id": "ae0e288e-74cf-461c-8e68-786e364032a1",
"metadata": {},
"source": [
"### Data transformation: the Stat\n",
"### Data transformations: the Stat\n",
"\n",
"\n",
"Built-in statistical transformations are one of seaborn's key features. But currently, they are tied up with the different visual representations. E.g., you can aggregate data in `lineplot`, but not in `scatterplot`.\n",
Expand All @@ -273,7 +273,7 @@
"id": "1788d935-5ad5-4262-993f-8d48c66631b9",
"metadata": {},
"source": [
"The `Stat` is computed on subsets of data defined by the semantic mappings:"
"A `Stat` is computed on subsets of data defined by the semantic mappings:"
]
},
{
Expand Down Expand Up @@ -323,7 +323,7 @@
"outputs": [],
"source": [
"class PeakAnnotation(so.Mark):\n",
" def plot(self, split_generator, scales, orient):\n",
" def _plot(self, split_generator, scales, orient):\n",
" for keys, data, ax in split_generator():\n",
" ix = data[\"y\"].idxmax()\n",
" ax.annotate(\n",
Expand Down Expand Up @@ -388,7 +388,7 @@
"source": [
"(\n",
" so.Plot(tips, \"day\", \"total_bill\", color=\"time\")\n",
" .add(so.Bar(), so.Agg(), move=so.Dodge())\n",
" .add(so.Dot(), so.Dodge())\n",
")"
]
},
Expand All @@ -409,7 +409,7 @@
"source": [
"(\n",
" so.Plot(tips, \"day\", \"total_bill\", color=\"time\")\n",
" .add(so.Bar(), so.Agg(), move=so.Dodge(empty=\"fill\", gap=.1))\n",
" .add(so.Bar(), so.Agg(), so.Dodge(empty=\"fill\", gap=.1))\n",
")"
]
},
Expand All @@ -430,7 +430,7 @@
"source": [
"(\n",
" so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"sex\")\n",
" .add(so.Bar(), so.Agg(), move=so.Dodge())\n",
" .add(so.Bar(), so.Agg(), so.Dodge())\n",
")"
]
},
Expand All @@ -451,7 +451,7 @@
"source": [
"(\n",
" so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n",
" .add(so.Dot(), move=so.Dodge(by=[\"color\"]))\n",
" .add(so.Dot(), so.Dodge(by=[\"color\"]))\n",
")"
]
},
Expand All @@ -460,7 +460,7 @@
"id": "c001004a-6771-46eb-b231-6accf88fe330",
"metadata": {},
"source": [
"It's also possible to stack multiple moves or kinds of moves by passing a list:"
"It's also possible to stack multiple moves or kinds of moves:"
]
},
{
Expand All @@ -472,10 +472,7 @@
"source": [
"(\n",
" so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n",
" .add(\n",
" so.Dot(),\n",
" move=[so.Dodge(by=[\"color\"]), so.Jitter(.5)]\n",
" )\n",
" .add(so.Dot(), so.Dodge(by=[\"color\"]), so.Jitter(.5))\n",
")"
]
},
Expand Down Expand Up @@ -568,8 +565,8 @@
" so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n",
" .scale(\n",
" x=\"log\",\n",
" y=so.Continuous(transform=\"log\").tick(at=[3, 10, 30, 100, 300]),\n",
" color=so.Continuous(\"rocket\", transform=\"log\"),\n",
" y=so.Continuous(trans=\"log\").tick(at=[3, 10, 30, 100, 300]),\n",
" color=so.Continuous(\"rocket\", trans=\"log\"),\n",
" )\n",
" .add(so.Dots())\n",
")"
Expand Down
25 changes: 19 additions & 6 deletions seaborn/_core/moves.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from pandas import DataFrame

from seaborn._core.groupby import GroupBy
from seaborn._core.scales import Scale


@dataclass
class Move:

group_by_orient: ClassVar[bool] = True

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:
raise NotImplementedError


Expand All @@ -31,7 +34,9 @@ class Jitter(Move):
# TODO what is the best way to have a reasonable default?
# The problem is that "reasonable" seems dependent on the mark

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

# TODO is it a problem that GroupBy is not used for anything here?
# Should we type it as optional?
Expand Down Expand Up @@ -68,7 +73,9 @@ class Dodge(Move):
# TODO should the default be an "all" singleton?
by: Optional[list[str]] = None

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

grouping_vars = [v for v in groupby.order if v in data]
groups = groupby.agg(data, {"width": "max"})
Expand Down Expand Up @@ -138,7 +145,9 @@ def _stack(self, df, orient):

return df

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

# TODO where to ensure that other semantic variables are sorted properly?
# TODO why are we not using the passed in groupby here?
Expand All @@ -154,7 +163,9 @@ class Shift(Move):
x: float = 0
y: float = 0

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

data = data.copy(deep=False)
data["x"] = data["x"] + self.x
Expand Down Expand Up @@ -188,7 +199,9 @@ def _norm(self, df, var):

return df

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

other = {"x": "y", "y": "x"}[orient]
return groupby.apply(data, self._norm, other)
Expand Down
63 changes: 38 additions & 25 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from contextlib import contextmanager
from collections import abc
from collections.abc import Callable, Generator, Hashable
from typing import Any, cast
from typing import Any, List, Optional, cast

from cycler import cycler
import pandas as pd
Expand Down Expand Up @@ -338,16 +338,14 @@ def on(self, target: Axes | SubFigure | Figure) -> Plot:
def add(
self,
mark: Mark,
stat: Stat | None = None,
move: Move | list[Move] | None = None,
*,
*transforms: Stat | Mark,
orient: str | None = None,
legend: bool = True,
data: DataSource = None,
**variables: VariableSpec,
) -> Plot:
"""
Define a layer of the visualization.
Define a layer of the visualization in terms of mark and data transform(s).
This is the main method for specifying how the data should be visualized.
It can be called multiple times with different arguments to define
Expand All @@ -357,48 +355,63 @@ def add(
----------
mark : :class:`seaborn.objects.Mark`
The visual representation of the data to use in this layer.
stat : :class:`seaborn.objects.Stat`
A transformation applied to the data before plotting.
move : :class:`seaborn.objects.Move`
Additional transformation(s) to handle over-plotting.
legend : bool
Option to suppress the mark/mappings for this layer from the legend.
transforms : :class:`seaborn.objects.Stat` or :class:`seaborn.objects.Move`
Objects representing transforms to be applied before plotting the data.
Current, at most one :class:`seaborn.objects.Stat` can be used, and it
must be passed first. This constraint will be relaxed in the future.
orient : "x", "y", "v", or "h"
The orientation of the mark, which affects how the stat is computed.
Typically corresponds to the axis that defines groups for aggregation.
The "v" (vertical) and "h" (horizontal) options are synonyms for "x" / "y",
but may be more intuitive with some marks. When not provided, an
orientation will be inferred from characteristics of the data and scales.
legend : bool
Option to suppress the mark/mappings for this layer from the legend.
data : DataFrame or dict
Data source to override the global source provided in the constructor.
variables : data vectors or identifiers
Additional layer-specific variables, including variables that will be
passed directly to the stat without scaling.
passed directly to the transforms without scaling.
"""
if not isinstance(mark, Mark):
msg = f"mark must be a Mark instance, not {type(mark)!r}."
raise TypeError(msg)

if stat is not None and not isinstance(stat, Stat):
msg = f"stat must be a Stat instance, not {type(stat)!r}."
# TODO This API for transforms was a late decision, and previously Plot.add
# accepted 0 or 1 Stat instances and 0, 1, or a list of Move instances.
# It will take some work to refactor the internals so that Stat and Move are
# treated identically, and until then well need to "unpack" the transforms
# here and enforce limitations on the order / types.

stat: Optional[Stat]
move: Optional[List[Move]]
error = False
if not transforms:
stat, move = None, None
elif isinstance(transforms[0], Stat):
stat = transforms[0]
move = [m for m in transforms[1:] if isinstance(m, Move)]
error = len(move) != len(transforms) - 1
else:
stat = None
move = [m for m in transforms if isinstance(m, Move)]
error = len(move) != len(transforms)

if error:
msg = " ".join([
"Transforms must have at most one Stat type (in the first position),",
"and all others must be a Move type. Given transform type(s):",
", ".join(str(type(t).__name__) for t in transforms) + "."
])
raise TypeError(msg)

# TODO decide how to allow Mark to have default Stat/Move
# if stat is None and hasattr(mark, "default_stat"):
# stat = mark.default_stat()

# TODO it doesn't work to supply scalars to variables, but that would be nice

# TODO accept arbitrary variables defined by the stat (/move?) here
# (but not in the Plot constructor)
# Should stat variables ever go in the constructor, or just in the add call?

new = self._clone()
new._layers.append({
"mark": mark,
"stat": stat,
"move": move,
# TODO it doesn't work to supply scalars to variables, but it should
"vars": variables,
"source": data,
"legend": legend,
Expand Down Expand Up @@ -1232,7 +1245,7 @@ def get_order(var):
move_groupers.insert(0, orient)
order = {var: get_order(var) for var in move_groupers}
groupby = GroupBy(order)
df = move_step(df, groupby, orient)
df = move_step(df, groupby, orient, scales)

df = self._unscale_coords(subplots, df, orient)

Expand Down
Loading

0 comments on commit e9ad419

Please sign in to comment.