Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No public description #4188

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 29 additions & 71 deletions google/colab/_quickchart.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Automated chart generation for data frames."""
import itertools

import IPython

Expand Down Expand Up @@ -67,6 +66,7 @@ def determine_charts(df, dataframe_registry, max_chart_instances=None):
# Lazy import to avoid loading matplotlib and transitive deps on kernel init.
from google.colab import _quickchart_dtypes # pylint: disable=g-import-not-at-top
from google.colab import _quickchart_helpers # pylint: disable=g-import-not-at-top
from google.colab import _quickchart_rank # pylint: disable=g-import-not-at-top

dtype_groups = _quickchart_dtypes.classify_dtypes(df)
numeric_cols = dtype_groups['numeric']
Expand All @@ -76,23 +76,30 @@ def determine_charts(df, dataframe_registry, max_chart_instances=None):

if numeric_cols:
section = _quickchart_helpers.histograms_section(
df, numeric_cols[:max_chart_instances], dataframe_registry
df,
_quickchart_rank.rank_histograms(df, numeric_cols)[
:max_chart_instances
],
dataframe_registry,
)
if section.charts:
chart_sections.append(section)

if categorical_cols:
selected_categorical_cols = categorical_cols[:max_chart_instances]
section = _quickchart_helpers.categorical_histograms_section(
df, selected_categorical_cols, dataframe_registry
df,
_quickchart_rank.rank_histograms(df, categorical_cols)[
:max_chart_instances
],
dataframe_registry,
)
if section.charts:
chart_sections.append(section)

if len(numeric_cols) >= 2:
section = _quickchart_helpers.scatter_section(
df,
_select_first_k_pairs(numeric_cols, k=max_chart_instances),
_quickchart_rank.rank_scatter(df, numeric_cols)[:max_chart_instances],
dataframe_registry,
)
if section.charts:
Expand All @@ -101,28 +108,34 @@ def determine_charts(df, dataframe_registry, max_chart_instances=None):
if time_cols:
section = _quickchart_helpers.time_series_line_plots_section(
df,
_select_time_series_cols(
time_cols=time_cols,
numeric_cols=numeric_cols,
categorical_cols=categorical_cols,
k=max_chart_instances,
),
_quickchart_rank.rank_time_series_plots(
df,
time_colnames=time_cols,
numeric_colnames=numeric_cols,
categorical_colnames=categorical_cols,
)[:max_chart_instances],
dataframe_registry,
)
if section.charts:
chart_sections.append(section)

if numeric_cols:
section = _quickchart_helpers.value_plots_section(
df, numeric_cols[:max_chart_instances], dataframe_registry
df,
_quickchart_rank.rank_value_plots(df, numeric_cols)[
:max_chart_instances
],
dataframe_registry,
)
if section.charts:
chart_sections.append(section)

if len(categorical_cols) >= 2:
section = _quickchart_helpers.heatmaps_section(
df,
_select_first_k_pairs(categorical_cols, k=max_chart_instances),
_quickchart_rank.rank_heatmaps(df, categorical_cols)[
:max_chart_instances
],
dataframe_registry,
)
if section.charts:
Expand All @@ -131,67 +144,12 @@ def determine_charts(df, dataframe_registry, max_chart_instances=None):
if categorical_cols and numeric_cols:
section = _quickchart_helpers.faceted_distributions_section(
df,
_select_faceted_numeric_cols(
numeric_cols, categorical_cols, k=max_chart_instances
),
_quickchart_rank.rank_faceted_distributions(
df, value_colnames=numeric_cols, facet_colnames=categorical_cols
)[:max_chart_instances],
dataframe_registry,
)
if section.charts:
chart_sections.append(section)

return chart_sections


def _select_first_k_pairs(colnames, k=None):
"""Selects the first k pairs of column names, sequentially.

e.g., ['a', 'b', 'c'] => [('a', b'), ('b', 'c')] for k=2

Args:
colnames: (iterable<str>) Column names from which to generate pairs.
k: (int) The number of column pairs.

Returns:
(list<(str, str)>) A k-length sequence of column name pairs.
"""
return itertools.islice(itertools.pairwise(colnames), k)


def _select_faceted_numeric_cols(numeric_cols, categorical_cols, k=None):
"""Selects numeric columns and corresponding categorical facets.

Args:
numeric_cols: (iterable<str>) Available numeric columns.
categorical_cols: (iterable<str>) Available categorical columns.
k: (int) The number of column pairs to select.

Returns:
(iter<(str, str)>) Prioritized sequence of (numeric, categorical) column
pairs.
"""
return itertools.islice(itertools.product(numeric_cols, categorical_cols), k)


def _select_time_series_cols(time_cols, numeric_cols, categorical_cols, k=None):
"""Selects combinations of colnames that can be plotted as time series.

Args:
time_cols: (iter<str>) Available time-like columns.
numeric_cols: (iter<str>) Available numeric columns.
categorical_cols: (iter<str>) Available categorical columns.
k: (int) The number of combinations to select.

Returns:
(iter<(str, str, str)>) Prioritized sequence of (time, value, series)
colname combinations.
"""
numeric_cols = [c for c in numeric_cols if c not in time_cols]
numeric_aggregates = ['count()']
if not categorical_cols:
categorical_cols = [None]
return itertools.islice(
itertools.product(
time_cols, numeric_cols + numeric_aggregates, categorical_cols
),
k,
)
Loading
Loading