Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Norm move class #2827

Merged
merged 2 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions seaborn/_core/moves.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar, Callable, Optional, Union

import numpy as np
from pandas import DataFrame

from seaborn._core.groupby import GroupBy

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional
from pandas import DataFrame


@dataclass
class Move:

group_by_orient: ClassVar[bool] = True

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
raise NotImplementedError

Expand Down Expand Up @@ -61,10 +60,12 @@ class Dodge(Move):
"""
Displacement and narrowing of overlapping marks along orientation axis.
"""
empty: str = "keep" # keep, drop, fill
empty: str = "keep" # Options: keep, drop, fill
gap: float = 0

# TODO accept just a str here?
# TODO should this always be present?
# TODO should the default be an "all" singleton?
by: Optional[list[str]] = None

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
Expand Down Expand Up @@ -117,7 +118,7 @@ class Stack(Move):
"""
Displacement of overlapping bar or area marks along the value axis.
"""
# TODO center? (or should this be a different move?)
# TODO center? (or should this be a different move, eg. Stream())

def _stack(self, df, orient):

Expand All @@ -140,6 +141,7 @@ def _stack(self, df, orient):
def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:

# TODO where to ensure that other semantic variables are sorted properly?
# TODO why are we not using the passed in groupby here?
groupers = ["col", "row", orient]
return GroupBy(groupers).apply(data, self._stack, orient)

Expand All @@ -158,3 +160,41 @@ def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
data["x"] = data["x"] + self.x
data["y"] = data["y"] + self.y
return data


@dataclass
class Norm(Move):
"""
Divisive scaling on the value axis after aggregating within groups.
"""

func: Union[Callable, str] = "max"
where: Optional[str] = None
by: Optional[list[str]] = None
percent: bool = False

group_by_orient: ClassVar[bool] = False

def _norm(self, df, var):

if self.where is None:
denom_data = df[var]
else:
denom_data = df.query(self.where)[var]
df[var] = df[var] / denom_data.agg(self.func)

if self.percent:
df[var] = df[var] * 100

return df

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:

other = {"x": "y", "y": "x"}[orient]
return groupby.apply(data, self._norm, other)


# TODO
# @dataclass
# class Ridge(Move):
# ...
11 changes: 6 additions & 5 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,11 +1117,12 @@ def get_order(var):
if move is not None:
moves = move if isinstance(move, list) else [move]
for move_step in moves:
move_groupers = [
orient,
*(getattr(move_step, "by", None) or grouping_properties),
*default_grouping_vars,
]
move_by = getattr(move_step, "by", None)
if move_by is None:
move_by = grouping_properties
move_groupers = [*move_by, *default_grouping_vars]
if move_step.group_by_orient:
move_groupers.insert(0, orient)
order = {var: get_order(var) for var in move_groupers}
groupby = GroupBy(order)
df = move_step(df, groupby, orient)
Expand Down
2 changes: 1 addition & 1 deletion seaborn/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
from seaborn._stats.regression import OLSFit, PolyFit # noqa: F401
from seaborn._stats.histograms import Hist # noqa: F401

from seaborn._core.moves import Dodge, Jitter, Shift, Stack # noqa: F401
from seaborn._core.moves import Dodge, Jitter, Norm, Shift, Stack # noqa: F401

from seaborn._core.scales import Nominal, Continuous, Temporal # noqa: F401
40 changes: 39 additions & 1 deletion seaborn/tests/_core/test_moves.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pandas.testing import assert_series_equal
from numpy.testing import assert_array_equal, assert_array_almost_equal

from seaborn._core.moves import Dodge, Jitter, Shift, Stack
from seaborn._core.moves import Dodge, Jitter, Shift, Stack, Norm
from seaborn._core.rules import categorical_order
from seaborn._core.groupby import GroupBy

Expand Down Expand Up @@ -318,3 +318,41 @@ def test_moves(self, toy_df, x, y):
res = Shift(x=x, y=y)(toy_df, gb, "x")
assert_array_equal(res["x"], toy_df["x"] + x)
assert_array_equal(res["y"], toy_df["y"] + y)


class TestNorm(MoveFixtures):

@pytest.mark.parametrize("orient", ["x", "y"])
def test_default_no_groups(self, df, orient):

other = {"x": "y", "y": "x"}[orient]
gb = GroupBy(["null"])
res = Norm()(df, gb, orient)
assert res[other].max() == pytest.approx(1)

@pytest.mark.parametrize("orient", ["x", "y"])
def test_default_groups(self, df, orient):

other = {"x": "y", "y": "x"}[orient]
gb = GroupBy(["grp2"])
res = Norm()(df, gb, orient)
for _, grp in res.groupby("grp2"):
assert grp[other].max() == pytest.approx(1)

def test_sum(self, df):

gb = GroupBy(["null"])
res = Norm("sum")(df, gb, "x")
assert res["y"].sum() == pytest.approx(1)

def test_where(self, df):

gb = GroupBy(["null"])
res = Norm(where="x == 2")(df, gb, "x")
assert res.loc[res["x"] == 2, "y"].max() == pytest.approx(1)

def test_percent(self, df):

gb = GroupBy(["null"])
res = Norm(percent=True)(df, gb, "x")
assert res["y"].max() == pytest.approx(100)