diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index 6db90feae17b8..65112ddc20433 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -131,6 +131,9 @@ def _flatten_column_after_pivot( def validate_column_args(*argnames: str) -> Callable[..., Any]: def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: def wrapped(df: DataFrame, **options: Any) -> Any: + if options.get("is_pivot_df"): + # skip validation when pivot Dataframe + return func(df, **options) columns = df.columns.tolist() for name in argnames: if name in options and not all( @@ -223,6 +226,7 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals marginal_distributions: Optional[bool] = None, marginal_distribution_name: Optional[str] = None, flatten_columns: bool = True, + reset_index: bool = True, ) -> DataFrame: """ Perform a pivot operation on a DataFrame. @@ -243,6 +247,7 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals :param marginal_distribution_name: Name of row/column with marginal distribution. Default to 'All'. :param flatten_columns: Convert column names to strings + :param reset_index: Convert index to column :return: A pivot table :raises QueryObjectValidationError: If the request in incorrect """ @@ -300,7 +305,8 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals _flatten_column_after_pivot(col, aggregates) for col in df.columns ] # return index as regular column - df.reset_index(level=0, inplace=True) + if reset_index: + df.reset_index(level=0, inplace=True) return df @@ -343,13 +349,14 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame: @validate_column_args("columns") def rolling( # pylint: disable=too-many-arguments df: DataFrame, - columns: Dict[str, str], rolling_type: str, + columns: Optional[Dict[str, str]] = None, window: Optional[int] = None, rolling_type_options: Optional[Dict[str, Any]] = None, center: bool = False, win_type: Optional[str] = None, min_periods: Optional[int] = None, + is_pivot_df: bool = False, ) -> DataFrame: """ Apply a rolling window on the dataset. See the Pandas docs for further details: @@ -369,11 +376,16 @@ def rolling( # pylint: disable=too-many-arguments :param win_type: Type of window function. :param min_periods: The minimum amount of periods required for a row to be included in the result set. + :param is_pivot_df: Dataframe is pivoted or not :return: DataFrame with the rolling columns :raises QueryObjectValidationError: If the request in incorrect """ rolling_type_options = rolling_type_options or {} - df_rolling = df[columns.keys()] + columns = columns or {} + if is_pivot_df: + df_rolling = df + else: + df_rolling = df[columns.keys()] kwargs: Dict[str, Union[str, int]] = {} if window is None: raise QueryObjectValidationError(_("Undefined window for rolling operation")) @@ -405,10 +417,20 @@ def rolling( # pylint: disable=too-many-arguments options=rolling_type_options, ) ) from ex - df = _append_columns(df, df_rolling, columns) + + if is_pivot_df: + agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list() + agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df} + df_rolling.columns = [ + _flatten_column_after_pivot(col, agg) for col in df_rolling.columns + ] + df_rolling.reset_index(level=0, inplace=True) + else: + df_rolling = _append_columns(df, df_rolling, columns) + if min_periods: - df = df[min_periods:] - return df + df_rolling = df_rolling[min_periods:] + return df_rolling @validate_column_args("columns", "drop", "rename") @@ -524,7 +546,12 @@ def compare( # pylint: disable=too-many-arguments @validate_column_args("columns") -def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame: +def cum( + df: DataFrame, + operator: str, + columns: Optional[Dict[str, str]] = None, + is_pivot_df: bool = False, +) -> DataFrame: """ Calculate cumulative sum/product/min/max for select columns. @@ -535,9 +562,14 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame: `y2` based on cumulative values calculated from `y`, leaving the original column `y` unchanged. :param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max` + :param is_pivot_df: Dataframe is pivoted or not :return: DataFrame with cumulated columns """ - df_cum = df[columns.keys()] + columns = columns or {} + if is_pivot_df: + df_cum = df + else: + df_cum = df[columns.keys()] operation = "cum" + operator if operation not in ALLOWLIST_CUMULATIVE_FUNCTIONS or not hasattr( df_cum, operation @@ -545,7 +577,17 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame: raise QueryObjectValidationError( _("Invalid cumulative operator: %(operator)s", operator=operator) ) - return _append_columns(df, getattr(df_cum, operation)(), columns) + if is_pivot_df: + df_cum = getattr(df_cum, operation)() + agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list() + agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df} + df_cum.columns = [ + _flatten_column_after_pivot(col, agg) for col in df_cum.columns + ] + df_cum.reset_index(level=0, inplace=True) + else: + df_cum = _append_columns(df, getattr(df_cum, operation)(), columns) + return df_cum def geohash_decode( diff --git a/tests/integration_tests/fixtures/dataframes.py b/tests/integration_tests/fixtures/dataframes.py index 28bc32fade774..2a49bd3f8d951 100644 --- a/tests/integration_tests/fixtures/dataframes.py +++ b/tests/integration_tests/fixtures/dataframes.py @@ -165,3 +165,19 @@ "b": [4, 3, 4.1, 3.95], } ) + +single_metric_df = DataFrame( + { + "dttm": to_datetime(["2019-01-01", "2019-01-01", "2019-01-02", "2019-01-02",]), + "country": ["UK", "US", "UK", "US"], + "sum_metric": [5, 6, 7, 8], + } +) +multiple_metrics_df = DataFrame( + { + "dttm": to_datetime(["2019-01-01", "2019-01-01", "2019-01-02", "2019-01-02",]), + "country": ["UK", "US", "UK", "US"], + "sum_metric": [5, 6, 7, 8], + "count_metric": [1, 2, 3, 4], + } +) diff --git a/tests/integration_tests/pandas_postprocessing_tests.py b/tests/integration_tests/pandas_postprocessing_tests.py index f242e3232291a..7221130be85bf 100644 --- a/tests/integration_tests/pandas_postprocessing_tests.py +++ b/tests/integration_tests/pandas_postprocessing_tests.py @@ -35,6 +35,8 @@ from .base_tests import SupersetTestCase from .fixtures.dataframes import ( categories_df, + single_metric_df, + multiple_metrics_df, lonlat_df, names_df, timeseries_df, @@ -305,6 +307,23 @@ def test_pivot_eliminate_cartesian_product_columns(self): ) self.assertTrue(np.isnan(df["metric, 1, 1"][0])) + def test_pivot_without_flatten_columns_and_reset_index(self): + df = proc.pivot( + df=single_metric_df, + index=["dttm"], + columns=["country"], + aggregates={"sum_metric": {"operator": "sum"}}, + flatten_columns=False, + reset_index=False, + ) + # metric + # country UK US + # dttm + # 2019-01-01 5 6 + # 2019-01-02 7 8 + assert df.columns.to_list() == [("sum_metric", "UK"), ("sum_metric", "US")] + assert df.index.to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list() + def test_aggregate(self): aggregates = { "asc sum": {"column": "asc_idx", "operator": "sum"}, @@ -405,6 +424,60 @@ def test_rolling(self): window=2, ) + def test_rolling_with_pivot_df_and_single_metric(self): + pivot_df = proc.pivot( + df=single_metric_df, + index=["dttm"], + columns=["country"], + aggregates={"sum_metric": {"operator": "sum"}}, + flatten_columns=False, + reset_index=False, + ) + rolling_df = proc.rolling( + df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True, + ) + # dttm UK US + # 0 2019-01-01 5 6 + # 1 2019-01-02 12 14 + assert rolling_df["UK"].to_list() == [5.0, 12.0] + assert rolling_df["US"].to_list() == [6.0, 14.0] + assert ( + rolling_df["dttm"].to_list() + == to_datetime(["2019-01-01", "2019-01-02",]).to_list() + ) + + rolling_df = proc.rolling( + df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True, + ) + assert rolling_df.empty is True + + def test_rolling_with_pivot_df_and_multiple_metrics(self): + pivot_df = proc.pivot( + df=multiple_metrics_df, + index=["dttm"], + columns=["country"], + aggregates={ + "sum_metric": {"operator": "sum"}, + "count_metric": {"operator": "sum"}, + }, + flatten_columns=False, + reset_index=False, + ) + rolling_df = proc.rolling( + df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True, + ) + # dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US + # 0 2019-01-01 1.0 2.0 5.0 6.0 + # 1 2019-01-02 4.0 6.0 12.0 14.0 + assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0] + assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0] + assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0] + assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0] + assert ( + rolling_df["dttm"].to_list() + == to_datetime(["2019-01-01", "2019-01-02",]).to_list() + ) + def test_select(self): # reorder columns post_df = proc.select(df=timeseries_df, columns=["y", "label"]) @@ -557,6 +630,51 @@ def test_cum(self): operator="abc", ) + def test_cum_with_pivot_df_and_single_metric(self): + pivot_df = proc.pivot( + df=single_metric_df, + index=["dttm"], + columns=["country"], + aggregates={"sum_metric": {"operator": "sum"}}, + flatten_columns=False, + reset_index=False, + ) + cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,) + # dttm UK US + # 0 2019-01-01 5 6 + # 1 2019-01-02 12 14 + assert cum_df["UK"].to_list() == [5.0, 12.0] + assert cum_df["US"].to_list() == [6.0, 14.0] + assert ( + cum_df["dttm"].to_list() + == to_datetime(["2019-01-01", "2019-01-02",]).to_list() + ) + + def test_cum_with_pivot_df_and_multiple_metrics(self): + pivot_df = proc.pivot( + df=multiple_metrics_df, + index=["dttm"], + columns=["country"], + aggregates={ + "sum_metric": {"operator": "sum"}, + "count_metric": {"operator": "sum"}, + }, + flatten_columns=False, + reset_index=False, + ) + cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,) + # dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US + # 0 2019-01-01 1 2 5 6 + # 1 2019-01-02 4 6 12 14 + assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0] + assert cum_df["count_metric, US"].to_list() == [2.0, 6.0] + assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0] + assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0] + assert ( + cum_df["dttm"].to_list() + == to_datetime(["2019-01-01", "2019-01-02",]).to_list() + ) + def test_geohash_decode(self): # decode lon/lat from geohash post_df = proc.geohash_decode(