Skip to content

Commit

Permalink
Begin refactoring setup pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Waskom committed Sep 26, 2021
1 parent edf2729 commit 42b2481
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 48 deletions.
6 changes: 5 additions & 1 deletion seaborn/_core/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@ def setup(
levels, mapping = self._setup_categorical(
data, palette, order,
)
return LookupMapping(mapping)
# return LookupMapping(mapping)
# TODO hack to keep things runnable until downstream refactor
mapping = LookupMapping(mapping)
mapping.levels = levels
return mapping

elif map_type == "numeric":

Expand Down
89 changes: 44 additions & 45 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,9 @@ def __init__(
"dash": DashSemantic(),
}

# TODO is using "unknown" here the best approach?
# Other options would be:
# - None as the value for type
# - some sort of uninitialized singleton for the object,
self._scales = {
"x": ScaleWrapper(mpl.scale.LinearScale("x"), "unknown"),
"y": ScaleWrapper(mpl.scale.LinearScale("y"), "unknown"),
}

self._target = None

self._scales = {}
self._subplotspec = {}
self._facetspec = {}
self._pairspec = {}
Expand Down Expand Up @@ -208,16 +200,10 @@ def pair(
for axis in "xy":
keys = []
for i, col in enumerate(pairspec.get(axis, [])):

key = f"{axis}{i}"
keys.append(key)
pairspec["variables"][key] = col

# TODO how much type inference to do here?
# (i.e., should we force .scale_categorical, etc.?)
# We could also accept a scales keyword? Or document that calling, e.g.
# p.scale_categorical("x4") is the right approach
self._scales[key] = ScaleWrapper(mpl.scale.LinearScale(key), "unknown")
if keys:
pairspec["structure"][axis] = keys

Expand Down Expand Up @@ -430,9 +416,10 @@ def resize(self, val):

def plot(self, pyplot=False) -> Plot:

self._setup_layers()
self._setup_data()
self._setup_scales()
self._setup_mappings()
self._setup_orderings()
self._setup_figure(pyplot)

for layer in self._layers:
Expand Down Expand Up @@ -476,11 +463,9 @@ def save(self) -> Plot: # TODO perhaps this should not return self?
# End of public API
# ================================================================================ #

# TODO order these methods to match the order they get called in
def _setup_data(self):

def _setup_layers(self):

common_data = (
self._data = (
self._data
.concat(
self._facetspec.get("source"),
Expand All @@ -495,18 +480,51 @@ def _setup_layers(self):
# TODO concat with mapping spec

for layer in self._layers:
layer.data = common_data.concat(layer.source, layer.variables)
# TODO FIXME:mutable we need to make this not modify the existing object
# TODO one idea is add() inserts a dict into _layerspec or something
layer.data = self._data.concat(layer.source, layer.variables)

def _setup_scales(self) -> None:

variables = set(self._data.frame)
for layer in self._layers:
variables |= set(layer.data.frame)

for var in variables:
scale = self._scales.get(var, ScaleWrapper(mpl.scale.LinearScale(var)))
var_type = scale.type
if var_type is None:
values = [
self._data.frame.get(var),
# TODO important to check for var in x.variables, not just in x
# Because we only want to concat if a variable was *added* here
*(x.data.frame.get(var) for x in self._layers if var in x.variables)
]
var_type = variable_type(pd.concat(values, ignore_index=True))

# TODO eventually this will be updating a different dictionary
# TODO do this as ScaleWrapper.from_wrapper(scale, var_type)?
self._scales[var] = ScaleWrapper(scale._scale, var_type, scale.norm)

def _setup_mappings(self) -> None:

layers = self._layers
for var, scale in self._scales.items():
if scale.type == "unknown" and any(var in layer for layer in layers):
# TODO this is copied from _setup_mappings ... ripe for abstraction!

# TODO we should setup default mappings here based on whether a mapping
# variable appears in at least one of the layer data but isn't in self._mappings
# Source of what mappings to check can be some dictionary of default mappings?

for var, mapping in self._mappings.items():
if any(var in layer for layer in layers):
all_data = pd.concat(
[layer.data.frame.get(var) for layer in layers]
).reset_index(drop=True)
scale.type = variable_type(all_data)
scale = self._scales.get(var)
self._mappings[var] = mapping.setup(all_data, scale)

def _setup_orderings(self) -> None:

...

def _setup_figure(self, pyplot: bool = False) -> None:

Expand Down Expand Up @@ -614,22 +632,6 @@ def _setup_figure(self, pyplot: bool = False) -> None:
title_text = ax.set_title(title)
title_text.set_visible(show_title)

def _setup_mappings(self) -> None:

layers = self._layers

# TODO we should setup default mappings here based on whether a mapping
# variable appears in at least one of the layer data but isn't in self._mappings
# Source of what mappings to check can be some dictionary of default mappings?

for var, mapping in self._mappings.items():
if any(var in layer for layer in layers):
all_data = pd.concat(
[layer.data.frame.get(var) for layer in layers]
).reset_index(drop=True)
scale = self._scales.get(var)
mapping.setup(all_data, scale)

def _plot_layer(self, layer: Layer, mappings: dict[str, SemanticMapping]) -> None:

default_grouping_vars = ["col", "row", "group"] # TODO where best to define?
Expand Down Expand Up @@ -902,7 +904,7 @@ def _repr_png_(self) -> bytes:

class Layer:

data: PlotData # TODO added externally (bad design?)
data = PlotData

def __init__(
self,
Expand All @@ -918,7 +920,4 @@ def __init__(
self.variables = variables

def __contains__(self, key: str) -> bool:

if self.data is None:
return False
return key in self.data
4 changes: 2 additions & 2 deletions seaborn/_core/scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ScaleWrapper:
def __init__(
self,
scale: ScaleBase,
type: VariableType, # TODO don't use builtin name?
type: VariableType | None = None, # TODO don't use builtin name?
norm: tuple[float | None, float | None] | Normalize | None = None,
):

Expand All @@ -32,7 +32,7 @@ def __init__(
self.reverse = transform.inverted().transform

# TODO can't we get type from the scale object in most cases?
self.type = VarType(type)
self.type = type if type is None else VarType(type)

if norm is None:
norm = norm_from_scale(scale, norm)
Expand Down

0 comments on commit 42b2481

Please sign in to comment.