Skip to content

Commit

Permalink
feat: improve default discrete colorscales in multivariate analysis (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mbelak-dtml authored Dec 21, 2023
1 parent 386d868 commit 8cf365b
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 20 deletions.
32 changes: 24 additions & 8 deletions edvart/report_sections/multivariate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from edvart.plots import scatter_plot_2d
from edvart.report_sections.code_string_formatting import code_dedent, get_code
from edvart.report_sections.section_base import ReportSection, Section, Verbosity
from edvart.utils import discrete_colorscale, select_numeric_columns
from edvart.utils import (
get_default_discrete_colorscale,
hsl_wheel_colorscale,
make_discrete_colorscale,
select_numeric_columns,
)

try:
from edvart.report_sections.umap import UMAP # pylint: disable=cyclic-import
Expand Down Expand Up @@ -529,7 +534,9 @@ def required_imports(self) -> List[str]:
if self.verbosity <= Verbosity.MEDIUM:
return ["from edvart.report_sections.multivariate_analysis import parallel_coordinates"]
return [
"from edvart.utils import discrete_colorscale",
"from edvart.utils import ("
" get_default_discrete_colorscale, make_discrete_colorscale, hsl_wheel_colorscale"
")",
"from typing import Iterable",
"import plotly",
"import plotly.graph_objects as go",
Expand Down Expand Up @@ -560,7 +567,11 @@ def add_cells(self, cells: List[Dict[str, Any]], df: pd.DataFrame) -> None:
code = default_call
else:
code = (
get_code(discrete_colorscale)
get_code(hsl_wheel_colorscale)
+ "\n\n"
+ get_code(make_discrete_colorscale)
+ "\n\n"
+ get_code(get_default_discrete_colorscale)
+ "\n\n"
+ get_code(parallel_coordinates)
+ "\n\n"
Expand Down Expand Up @@ -628,7 +639,7 @@ def parallel_coordinates(

if is_categorical_color:
categories = df[color_col].unique()
colorscale = list(discrete_colorscale(len(categories)))
colorscale = get_default_discrete_colorscale(n_colors=len(categories))
# encode categories into numbers
color_series = pd.Series(pd.Categorical(df[color_col]).codes)
else:
Expand Down Expand Up @@ -660,7 +671,6 @@ def parallel_coordinates(
)
else:
line = None

# Add numeric columns to dimensions
dimensions = [{"label": col_name, "values": df[col_name]} for col_name in numeric_columns]
# Add categorical columns to dimensions
Expand Down Expand Up @@ -721,7 +731,9 @@ def required_imports(self) -> List[str]:
if self.verbosity <= Verbosity.MEDIUM:
return ["from edvart.report_sections.multivariate_analysis import parallel_categories"]
return [
"from edvart.utils import discrete_colorscale",
"from edvart.utils import ("
" get_default_discrete_colorscale, make_discrete_colorscale, hsl_wheel_colorscale"
")",
"import plotly.graph_objects as go",
]

Expand Down Expand Up @@ -749,7 +761,11 @@ def add_cells(self, cells: List[Dict[str, Any]], df: pd.DataFrame) -> None:
code = default_call
else:
code = (
get_code(discrete_colorscale)
get_code(hsl_wheel_colorscale)
+ "\n\n"
+ get_code(make_discrete_colorscale)
+ "\n\n"
+ get_code(get_default_discrete_colorscale)
+ "\n\n"
+ get_code(parallel_categories)
+ "\n\n"
Expand Down Expand Up @@ -810,7 +826,7 @@ def parallel_categories(
)
if categorical_color:
categories = df[color_col].unique()
colorscale = list(discrete_colorscale(len(categories)))
colorscale = get_default_discrete_colorscale(n_colors=len(categories))
# encode categories into numbers
color_series = pd.Series(pd.Categorical(df[color_col]).codes)
else:
Expand Down
74 changes: 66 additions & 8 deletions edvart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union

import pandas as pd
import plotly
import statsmodels.api as sm
from scipy import stats

Expand Down Expand Up @@ -126,23 +127,80 @@ def reindex_to_period(
return df


def discrete_colorscale(n, saturation=0.5, lightness=0.5) -> Iterable[Tuple[float, str]]:
def hsl_wheel_colorscale(n: int, saturation=0.5, lightness=0.5) -> Iterable[str]:
"""
Generate a colorscale of n discrete colors.
Colours are equally spaced around the complete HSL wheel with constant saturation and lightness.
Colors are equally spaced around the complete HSL wheel with constant saturation and lightness.
Returns
-------
Iterable[str]
An iterable of n plotly-compatible HSL strings.
"""
for i in range(n):
yield f"hsl({(i / n) * 360 :.2f}, {saturation * 100 :.2f}%, {lightness * 100 :.2f}%)"


def make_discrete_colorscale(colorscale: List[str], n_colors: int) -> Iterable[Tuple[float, str]]:
"""
Generate a colorscale of n discrete colors for use in `plotly.graph_objects`.
Note that when using `plotly.express`, the parameter `color_discrete_sequence`
can be used instead.
Parameters
----------
colorscale : List[str]
A list of plotly-compatible colors.
n_colors : int
Number of colors to in the generated colorscale.
Returns
-------
Iterable[Tuple[float, str]]
An iterable of 2n tuples, where each tuple contains a value between 0 and 1
(the values are equally and each value appears twice) and an HSL string containing
an HSL color string with hue corresponding to the value.
(the values are equally spaced in the interval and each value appears twice), and one of the
colors from the `colorscale`.
Examples
--------
>>> list(make_discrete_colorscale(["red", "green", "blue"], 4))
[
(0, "red"), (0.25, "red"),
(0.25, "green"), (0.5, "green"),
(0.5, "blue"), (0.75, "blue"),
(0.75, "red"), (1, "red")
]
"""
for i in range(n):
color = f"hsl({(i / n) * 360 :.2f}, {saturation * 100 :.2f}%, {lightness * 100 :.2f}%)"
yield (i / n, color)
yield ((i + 1) / n, color)
for i in range(n_colors):
color = colorscale[i % len(colorscale)]
yield (i / n_colors, color)
yield ((i + 1) / n_colors, color)


def get_default_discrete_colorscale(n_colors: int) -> List[Tuple[float, str]]:
"""
Get a default Plotly-compatible colorscale of n discrete colors.
Parameters
----------
n_colors : int
Number of colors.
Returns
-------
list[tuple[float, str]]
A list of 2n tuples, where each tuple contains a value between 0 and 1 and a
plotly-compatible color string.
"""
if n_colors <= 10:
colorscale = plotly.colors.qualitative.Plotly
elif n_colors <= 24:
colorscale = plotly.colors.qualitative.Light24
else:
colorscale = list(hsl_wheel_colorscale(n_colors))
return list(make_discrete_colorscale(colorscale, n_colors))


def select_numeric_columns(df: pd.DataFrame, columns: Optional[List[str]]) -> List[str]:
Expand Down
18 changes: 14 additions & 4 deletions tests/test_multivariate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
MultivariateAnalysisSubsection,
)
from edvart.report_sections.section_base import Verbosity
from edvart.utils import select_numeric_columns
from edvart.utils import (
get_default_discrete_colorscale,
make_discrete_colorscale,
select_numeric_columns,
)

from .execution_utils import check_section_executes
from .pyarrow_utils import pyarrow_parameterize
Expand Down Expand Up @@ -286,14 +290,18 @@ def test_generated_code_verbosity_2(pyarrow_dtypes: bool):
),
"\n\n".join(
(
get_code(utils.discrete_colorscale),
get_code(utils.hsl_wheel_colorscale),
get_code(utils.make_discrete_colorscale),
get_code(utils.get_default_discrete_colorscale),
get_code(multivariate_analysis.parallel_coordinates),
"parallel_coordinates(df=df)",
)
),
"\n\n".join(
(
get_code(utils.discrete_colorscale),
get_code(utils.hsl_wheel_colorscale),
get_code(utils.make_discrete_colorscale),
get_code(utils.get_default_discrete_colorscale),
get_code(multivariate_analysis.parallel_categories),
"parallel_categories(df=df)",
)
Expand Down Expand Up @@ -384,7 +392,9 @@ def test_verbosity_low_different_subsection_verbosities(pyarrow_dtypes: bool):
"parallel_categories(df=df)",
"\n\n".join(
(
get_code(utils.discrete_colorscale),
get_code(utils.hsl_wheel_colorscale),
get_code(utils.make_discrete_colorscale),
get_code(utils.get_default_discrete_colorscale),
get_code(multivariate_analysis.parallel_coordinates),
"parallel_coordinates(df=df)",
)
Expand Down

0 comments on commit 8cf365b

Please sign in to comment.