Skip to content

Commit

Permalink
Address more small TODOs and flesh out docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed May 23, 2020
1 parent 2aaef33 commit 811261e
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 54 deletions.
107 changes: 63 additions & 44 deletions seaborn/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,51 +26,68 @@


class SemanticMapping:
"""Base class for mapping data values to plot attributes."""

# -- Default attributes that all SemanticMapping subclasses must set

# Whether the mapping is numeric, categorical, or datetime
map_type = None

# Ordered list of unique values in the input data
levels = None

# A mapping from the data values to corresponding plot attributes
lookup_table = None

def __init__(self, plotter):

# TODO Putting this here so we can continue to use a lot of the
# logic that's built into the library, but the idea of this class
# is to move towards semantic mappings that are agnositic about the
# kind of plot they're going to be used to draw, which is going to
# take some rethinking.
# kind of plot they're going to be used to draw.
# Fully achieving that is going to take some thinking.
self.plotter = plotter

def map(cls, plotter, *args, **kwargs):
# This method is assigned the __init__ docstring
method_name = "_{}_map".format(cls.__name__[:-7].lower())
setattr(plotter, method_name, cls(plotter, *args, **kwargs))
return plotter

def _lookup_single(self, key):
"""Apply the mapping to a single data value."""
return self.lookup_table[key]

value = self.lookup_table[key]
return value

def __call__(self, data, *args, **kwargs):

if isinstance(data, (list, np.ndarray, pd.Series)):
# TODO need to debug why data.map(self.lookup_table) doesn't work
return [self._lookup_single(key, *args, **kwargs) for key in data]
def __call__(self, key, *args, **kwargs):
"""Get the attribute(s) values for the data key."""
if isinstance(key, (list, np.ndarray, pd.Series)):
return [self._lookup_single(k, *args, **kwargs) for k in key]
else:
return self._lookup_single(data, *args, **kwargs)
return self._lookup_single(key, *args, **kwargs)


@share_init_params_with_map
class HueMapping(SemanticMapping):
"""Mapping that sets artist colors according to data values."""
# A specification of the colors that should appear in the plot
palette = None

# Default attributes (TODO use data class?)
# TODO Add docs for what these are (maybe on the base class?)
map_type = None
levels = None
# An object that normalizes data values to [0, 1] range for color mapping
norm = None

# A continuous colormap object for interpolating in a numeric context
cmap = None
palette = None
lookup_table = None

def __init__(
self, plotter, palette=None, order=None, norm=None,
):
"""Map the levels of the `hue` variable to distinct colors.
Parameters
----------
# TODO add generic parameters
"""
super().__init__(plotter)

