Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 578684882
  • Loading branch information
drewbryant authored and colaboratory-team committed Dec 4, 2023
1 parent 97b998f commit 218ef46
Show file tree
Hide file tree
Showing 2 changed files with 344 additions and 71 deletions.
100 changes: 29 additions & 71 deletions google/colab/_quickchart.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Automated chart generation for data frames."""
import itertools

import IPython

Expand Down Expand Up @@ -67,6 +66,7 @@ def determine_charts(df, dataframe_registry, max_chart_instances=None):
# Lazy import to avoid loading matplotlib and transitive deps on kernel init.
from google.colab import _quickchart_dtypes # pylint: disable=g-import-not-at-top
from google.colab import _quickchart_helpers # pylint: disable=g-import-not-at-top
from google.colab import _quickchart_rank # pylint: disable=g-import-not-at-top

dtype_groups = _quickchart_dtypes.classify_dtypes(df)
numeric_cols = dtype_groups['numeric']
Expand All @@ -76,23 +76,30 @@ def determine_charts(df, dataframe_registry, max_chart_instances=None):

if numeric_cols:
section = _quickchart_helpers.histograms_section(
df, numeric_cols[:max_chart_instances], dataframe_registry
df,
_quickchart_rank.rank_histograms(df, numeric_cols)[
:max_chart_instances
],
dataframe_registry,
)
if section.charts:
chart_sections.append(section)

if categorical_cols:
selected_categorical_cols = categorical_cols[:max_chart_instances]
section = _quickchart_helpers.categorical_histograms_section(
df, selected_categorical_cols, dataframe_registry
df,
_quickchart_rank.rank_histograms(df, categorical_cols)[
:max_chart_instances
],
dataframe_registry,
)
if section.charts:
chart_sections.append(section)

if len(numeric_cols) >= 2:
section = _quickchart_helpers.scatter_section(
df,
_select_first_k_pairs(numeric_cols, k=max_chart_instances),
_quickchart_rank.rank_scatter(df, numeric_cols)[:max_chart_instances],
dataframe_registry,
)
if section.charts:
Expand All @@ -101,28 +108,34 @@ def determine_charts(df, dataframe_registry, max_chart_instances=None):
if time_cols:
section = _quickchart_helpers.time_series_line_plots_section(
df,
_select_time_series_cols(
time_cols=time_cols,
numeric_cols=numeric_cols,
categorical_cols=categorical_cols,
k=max_chart_instances,
),
_quickchart_rank.rank_time_series_plots(
df,
time_colnames=time_cols,
numeric_colnames=numeric_cols,
categorical_colnames=categorical_cols,
)[:max_chart_instances],
dataframe_registry,
)
if section.charts:
chart_sections.append(section)

if numeric_cols:
section = _quickchart_helpers.value_plots_section(
df, numeric_cols[:max_chart_instances], dataframe_registry
df,
_quickchart_rank.rank_value_plots(df, numeric_cols)[
:max_chart_instances
],
dataframe_registry,
)
if section.charts:
chart_sections.append(section)

if len(categorical_cols) >= 2:
section = _quickchart_helpers.heatmaps_section(
df,
_select_first_k_pairs(categorical_cols, k=max_chart_instances),
_quickchart_rank.rank_heatmaps(df, categorical_cols)[
:max_chart_instances
],
dataframe_registry,
)
if section.charts:
Expand All @@ -131,67 +144,12 @@ def determine_charts(df, dataframe_registry, max_chart_instances=None):
if categorical_cols and numeric_cols:
section = _quickchart_helpers.faceted_distributions_section(
df,
_select_faceted_numeric_cols(
numeric_cols, categorical_cols, k=max_chart_instances
),
_quickchart_rank.rank_faceted_distributions(
df, value_colnames=numeric_cols, facet_colnames=categorical_cols
)[:max_chart_instances],
dataframe_registry,
)
if section.charts:
chart_sections.append(section)

return chart_sections


def _select_first_k_pairs(colnames, k=None):
"""Selects the first k pairs of column names, sequentially.
e.g., ['a', 'b', 'c'] => [('a', b'), ('b', 'c')] for k=2
Args:
colnames: (iterable<str>) Column names from which to generate pairs.
k: (int) The number of column pairs.
Returns:
(list<(str, str)>) A k-length sequence of column name pairs.
"""
return itertools.islice(itertools.pairwise(colnames), k)


def _select_faceted_numeric_cols(numeric_cols, categorical_cols, k=None):
"""Selects numeric columns and corresponding categorical facets.
Args:
numeric_cols: (iterable<str>) Available numeric columns.
categorical_cols: (iterable<str>) Available categorical columns.
k: (int) The number of column pairs to select.
Returns:
(iter<(str, str)>) Prioritized sequence of (numeric, categorical) column
pairs.
"""
return itertools.islice(itertools.product(numeric_cols, categorical_cols), k)


def _select_time_series_cols(time_cols, numeric_cols, categorical_cols, k=None):
"""Selects combinations of colnames that can be plotted as time series.
Args:
time_cols: (iter<str>) Available time-like columns.
numeric_cols: (iter<str>) Available numeric columns.
categorical_cols: (iter<str>) Available categorical columns.
k: (int) The number of combinations to select.
Returns:
(iter<(str, str, str)>) Prioritized sequence of (time, value, series)
colname combinations.
"""
numeric_cols = [c for c in numeric_cols if c not in time_cols]
numeric_aggregates = ['count()']
if not categorical_cols:
categorical_cols = [None]
return itertools.islice(
itertools.product(
time_cols, numeric_cols + numeric_aggregates, categorical_cols
),
k,
)
Loading

0 comments on commit 218ef46

Please sign in to comment.