Skip to content

Commit

Permalink
Fix size norm extrapolation (#2580)
Browse files Browse the repository at this point in the history
* fix size range

* add default value to sizes

* Add test for GH2539

* Restrict usage of default size range attribute

* Work around archaic matplotlib auto-legending of named Series objects

Co-authored-by: Risako <[email protected]>
  • Loading branch information
mwaskom and Risako authored May 10, 2021
1 parent 536fa2d commit 405f666
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
17 changes: 9 additions & 8 deletions seaborn/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def __init__(

if map_type == "numeric":

levels, lookup_table, norm = self.numeric_mapping(
levels, lookup_table, norm, size_range = self.numeric_mapping(
data, sizes, norm,
)

Expand All @@ -297,6 +297,7 @@ def __init__(
levels, lookup_table = self.categorical_mapping(
data, sizes, order,
)
size_range = None

# --- Option 3: datetime mapping

Expand All @@ -308,11 +309,13 @@ def __init__(
# pandas and numpy represent datetime64 data
list(data), sizes, order,
)
size_range = None

self.map_type = map_type
self.levels = levels
self.norm = norm
self.sizes = sizes
self.size_range = size_range
self.lookup_table = lookup_table

def infer_map_type(self, norm, sizes, var_type):
Expand All @@ -334,9 +337,7 @@ def _lookup_single(self, key):
normed = self.norm(key)
if np.ma.is_masked(normed):
normed = np.nan
size_values = self.lookup_table.values()
size_range = min(size_values), max(size_values)
value = size_range[0] + normed * np.ptp(size_range)
value = self.size_range[0] + normed * np.ptp(self.size_range)
return value

def categorical_mapping(self, data, sizes, order):
Expand Down Expand Up @@ -385,15 +386,15 @@ def categorical_mapping(self, data, sizes, order):
# across the visual representation of the data. But at this
# point, we don't know the visual representation. Likely we
# want to change the logic of this Mapping so that it gives
# points on a nornalized range that then gets unnormalized
# points on a normalized range that then gets un-normalized
# when we know what we're drawing. But given the way the
# package works now, this way is cleanest.
sizes = self.plotter._default_size_range

# For categorical sizes, use regularly-spaced linear steps
# between the minimum and maximum sizes. Then reverse the
# ramp so that the largest value is used for the first entry
# in size_order, etc. This is because "ordered" categoricals
# in size_order, etc. This is because "ordered" categories
# are often though to go in decreasing priority.
sizes = np.linspace(*sizes, len(levels))[::-1]
lookup_table = dict(zip(levels, sizes))
Expand Down Expand Up @@ -437,7 +438,7 @@ def numeric_mapping(self, data, sizes, norm):

# When not provided, we get the size range from the plotter
# object we are attached to. See the note in the categorical
# method about how this is suboptimal for future development.:
# method about how this is suboptimal for future development.
size_range = self.plotter._default_size_range

# Now that we know the minimum and maximum sizes that will get drawn,
Expand Down Expand Up @@ -477,7 +478,7 @@ def numeric_mapping(self, data, sizes, norm):
sizes = lo + sizes_scaled * (hi - lo)
lookup_table = dict(zip(levels, sizes))

return levels, lookup_table, norm
return levels, lookup_table, norm, size_range


@share_init_params_with_map
Expand Down
34 changes: 34 additions & 0 deletions seaborn/tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,40 @@ def test_linewidths(self, long_df):
scatterplot(data=long_df, x="x", y="y", linewidth=lw)
assert ax.collections[0].get_linewidths().item() == lw

def test_size_norm_extrapolation(self):

# https://github.com/mwaskom/seaborn/issues/2539
x = np.arange(0, 20, 2)
f, axs = plt.subplots(1, 2, sharex=True, sharey=True)

slc = 5
kws = dict(sizes=(50, 200), size_norm=(0, x.max()), legend="brief")

scatterplot(x=x, y=x, size=x, ax=axs[0], **kws)
scatterplot(x=x[:slc], y=x[:slc], size=x[:slc], ax=axs[1], **kws)

assert np.allclose(
axs[0].collections[0].get_sizes()[:slc],
axs[1].collections[0].get_sizes()
)

legends = [ax.legend_ for ax in axs]
legend_data = [
{
label.get_text(): handle.get_sizes().item()
for label, handle in zip(legend.get_texts(), legend.legendHandles)
} for legend in legends
]

for key in set(legend_data[0]) & set(legend_data[1]):
if key == "y":
# At some point (circa 3.0) matplotlib auto-added pandas series
# with a valid name into the legend, which messes up this test.
# I can't track down when that was added (or removed), so let's
# just anticipate and ignore it here.
continue
assert legend_data[0][key] == legend_data[1][key]

def test_datetime_scale(self, long_df):

ax = scatterplot(data=long_df, x="t", y="y")
Expand Down

0 comments on commit 405f666

Please sign in to comment.