data = plotter.plot_data["hue"]
Expand All @@ -89,22 +106,19 @@ def __init__(

if map_type == "numeric":

# TODO do we need levels for numeric map?
data = pd.to_numeric(data)
levels, lookup_table, norm, cmap = self.numeric_mapping(
data, palette, norm,
)
# TODO what do we need limits for? Can we get this downstream?
limits = norm.vmin, norm.vmax

# --- Option 2: categorical mapping using seaborn palette

else:

cmap = limits = norm = None
cmap = norm = None
levels, lookup_table = self.categorical_mapping(
# Casting data to list to handle differences in the way
# pandas represents numpy datetime64 data
# pandas and numpy represent datetime64 data
list(data), palette, order,
)

Expand All @@ -116,23 +130,25 @@ def __init__(
self.lookup_table = lookup_table
self.palette = palette
self.levels = levels
self.limits = limits # TODO do we need this
self.norm = norm
self.cmap = cmap

def _lookup_single(self, key):

"""Get the color for a single value, using colormap to interpolate."""
try:
# Use a value that's in the original data vector
value = self.lookup_table[key]
except KeyError:
# Use the colormap to interpolate between existing datapoints
# (e.g. in the context of making a continuous legend)
normed = self.norm(key)
if np.ma.is_masked(normed):
normed = np.nan
value = self.cmap(normed)
return value

def infer_map_type(self, palette, norm, input_format, var_type):

"""Determine how to implement the mapping."""
if palette in QUAL_PALETTES:
map_type = "categorical"
elif norm is not None:
Expand All @@ -148,7 +164,6 @@ def infer_map_type(self, palette, norm, input_format, var_type):

def categorical_mapping(self, data, palette, order):
"""Determine colors when the hue mapping is categorical."""

# -- Identify the order and name of the levels

if order is None:
Expand Down Expand Up @@ -189,7 +204,6 @@ def categorical_mapping(self, data, palette, order):

def numeric_mapping(self, data, palette, norm):
"""Determine colors when the hue variable is quantitative."""

if isinstance(palette, dict):

# The presence of a norm object overrides a dictionary of hues
Expand Down Expand Up @@ -244,17 +258,20 @@ def numeric_mapping(self, data, palette, norm):

@share_init_params_with_map
class SizeMapping(SemanticMapping):

map_type = None
levels = None
"""Mapping that sets artist sizes according to data values."""
# An object that normalizes data values to [0, 1] range
norm = None
sizes = None
lookup_table = None

def __init__(
self, plotter, sizes=None, order=None, norm=None,
):
"""Map the levels of the `size` variable to distinct values.
Parameters
----------
# TODO add generic parameters
"""
super().__init__(plotter)

data = plotter.plot_data["size"]
Expand Down Expand Up @@ -455,17 +472,21 @@ def numeric_mapping(self, data, sizes, norm):

@share_init_params_with_map
class StyleMapping(SemanticMapping):
"""Mapping that sets artist style according to data values."""

# TODO define shared defaults on base class?
# TODO Add docs for what these are (maybe on the base class?)
map_type = None
levels = None
lookup_table = None
# Style mapping is always treated as categorical
map_type = "categorical"

def __init__(
self, plotter, markers=None, dashes=None, order=None,
):
"""Map the levels of the `style` variable to distinct values.
Parameters
----------
# TODO add generic parameters
"""
super().__init__(plotter)

data = plotter.plot_data["style"]
Expand Down Expand Up @@ -514,7 +535,7 @@ def __init__(
self.lookup_table = lookup_table

def _lookup_single(self, key, attr=None):

"""Get attribute(s) for a given data point."""
if attr is None:
value = self.lookup_table[key]
else:
Expand Down Expand Up @@ -566,7 +587,7 @@ class VectorPlotter:

def __init__(self, data=None, variables={}):

plot_data, variables = self.assign_variables(data, variables)
self.assign_variables(data, variables)

for var, cls in self._semantic_mappings.items():
if var in self.semantics:
Expand All @@ -579,12 +600,12 @@ def __init__(self, data=None, variables={}):
getattr(self, f"map_{var}")()

@classmethod
def get_variables(cls, kwargs):
def get_semantics(cls, kwargs):
"""Subset a dictionary` arguments with known semantic variables."""
return {k: kwargs[k] for k in cls.semantics}

def assign_variables(self, data=None, variables={}):
"""Define plot variables."""
"""Define plot variables, optionally using lookup from `data`."""
x = variables.get("x", None)
y = variables.get("y", None)

Expand All @@ -606,8 +627,7 @@ def assign_variables(self, data=None, variables={}):
for v in variables
}

# TODO maybe don't return, or return self for chaining
return plot_data, variables
return self

def _assign_variables_wideform(self, data=None, **kwargs):
"""Define plot variables given wide-form data.
Expand Down Expand Up @@ -638,7 +658,6 @@ def _assign_variables_wideform(self, data=None, **kwargs):
empty = data is None or not len(data)

# Then, determine if we have "flat" data (a single vector)
# TODO extract this into a separate function?
if isinstance(data, dict):
values = data.values()
else:
Expand Down
19 changes: 10 additions & 9 deletions seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def update(var_name, val_name, **kws):
locator = mpl.ticker.LogLocator(numticks=3)
else:
locator = mpl.ticker.MaxNLocator(nbins=3)
limits = min(self._hue_map.levels), max(self._hue_map.levels)
hue_levels, hue_formatted_levels = locator_to_legend_entries(
locator, self._hue_map.limits, self.plot_data["hue"].dtype
locator, limits, self.plot_data["hue"].dtype
)
elif self._hue_map.levels is None:
hue_levels = hue_formatted_levels = []
Expand Down Expand Up @@ -659,7 +660,7 @@ def lineplot(
legend="brief", ax=None, **kwargs
):

variables = _LinePlotter.get_variables(locals())
variables = _LinePlotter.get_semantics(locals())
p = _LinePlotter(
data=data, variables=variables,
estimator=estimator, ci=ci, n_boot=n_boot, seed=seed,
Expand Down Expand Up @@ -936,7 +937,7 @@ def scatterplot(
legend="brief", ax=None, **kwargs
):

variables = _ScatterPlotter.get_variables(locals())
variables = _ScatterPlotter.get_semantics(locals())
p = _ScatterPlotter(
data=data, variables=variables,
x_bins=x_bins, y_bins=y_bins,
Expand Down Expand Up @@ -1215,7 +1216,7 @@ def relplot(
# Use the full dataset to map the semantics
p = plotter(
data=data,
variables=plotter.get_variables(locals()),
variables=plotter.get_semantics(locals()),
legend=legend,
)
p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
Expand Down Expand Up @@ -1268,10 +1269,10 @@ def relplot(
plot_variables = {key: key for key in p.variables}
plot_kws.update(plot_variables)

# Define grid_data with row/col semantics
grid_semantics = "row", "col" # TODO define on FacetGrid?
# Add the grid semantics onto the plotter
grid_semantics = "row", "col"
p.semantics = plot_semantics + grid_semantics
full_data, full_variables = p.assign_variables(
p.assign_variables(
data=data,
variables=dict(
x=x, y=y,
Expand All @@ -1282,8 +1283,8 @@ def relplot(

# Pass the row/col variables to FacetGrid with their original
# names so that the axes titles render correctly
grid_kws = {v: full_variables.get(v, None) for v in grid_semantics}
full_data = full_data.rename(columns=grid_kws)
grid_kws = {v: p.variables.get(v, None) for v in grid_semantics}
full_data = p.plot_data.rename(columns=grid_kws)

# Set up the FacetGrid object
facet_kws = {} if facet_kws is None else facet_kws.copy()
Expand Down
2 changes: 1 addition & 1 deletion seaborn/tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ def test_legend_data(self, long_df):

ax.clear()
p = _LinePlotter(
variables=dict(x=x, y=y, hue=z),
variables=dict(x=x, y=y, hue=z + 1),
legend="brief"
)
p.map_hue(norm=mpl.colors.LogNorm()),
Expand Down

0 comments on commit 811261e

Please sign in to comment.