Skip to content

Commit

Permalink
Fix backwards compatibility issues on older matplotlibs
Browse files Browse the repository at this point in the history
Squashed commit of the following:

commit d03e6a5d8bf078df32d52035b504d6ff44f0036f
Author: Michael Waskom <[email protected]>
Date:   Sun Aug 22 18:08:36 2021 -0400

    Matplotlib/pandas norm backcompat in test

commit d115f7c1e780bec8b2ddc0a5834f47567bdfaa43
Author: Michael Waskom <[email protected]>
Date:   Sun Aug 22 18:08:25 2021 -0400

    Simpler workaround for matplotlib scale backcompat

commit dab8b3abb5cb9a8390695560cbdfe8fa6dc0a1e3
Author: Michael Waskom <[email protected]>
Date:   Sun Aug 22 16:45:33 2021 -0400

    Add workaround for uncopiable transforms on matplotlib<3.4

commit d25100241c24640d659a5ec6d2be84445240b461
Author: Michael Waskom <[email protected]>
Date:   Sun Aug 22 16:14:48 2021 -0400

    Add backcompat for matplotlib scale factory

commit 6c166644e6cba14b3a0bf696c8bee52b596470d1
Author: Michael Waskom <[email protected]>
Date:   Sun Aug 22 15:49:19 2021 -0400

    Update minimally supported dependency versions

commit 657739b2137fbc4353374db1bb5ac5bc9ee8da40
Author: Michael Waskom <[email protected]>
Date:   Sun Aug 22 15:46:56 2021 -0400

    Add backcompat layer for ax.set_{xy}scale(scale_obj)

commit 8863dd0f2d454a215d17b0daa9e6c287a3a5d792
Author: Michael Waskom <[email protected]>
Date:   Sun Aug 22 15:45:52 2021 -0400

    Add backcompat layer for GridSpec.n{col,row}s
  • Loading branch information
Michael Waskom committed Aug 29, 2021
1 parent 3a10862 commit 1b1929b
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 52 deletions.
10 changes: 5 additions & 5 deletions ci/deps_pinned.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy~=1.16.0
pandas~=0.24.0
matplotlib~=3.0.0
scipy~=1.2.0
statsmodels~=0.9.0
numpy~=1.17.0
pandas~=0.25.0
matplotlib~=3.1.0
scipy~=1.3.0
statsmodels~=0.10.0
35 changes: 29 additions & 6 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io
import itertools
from copy import deepcopy
from distutils.version import LooseVersion

