Skip to content

Commit

Permalink
feat: use data_types.is_categorical instead of utils.is_categorical
Browse files Browse the repository at this point in the history
Preparatory commit for removal of `utils.is_categorical`. Slightly
changes behavior of some plots, e.g. parallel categories is now stricter in
choosing which columns are considered as categorical.
  • Loading branch information
mbelak-dtml committed Aug 30, 2023
1 parent 2ac7a3e commit a5b524b
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 23 deletions.
7 changes: 3 additions & 4 deletions edvart/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import pandas as pd
import plotly.graph_objs as go

from edvart import utils
from edvart.data_types import is_numeric
from edvart.data_types import is_categorical, is_numeric

# Multiplier which makes plotly interactive plots (size in pixels) and
# matplotlib plots (size in inches) about the same size
Expand Down Expand Up @@ -101,7 +100,7 @@ def _scatter_plot_2d_noninteractive(
) -> None:
_fig, ax = plt.subplots(figsize=figsize)
if color_col is not None:
is_color_categorical = utils.is_categorical(df[color_col]) or not is_numeric(df[color_col])
is_color_categorical = is_categorical(df[color_col]) or not is_numeric(df[color_col])
if is_color_categorical:
color_categorical = pd.Categorical(df[color_col])
color_codes = color_categorical.codes
Expand Down Expand Up @@ -163,7 +162,7 @@ def _scatter_plot_2d_interactive(
layout.yaxis.scaleratio = 1
fig = go.Figure(layout=layout)
if color_col is not None:
is_color_categorical = utils.is_categorical(df[color_col]) or not is_numeric(df[color_col])
is_color_categorical = is_categorical(df[color_col]) or not is_numeric(df[color_col])
if is_color_categorical:
df = df.copy()
x_name, y_name = "__edvart_scatter_x", "__edvart_scatter_y"
Expand Down
20 changes: 9 additions & 11 deletions edvart/report_sections/multivariate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from IPython.display import Markdown, display
from sklearn.preprocessing import StandardScaler

from edvart.data_types import is_numeric
from edvart.data_types import is_boolean, is_categorical, is_numeric
from edvart.plots import scatter_plot_2d
from edvart.report_sections.code_string_formatting import get_code, total_dedent
from edvart.report_sections.section_base import ReportSection, Section, Verbosity
from edvart.utils import discrete_colorscale, is_categorical
from edvart.utils import discrete_colorscale

try:
from edvart.report_sections.umap import UMAP
Expand Down Expand Up @@ -533,11 +533,9 @@ def __init__(
columns = [
col
for col in df.columns
if is_numeric(df[col])
or (
is_categorical(df[col], nunique_max=nunique_max)
and df[col].nunique() <= nunique_max
)
if is_categorical(df[col], unique_value_count_threshold=nunique_max)
or is_boolean(df[col])
or is_numeric(df[col])
]
# If all columns are numeric we don't want to list them all in the generated call
# Setting columns to None will result in the columns argument not being included
Expand Down Expand Up @@ -740,10 +738,10 @@ def __init__(
columns = [
col
for col in df.columns
if (
is_categorical(df[col], nunique_max=nunique_max)
and df[col].nunique() <= nunique_max
)
if is_categorical(
df[col],
unique_value_count_threshold=nunique_max
) or is_boolean(df[col])
]

# If all columns are numeric we don't want to list them all in the generated call
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import plotly.graph_objects as go
from IPython.display import Markdown, display

from edvart import utils
from edvart.data_types import is_numeric
from edvart.data_types import is_categorical, is_numeric
from edvart.decorators import check_index_time_ascending
from edvart.report_sections.code_string_formatting import get_code, total_dedent
from edvart.report_sections.section_base import Section, Verbosity
Expand Down Expand Up @@ -115,7 +114,7 @@ def _time_series_line_plot_colored(df, columns=None, color_col=None):
)

layout = dict(xaxis_rangeslider_visible=True)
if not utils.is_categorical(df[color_col]):
if not is_categorical(df[color_col]):
raise ValueError(f"Cannot color by non-categorical column `{color_col}`")
if df[color_col].nunique() > 20:
warnings.warn("Coloring by categorical column with many unique values!")
Expand Down
10 changes: 5 additions & 5 deletions tests/test_multivariate_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_code_export_verbosity_medium_all_cols_valid():
expected_code = [
"pca_first_vs_second(df=df)",
"pca_explained_variance(df=df)",
"parallel_categories(df=df)",
"parallel_categories(df=df, columns=['col2'])",
]

assert len(exported_code) == len(expected_code)
Expand Down Expand Up @@ -228,14 +228,14 @@ def test_generated_code_verobsity_1():
)"""
),
"parallel_coordinates(df=df)",
"parallel_categories(df=df)",
"parallel_categories(df=df, columns=['B'])",
]
else:
expected_code = [
"pca_first_vs_second(df=df, columns=['A', 'C', 'D'])",
"pca_explained_variance(df=df, columns=['A', 'C', 'D'])",
"parallel_coordinates(df=df)",
"parallel_categories(df=df)",
"parallel_categories(df=df, columns=['B'])",
]

assert len(exported_code) == len(expected_code)
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_generated_code_verobsity_2():
(
get_code(utils.discrete_colorscale),
get_code(multivariate_analysis.ParallelCategories.parallel_categories),
"parallel_categories(df=df)",
"parallel_categories(df=df, columns=['B'])",
)
),
]
Expand Down Expand Up @@ -352,7 +352,7 @@ def test_verbosity_low_different_subsection_verbosities():
expected_subsections_str = ", ".join(expected_subsections)
expected_code = [
"multivariate_analysis(df=df, " f"subsections=[{expected_subsections_str}])",
"parallel_categories(df=df)",
"parallel_categories(df=df, columns=['B'])",
"\n\n".join(
(
get_code(utils.discrete_colorscale),
Expand Down

0 comments on commit a5b524b

Please sign in to comment.