Skip to content

Commit

Permalink
Prototypes of several methods; also some name bikeshedding
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed May 26, 2021
1 parent ef3f4d2 commit 239cad9
Showing 1 changed file with 96 additions and 60 deletions.
156 changes: 96 additions & 60 deletions seaborn/_new_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from collections import UserString
from datetime import datetime
import warnings
import io

import numpy as np
from numpy import ndarray
import pandas as pd
from pandas import DataFrame, Series, Index
from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_datetime64_dtype
import matplotlib as mpl
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from matplotlib.colors import Colormap, Normalize

from .palettes import (
Expand All @@ -37,7 +40,11 @@ class Plot:

_data: PlotData
_layers: list[Layer]
_mappings: dict[str, SemanticMapping] # TODO keys could be Literal
_mappings: dict[str, SemanticMapping] # TODO keys as Literal, or use TypedDict?

_figure: Figure
_ax: Optional[Axes]
# _facets: Optional[FacetGrid] # TODO would have a circular import?

def __init__(
self,
Expand Down Expand Up @@ -116,6 +123,15 @@ def plot(self) -> Plot:
# Another option would be to raise if layers have different variable
# types (this is basically what ggplot does), but that adds complexity.

# === TODO clean series of setup functions (TODO bikeshed names)
self._setup_figure()

# ===

# TODO we need to be able to show a blank figure
if not self._layers:
return self

all_data = pd.concat([layer.data.frame for layer in self._layers])

# TODO This feels messy here. Should be in a separate method. But called here?
Expand All @@ -132,7 +148,6 @@ def plot(self) -> Plot:
mapping.train(all_data[var]) # TODO return self?
mappings[var] = mapping

# TODO or something like this
for layer in self._layers:

# TODO alt. assign as attribute on Layer?
Expand All @@ -142,6 +157,22 @@ def plot(self) -> Plot:

return self

def _setup_figure(self):

# TODO add external API for parameterizing figure, etc.
# TODO add external API for parameterizing FacetGrid if using
# TODO add external API for passing existing ax (maybe in same method)

# TODO use context manager with theme that has been set
self._figure = Figure()
self._ax = self._figure.add_subplot()

# TODO good place to do this? (needs to handle FacetGrid)
for axis in "xy":
name = self._data.names.get(axis, None)
if name is not None:
self._ax.set(**{f"{axis}label": name})

def _plot_layer(self, layer, mappings):

df = layer.data.frame
Expand All @@ -158,7 +189,9 @@ def _plot_layer(self, layer, mappings):
# data = self.invert_scale(data)

# Something like this?
layer.mark._plot(df, mappings) # do we pass ax (and/or facets?! here?)
# TODO pass in split generator that will be the source of ax
ax = self._ax
layer.mark._plot(df, mappings, ax)

def show(self) -> Plot:

Expand All @@ -176,13 +209,24 @@ def show(self) -> Plot:

def save(self) -> Plot: # or to_file or similar to match pandas?

raise NotImplementedError()
return self

def _repr_html_(self) -> str:
def _repr_png_(self) -> bytes:

# TODO better to do this through a Jupyter hook?
# TODO Would like to allow for svg too ... how to configure?
# TODO We want to skip if the plot has otherwise been shown, but tricky...

html = ""
# TODO we need some way of not plotting multiple times
self.plot()

return html
buffer = io.BytesIO()

# TODO use bbox_inches="tight" like the inline backend?
# pro: better results, con: (sometimes) confusing results
self._figure.savefig(buffer, format="png", bbox_inches="tight")
return buffer.getvalue()


# TODO
Expand Down Expand Up @@ -421,15 +465,10 @@ def __init__(self, **kwargs):

self.kwargs = kwargs

def _plot(self, data, mappings):
def _plot(self, data, mappings, ax):

kws = self.kwargs.copy()

# TODO! Get this from plotter or generator ===
import matplotlib.pyplot as plt
ax = plt.gca()
# ===

# TODO since names match, can probably be automated!
if "hue" in data:
c = mappings["hue"](data["hue"])
Expand Down Expand Up @@ -503,7 +542,7 @@ def __init__(

def train( # TODO ggplot name; let's come up with something better
self,
data: Series,
data: Series, # TODO generally rename Series arguments to distinguish from DF?
) -> None:

palette: Optional[PaletteSpec] = self._input_palette
Expand All @@ -512,84 +551,81 @@ def train( # TODO ggplot name; let's come up with something better
cmap: Optional[Colormap] = None

# TODO these are currently extracted from a passed in plotter instance
# can we avoid doing that now that we are deferring the mapping?
# TODO can just remove if we excise wide-data handling from core
input_format: Literal["long", "wide"] = "long"
var_type = variable_type(data)

if data.notna().any():
map_type = self._infer_map_type(data, palette, norm, input_format)

map_type = self.infer_map_type(palette, norm, input_format, var_type)
# Our goal is to end up with a dictionary mapping every unique
# value in `data` to a color. We will also keep track of the
# metadata about this mapping we will need for, e.g., a legend

# Our goal is to end up with a dictionary mapping every unique
# value in `data` to a color. We will also keep track of the
# metadata about this mapping we will need for, e.g., a legend
# --- Option 1: numeric mapping with a matplotlib colormap

# --- Option 1: numeric mapping with a matplotlib colormap
if map_type == "numeric":

if map_type == "numeric":
data = pd.to_numeric(data)
levels, lookup_table, norm, cmap = self._setup_numeric(
data, palette, norm,
)

data = pd.to_numeric(data)
levels, lookup_table, norm, cmap = self.numeric_mapping(
data, palette, norm,
)
# --- Option 2: categorical mapping using seaborn palette

# --- Option 2: categorical mapping using seaborn palette
elif map_type == "categorical":

elif map_type == "categorical":
levels, lookup_table = self._setup_categorical(
data, palette, order,
)

levels, lookup_table = self.categorical_mapping(
data, palette, order,
)
# --- Option 3: datetime mapping

# --- Option 3: datetime mapping
elif map_type == "datetime":
# TODO this needs actual implementation
cmap = norm = None
levels, lookup_table = self._setup_categorical(
# Casting data to list to handle differences in the way
# pandas and numpy represent datetime64 data
list(data), palette, order,
)

elif map_type == "datetime":
# TODO this needs actual implementation
cmap = norm = None
levels, lookup_table = self.categorical_mapping(
# Casting data to list to handle differences in the way
# pandas and numpy represent datetime64 data
list(data), palette, order,
)

else:
raise RuntimeError() # TODO should never get here ...
else:
raise RuntimeError() # TODO should never get here ...

# TODO do we need to return and assign out here or can the
# type-specific methods do the assignment internally
# TODO do we need to return and assign out here or can the
# type-specific methods do the assignment internally

# TODO I don't love how this is kind of a mish-mash of attributes
# Can we be more consistent across SemanticMapping subclasses?
self.map_type = map_type
self.lookup_table = lookup_table
self.palette = palette
self.levels = levels
self.norm = norm
self.cmap = cmap
# TODO I don't love how this is kind of a mish-mash of attributes
# Can we be more consistent across SemanticMapping subclasses?
self.map_type = map_type
self.lookup_table = lookup_table
self.palette = palette
self.levels = levels
self.norm = norm
self.cmap = cmap

def infer_map_type(
def _infer_map_type(
self,
data: Series,
palette: Optional[PaletteSpec],
norm: Optional[Normalize],
input_format: Literal["long", "wide"],
var_type: Optional[Literal["numeric", "categorical", "datetime"]],
) -> Optional[Literal["numeric", "categorical", "datetime"]]:
"""Determine how to implement the mapping."""
map_type: Optional[Literal["numeric", "categorical", "datetime"]]
if palette in QUAL_PALETTES:
map_type = "categorical"
elif norm is not None:
map_type = "numeric"
elif isinstance(palette, (dict, list)):
elif isinstance(palette, (dict, list)): # TODO mapping/sequence?
map_type = "categorical"
elif input_format == "wide":
map_type = "categorical"
else:
map_type = var_type
map_type = variable_type(data)

return map_type

def categorical_mapping(
def _setup_categorical(
self,
data: Series,
palette: Optional[PaletteSpec],
Expand Down Expand Up @@ -622,7 +658,7 @@ def categorical_mapping(
elif isinstance(palette, list):
if len(palette) != n_colors:
err = "The palette list has the wrong number of colors."
raise ValueError(err)
raise ValueError(err) # TODO downgrade this to a warning?
colors = palette
else:
colors = color_palette(palette, n_colors)
Expand All @@ -631,7 +667,7 @@ def categorical_mapping(

return levels, lookup_table

def numeric_mapping(
def _setup_numeric(
self,
data: Series,
palette: Optional[PaletteSpec],
Expand Down

0 comments on commit 239cad9

Please sign in to comment.