import pandas as pd
import matplotlib as mpl
Expand Down Expand Up @@ -281,7 +282,13 @@ def scale_numeric(
# have it work the way you would expect.

if isinstance(scale, str):
scale = mpl.scale.scale_factory(scale, var, **kwargs)
# Matplotlib scales require an Axis object for backwards compatability,
# but it is not used, aside from extraction of the axis_name in LogScale.
# This can be removed when the minimum matplotlib is raised to 3.4,
# and a simple string (`var`) can be passed.
class Axis:
axis_name = var
scale = mpl.scale.scale_factory(scale, Axis(), **kwargs)

if norm is None:
# TODO what about when we want to infer the scale from the norm?
Expand Down Expand Up @@ -455,17 +462,33 @@ def _setup_figure(self, pyplot: bool = False) -> None:
figure_kws = {"figsize": getattr(self, "_figsize", None)} # TODO
self._figure = subplots.init_figure(pyplot, figure_kws)

# --- Assignment of scales
for sub in subplots:
ax = sub["ax"]
for axis in "xy":
axis_key = sub[axis]
scale = self._scales[axis_key]._scale
if LooseVersion(mpl.__version__) < "3.4":
# The ability to pass a BaseScale instance to Axes.set_{axis}scale
# was added to matplotlib in version 3.4.0:
# https://github.com/matplotlib/matplotlib/pull/19089
# Workaround: use the scale name, which is restrictive only
# if the user wants to define a custom scale.
# Additionally, setting the scale after updating the units breaks
# in some cases on older versions of matplotlib (with older pandas?)
# so only do it if necessary.
axis_obj = getattr(ax, f"{axis}axis")
if axis_obj.get_scale() != scale.name:
ax.set(**{f"{axis}scale": scale.name})
else:
ax.set(**{f"{axis}scale": scale})

# --- Figure annotation
for sub in subplots:
ax = sub["ax"]
for axis in "xy":
axis_key = sub[axis]
ax.set(**{
# Note: this is the only non "annotation" part of this code
# everything else can happen after .plot(), but we need this first
# Should perhaps separate it out to make that more clear
# (or pass scales into Subplots)
f"{axis}scale": self._scales[axis_key]._scale,
# TODO Should we make it possible to use only one x/y label for
# all rows/columns in a faceted plot? Maybe using sub{axis}label,
# although the alignments of the labels from taht method leaves
Expand Down
12 changes: 12 additions & 0 deletions seaborn/_core/scales.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations
from copy import copy
from distutils.version import LooseVersion

import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib.scale import LinearScale
from matplotlib.colors import Normalize

Expand Down Expand Up @@ -39,6 +42,15 @@ def __init__(

# TODO add a repr with useful information about what is wrapped and metadata

if LooseVersion(mpl.__version__) < "3.4":
# Until matplotlib 3.4, matplotlib transforms could not be deepcopied.
# Fixing PR: https://github.com/matplotlib/matplotlib/pull/19281
# That means that calling deepcopy() on a Plot object fails when the
# recursion gets down to the `ScaleWrapper` objects.
# As a workaround, stop the recursion at this level with older matplotlibs.
def __deepcopy__(self, memo=None):
return copy(self)

@property
def order(self):
if hasattr(self._scale, "order"):
Expand Down
3 changes: 2 additions & 1 deletion seaborn/tests/_core/test_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def test_numeric_multi_lookup(self, num_vector, num_norm):

cmap = color_palette("mako", as_cmap=True)
m = HueMapping(palette=cmap).setup(num_vector)
assert_array_equal(m(num_vector), cmap(num_norm(num_vector))[:, :3])
expected_colors = cmap(num_norm(num_vector.to_numpy()))[:, :3]
assert_array_equal(m(num_vector), expected_colors)

def test_bad_palette(self, num_vector):

Expand Down
68 changes: 33 additions & 35 deletions seaborn/tests/_core/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools
import warnings
import imghdr
from distutils.version import LooseVersion

import numpy as np
import pandas as pd
Expand All @@ -19,6 +20,17 @@
assert_vector_equal = functools.partial(assert_series_equal, check_names=False)


def assert_gridspec_shape(ax, nrows=1, ncols=1):

gs = ax.get_gridspec()
if LooseVersion(mpl.__version__) < "3.2":
assert gs._nrows == nrows
assert gs._ncols == ncols
else:
assert gs.nrows == nrows
assert gs.ncols == ncols


class MockStat(Stat):

def __call__(self, data):
Expand Down Expand Up @@ -676,7 +688,7 @@ def check_facet_results_1d(self, p, df, dim, key, order=None):
assert subplot[dim] == level
assert subplot[other_dim] is None
assert subplot["ax"].get_title() == f"{key} = {level}"
assert getattr(subplot["ax"].get_gridspec(), f"n{dim}s") == len(order)
assert_gridspec_shape(subplot["ax"], **{f"n{dim}s": len(order)})

def test_1d_from_init(self, long_df, dim):

Expand Down Expand Up @@ -732,9 +744,9 @@ def check_facet_results_2d(self, p, df, variables, order=None):
assert subplot["axes"].get_title() == (
f"{variables['row']} = {row_level} | {variables['col']} = {col_level}"
)
gridspec = subplot["axes"].get_gridspec()
assert gridspec.nrows == len(levels["row"])
assert gridspec.ncols == len(levels["col"])
assert_gridspec_shape(
subplot["axes"], len(levels["row"]), len(levels["col"])
)

def test_2d_from_init(self, long_df):

Expand Down Expand Up @@ -813,10 +825,8 @@ def test_col_wrapping(self):
wrap = 3
p = Plot().facet(col=cols, wrap=wrap).plot()

gridspec = p._figure.axes[0].get_gridspec()
assert len(p._figure.axes) == 4
assert gridspec.ncols == 3
assert gridspec.nrows == 2
assert_gridspec_shape(p._figure.axes[0], len(cols) // wrap + 1, wrap)

# TODO test axis labels and titles

Expand All @@ -826,10 +836,8 @@ def test_row_wrapping(self):
wrap = 3
p = Plot().facet(row=rows, wrap=wrap).plot()

gridspec = p._figure.axes[0].get_gridspec()
assert_gridspec_shape(p._figure.axes[0], wrap, len(rows) // wrap + 1)
assert len(p._figure.axes) == 4
assert gridspec.ncols == 2
assert gridspec.nrows == 3

# TODO test axis labels and titles

Expand All @@ -845,10 +853,7 @@ def check_pair_grid(self, p, x, y):
ax = subplot["ax"]
assert ax.get_xlabel() == "" if x_j is None else x_j
assert ax.get_ylabel() == "" if y_i is None else y_i

gs = subplot["ax"].get_gridspec()
assert gs.ncols == len(x)
assert gs.nrows == len(y)
assert_gridspec_shape(subplot["ax"], len(y), len(x))

@pytest.mark.parametrize(
"vector_type", [list, np.array, pd.Series, pd.Index]
Expand Down Expand Up @@ -886,8 +891,7 @@ def test_non_cartesian(self, long_df):
ax = subplot["ax"]
assert ax.get_xlabel() == x[i]
assert ax.get_ylabel() == y[i]
assert ax.get_gridspec().nrows == 1
assert ax.get_gridspec().ncols == len(x) == len(y)
assert_gridspec_shape(ax, 1, len(x))

root, *other = p._figure.axes
for axis in "xy":
Expand Down Expand Up @@ -930,10 +934,7 @@ def test_with_facets(self, long_df):
assert ax.get_xlabel() == x
assert ax.get_ylabel() == y_i
assert ax.get_title() == f"{col} = {col_i}"

gs = subplot["ax"].get_gridspec()
assert gs.ncols == len(facet_levels)
assert gs.nrows == len(y)
assert_gridspec_shape(ax, len(y), len(facet_levels))

@pytest.mark.parametrize("variables", [("rows", "y"), ("columns", "x")])
def test_error_on_facet_overlap(self, long_df, variables):
Expand Down Expand Up @@ -1005,42 +1006,39 @@ def test_axis_sharing_with_facets(self, long_df):
def test_x_wrapping(self, long_df):

x_vars = ["f", "x", "y", "z"]
p = Plot(long_df, y="y").pair(x=x_vars, wrap=3).plot()
wrap = 3
p = Plot(long_df, y="y").pair(x=x_vars, wrap=wrap).plot()

gridspec = p._figure.axes[0].get_gridspec()
assert len(p._figure.axes) == 4
assert gridspec.ncols == 3
assert gridspec.nrows == 2
assert_gridspec_shape(p._figure.axes[0], len(x_vars) // wrap + 1, wrap)
assert len(p._figure.axes) == len(x_vars)

# TODO test axis labels and visibility

def test_y_wrapping(self, long_df):

y_vars = ["f", "x", "y", "z"]
p = Plot(long_df, x="x").pair(y=y_vars, wrap=3).plot()
wrap = 3
p = Plot(long_df, x="x").pair(y=y_vars, wrap=wrap).plot()

gridspec = p._figure.axes[0].get_gridspec()
assert len(p._figure.axes) == 4
assert gridspec.nrows == 3
assert gridspec.ncols == 2
assert_gridspec_shape(p._figure.axes[0], wrap, len(y_vars) // wrap + 1)
assert len(p._figure.axes) == len(y_vars)

# TODO test axis labels and visibility

def test_noncartesian_wrapping(self, long_df):

x_vars = ["a", "b", "c", "t"]
y_vars = ["f", "x", "y", "z"]
wrap = 3

p = (
Plot(long_df, x="x")
.pair(x=x_vars, y=y_vars, wrap=3, cartesian=False)
.pair(x=x_vars, y=y_vars, wrap=wrap, cartesian=False)
.plot()
)

gridspec = p._figure.axes[0].get_gridspec()
assert len(p._figure.axes) == 4
assert gridspec.nrows == 2
assert gridspec.ncols == 3
assert_gridspec_shape(p._figure.axes[0], len(x_vars) // wrap + 1, wrap)
assert len(p._figure.axes) == len(x_vars)


class TestLabelVisibility:
Expand Down
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
PYTHON_REQUIRES = ">=3.7"

INSTALL_REQUIRES = [
'numpy>=1.16',
'pandas>=0.24',
'matplotlib>=3.0',
'numpy>=1.17',
'pandas>=0.25',
'matplotlib>=3.1',
]

EXTRAS_REQUIRE = {
'all': [
'scipy>=1.2',
'statsmodels>=0.9',
'scipy>=1.3',
'statsmodels>=0.10',
]
}

Expand Down

0 comments on commit 1b1929b

Please sign in to comment.