diff --git a/tests/integration_tests/pandas_postprocessing_tests.py b/tests/integration_tests/pandas_postprocessing_tests.py deleted file mode 100644 index 50612e1da3055..0000000000000 --- a/tests/integration_tests/pandas_postprocessing_tests.py +++ /dev/null @@ -1,1098 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# isort:skip_file -from datetime import datetime -from importlib.util import find_spec -import math -from typing import Any, List, Optional - -import numpy as np -from pandas import DataFrame, Series, Timestamp, to_datetime -import pytest - -from superset.exceptions import QueryObjectValidationError -from superset.utils import pandas_postprocessing as proc -from superset.utils.core import ( - DTTM_ALIAS, - PostProcessingContributionOrientation, - PostProcessingBoxplotWhiskerType, -) - -from .base_tests import SupersetTestCase -from .fixtures.dataframes import ( - categories_df, - single_metric_df, - multiple_metrics_df, - lonlat_df, - names_df, - timeseries_df, - prophet_df, - timeseries_df2, -) - -AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}} -AGGREGATES_MULTIPLE = { - "idx_nulls": {"operator": "sum"}, - "asc_idx": {"operator": "mean"}, -} - - -def series_to_list(series: Series) -> List[Any]: - """ - Converts a `Series` to a regular list, and replaces non-numeric values to - Nones. - - :param series: Series to convert - :return: list without nan or inf - """ - return [ - None - if not isinstance(val, str) and (math.isnan(val) or math.isinf(val)) - else val - for val in series.tolist() - ] - - -def round_floats( - floats: List[Optional[float]], precision: int -) -> List[Optional[float]]: - """ - Round list of floats to certain precision - - :param floats: floats to round - :param precision: intended decimal precision - :return: rounded floats - """ - return [round(val, precision) if val else None for val in floats] - - -class TestPostProcessing(SupersetTestCase): - def test_flatten_column_after_pivot(self): - """ - Test pivot column flattening function - """ - # single aggregate cases - self.assertEqual( - proc._flatten_column_after_pivot( - aggregates=AGGREGATES_SINGLE, column="idx_nulls", - ), - "idx_nulls", - ) - self.assertEqual( - proc._flatten_column_after_pivot( - aggregates=AGGREGATES_SINGLE, column=1234, - ), - "1234", - ) - self.assertEqual( - proc._flatten_column_after_pivot( - aggregates=AGGREGATES_SINGLE, column=Timestamp("2020-09-29T00:00:00"), - ), - "2020-09-29 00:00:00", - ) - self.assertEqual( - proc._flatten_column_after_pivot( - aggregates=AGGREGATES_SINGLE, column="idx_nulls", - ), - "idx_nulls", - ) - self.assertEqual( - proc._flatten_column_after_pivot( - aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1"), - ), - "col1", - ) - self.assertEqual( - proc._flatten_column_after_pivot( - aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1", 1234), - ), - "col1, 1234", - ) - - # Multiple aggregate cases - self.assertEqual( - proc._flatten_column_after_pivot( - aggregates=AGGREGATES_MULTIPLE, column=("idx_nulls", "asc_idx", "col1"), - ), - "idx_nulls, asc_idx, col1", - ) - self.assertEqual( - proc._flatten_column_after_pivot( - aggregates=AGGREGATES_MULTIPLE, - column=("idx_nulls", "asc_idx", "col1", 1234), - ), - "idx_nulls, asc_idx, col1, 1234", - ) - - def test_pivot_without_columns(self): - """ - Make sure pivot without columns returns correct DataFrame - """ - df = proc.pivot(df=categories_df, index=["name"], aggregates=AGGREGATES_SINGLE,) - self.assertListEqual( - df.columns.tolist(), ["name", "idx_nulls"], - ) - self.assertEqual(len(df), 101) - self.assertEqual(df.sum()[1], 1050) - - def test_pivot_with_single_column(self): - """ - Make sure pivot with single column returns correct DataFrame - """ - df = proc.pivot( - df=categories_df, - index=["name"], - columns=["category"], - aggregates=AGGREGATES_SINGLE, - ) - self.assertListEqual( - df.columns.tolist(), ["name", "cat0", "cat1", "cat2"], - ) - self.assertEqual(len(df), 101) - self.assertEqual(df.sum()[1], 315) - - df = proc.pivot( - df=categories_df, - index=["dept"], - columns=["category"], - aggregates=AGGREGATES_SINGLE, - ) - self.assertListEqual( - df.columns.tolist(), ["dept", "cat0", "cat1", "cat2"], - ) - self.assertEqual(len(df), 5) - - def test_pivot_with_multiple_columns(self): - """ - Make sure pivot with multiple columns returns correct DataFrame - """ - df = proc.pivot( - df=categories_df, - index=["name"], - columns=["category", "dept"], - aggregates=AGGREGATES_SINGLE, - ) - self.assertEqual(len(df.columns), 1 + 3 * 5) # index + possible permutations - - def test_pivot_fill_values(self): - """ - Make sure pivot with fill values returns correct DataFrame - """ - df = proc.pivot( - df=categories_df, - index=["name"], - columns=["category"], - metric_fill_value=1, - aggregates={"idx_nulls": {"operator": "sum"}}, - ) - self.assertEqual(df.sum()[1], 382) - - def test_pivot_fill_column_values(self): - """ - Make sure pivot witn null column names returns correct DataFrame - """ - df_copy = categories_df.copy() - df_copy["category"] = None - df = proc.pivot( - df=df_copy, - index=["name"], - columns=["category"], - aggregates={"idx_nulls": {"operator": "sum"}}, - ) - assert len(df) == 101 - assert df.columns.tolist() == ["name", ""] - - def test_pivot_exceptions(self): - """ - Make sure pivot raises correct Exceptions - """ - # Missing index - self.assertRaises( - TypeError, - proc.pivot, - df=categories_df, - columns=["dept"], - aggregates=AGGREGATES_SINGLE, - ) - - # invalid index reference - self.assertRaises( - QueryObjectValidationError, - proc.pivot, - df=categories_df, - index=["abc"], - columns=["dept"], - aggregates=AGGREGATES_SINGLE, - ) - - # invalid column reference - self.assertRaises( - QueryObjectValidationError, - proc.pivot, - df=categories_df, - index=["dept"], - columns=["abc"], - aggregates=AGGREGATES_SINGLE, - ) - - # invalid aggregate options - self.assertRaises( - QueryObjectValidationError, - proc.pivot, - df=categories_df, - index=["name"], - columns=["category"], - aggregates={"idx_nulls": {}}, - ) - - def test_pivot_eliminate_cartesian_product_columns(self): - # single metric - mock_df = DataFrame( - { - "dttm": to_datetime(["2019-01-01", "2019-01-01"]), - "a": [0, 1], - "b": [0, 1], - "metric": [9, np.NAN], - } - ) - - df = proc.pivot( - df=mock_df, - index=["dttm"], - columns=["a", "b"], - aggregates={"metric": {"operator": "mean"}}, - drop_missing_columns=False, - ) - self.assertEqual(list(df.columns), ["dttm", "0, 0", "1, 1"]) - self.assertTrue(np.isnan(df["1, 1"][0])) - - # multiple metrics - mock_df = DataFrame( - { - "dttm": to_datetime(["2019-01-01", "2019-01-01"]), - "a": [0, 1], - "b": [0, 1], - "metric": [9, np.NAN], - "metric2": [10, 11], - } - ) - - df = proc.pivot( - df=mock_df, - index=["dttm"], - columns=["a", "b"], - aggregates={ - "metric": {"operator": "mean"}, - "metric2": {"operator": "mean"}, - }, - drop_missing_columns=False, - ) - self.assertEqual( - list(df.columns), - ["dttm", "metric, 0, 0", "metric, 1, 1", "metric2, 0, 0", "metric2, 1, 1"], - ) - self.assertTrue(np.isnan(df["metric, 1, 1"][0])) - - def test_pivot_without_flatten_columns_and_reset_index(self): - df = proc.pivot( - df=single_metric_df, - index=["dttm"], - columns=["country"], - aggregates={"sum_metric": {"operator": "sum"}}, - flatten_columns=False, - reset_index=False, - ) - # metric - # country UK US - # dttm - # 2019-01-01 5 6 - # 2019-01-02 7 8 - assert df.columns.to_list() == [("sum_metric", "UK"), ("sum_metric", "US")] - assert df.index.to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list() - - def test_aggregate(self): - aggregates = { - "asc sum": {"column": "asc_idx", "operator": "sum"}, - "asc q2": { - "column": "asc_idx", - "operator": "percentile", - "options": {"q": 75}, - }, - "desc q1": { - "column": "desc_idx", - "operator": "percentile", - "options": {"q": 25}, - }, - } - df = proc.aggregate( - df=categories_df, groupby=["constant"], aggregates=aggregates - ) - self.assertListEqual( - df.columns.tolist(), ["constant", "asc sum", "asc q2", "desc q1"] - ) - self.assertEqual(series_to_list(df["asc sum"])[0], 5050) - self.assertEqual(series_to_list(df["asc q2"])[0], 75) - self.assertEqual(series_to_list(df["desc q1"])[0], 25) - - def test_sort(self): - df = proc.sort(df=categories_df, columns={"category": True, "asc_idx": False}) - self.assertEqual(96, series_to_list(df["asc_idx"])[1]) - - self.assertRaises( - QueryObjectValidationError, proc.sort, df=df, columns={"abc": True} - ) - - def test_rolling(self): - # sum rolling type - post_df = proc.rolling( - df=timeseries_df, - columns={"y": "y"}, - rolling_type="sum", - window=2, - min_periods=0, - ) - - self.assertListEqual(post_df.columns.tolist(), ["label", "y"]) - self.assertListEqual(series_to_list(post_df["y"]), [1.0, 3.0, 5.0, 7.0]) - - # mean rolling type with alias - post_df = proc.rolling( - df=timeseries_df, - rolling_type="mean", - columns={"y": "y_mean"}, - window=10, - min_periods=0, - ) - self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y_mean"]) - self.assertListEqual(series_to_list(post_df["y_mean"]), [1.0, 1.5, 2.0, 2.5]) - - # count rolling type - post_df = proc.rolling( - df=timeseries_df, - rolling_type="count", - columns={"y": "y"}, - window=10, - min_periods=0, - ) - self.assertListEqual(post_df.columns.tolist(), ["label", "y"]) - self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0]) - - # quantile rolling type - post_df = proc.rolling( - df=timeseries_df, - columns={"y": "q1"}, - rolling_type="quantile", - rolling_type_options={"quantile": 0.25}, - window=10, - min_periods=0, - ) - self.assertListEqual(post_df.columns.tolist(), ["label", "y", "q1"]) - self.assertListEqual(series_to_list(post_df["q1"]), [1.0, 1.25, 1.5, 1.75]) - - # incorrect rolling type - self.assertRaises( - QueryObjectValidationError, - proc.rolling, - df=timeseries_df, - columns={"y": "y"}, - rolling_type="abc", - window=2, - ) - - # incorrect rolling type options - self.assertRaises( - QueryObjectValidationError, - proc.rolling, - df=timeseries_df, - columns={"y": "y"}, - rolling_type="quantile", - rolling_type_options={"abc": 123}, - window=2, - ) - - def test_rolling_with_pivot_df_and_single_metric(self): - pivot_df = proc.pivot( - df=single_metric_df, - index=["dttm"], - columns=["country"], - aggregates={"sum_metric": {"operator": "sum"}}, - flatten_columns=False, - reset_index=False, - ) - rolling_df = proc.rolling( - df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True, - ) - # dttm UK US - # 0 2019-01-01 5 6 - # 1 2019-01-02 12 14 - assert rolling_df["UK"].to_list() == [5.0, 12.0] - assert rolling_df["US"].to_list() == [6.0, 14.0] - assert ( - rolling_df["dttm"].to_list() - == to_datetime(["2019-01-01", "2019-01-02",]).to_list() - ) - - rolling_df = proc.rolling( - df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True, - ) - assert rolling_df.empty is True - - def test_rolling_with_pivot_df_and_multiple_metrics(self): - pivot_df = proc.pivot( - df=multiple_metrics_df, - index=["dttm"], - columns=["country"], - aggregates={ - "sum_metric": {"operator": "sum"}, - "count_metric": {"operator": "sum"}, - }, - flatten_columns=False, - reset_index=False, - ) - rolling_df = proc.rolling( - df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True, - ) - # dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US - # 0 2019-01-01 1.0 2.0 5.0 6.0 - # 1 2019-01-02 4.0 6.0 12.0 14.0 - assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0] - assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0] - assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0] - assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0] - assert ( - rolling_df["dttm"].to_list() - == to_datetime(["2019-01-01", "2019-01-02",]).to_list() - ) - - def test_select(self): - # reorder columns - post_df = proc.select(df=timeseries_df, columns=["y", "label"]) - self.assertListEqual(post_df.columns.tolist(), ["y", "label"]) - - # one column - post_df = proc.select(df=timeseries_df, columns=["label"]) - self.assertListEqual(post_df.columns.tolist(), ["label"]) - - # rename and select one column - post_df = proc.select(df=timeseries_df, columns=["y"], rename={"y": "y1"}) - self.assertListEqual(post_df.columns.tolist(), ["y1"]) - - # rename one and leave one unchanged - post_df = proc.select(df=timeseries_df, rename={"y": "y1"}) - self.assertListEqual(post_df.columns.tolist(), ["label", "y1"]) - - # drop one column - post_df = proc.select(df=timeseries_df, exclude=["label"]) - self.assertListEqual(post_df.columns.tolist(), ["y"]) - - # rename and drop one column - post_df = proc.select(df=timeseries_df, rename={"y": "y1"}, exclude=["label"]) - self.assertListEqual(post_df.columns.tolist(), ["y1"]) - - # invalid columns - self.assertRaises( - QueryObjectValidationError, - proc.select, - df=timeseries_df, - columns=["abc"], - rename={"abc": "qwerty"}, - ) - - # select renamed column by new name - self.assertRaises( - QueryObjectValidationError, - proc.select, - df=timeseries_df, - columns=["label_new"], - rename={"label": "label_new"}, - ) - - def test_diff(self): - # overwrite column - post_df = proc.diff(df=timeseries_df, columns={"y": "y"}) - self.assertListEqual(post_df.columns.tolist(), ["label", "y"]) - self.assertListEqual(series_to_list(post_df["y"]), [None, 1.0, 1.0, 1.0]) - - # add column - post_df = proc.diff(df=timeseries_df, columns={"y": "y1"}) - self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y1"]) - self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0]) - self.assertListEqual(series_to_list(post_df["y1"]), [None, 1.0, 1.0, 1.0]) - - # look ahead - post_df = proc.diff(df=timeseries_df, columns={"y": "y1"}, periods=-1) - self.assertListEqual(series_to_list(post_df["y1"]), [-1.0, -1.0, -1.0, None]) - - # invalid column reference - self.assertRaises( - QueryObjectValidationError, - proc.diff, - df=timeseries_df, - columns={"abc": "abc"}, - ) - - # diff by columns - post_df = proc.diff(df=timeseries_df2, columns={"y": "y", "z": "z"}, axis=1) - self.assertListEqual(post_df.columns.tolist(), ["label", "y", "z"]) - self.assertListEqual(series_to_list(post_df["z"]), [0.0, 2.0, 8.0, 6.0]) - - def test_compare(self): - # `difference` comparison - post_df = proc.compare( - df=timeseries_df2, - source_columns=["y"], - compare_columns=["z"], - compare_type="difference", - ) - self.assertListEqual( - post_df.columns.tolist(), ["label", "y", "z", "difference__y__z",] - ) - self.assertListEqual( - series_to_list(post_df["difference__y__z"]), [0.0, -2.0, -8.0, -6.0], - ) - - # drop original columns - post_df = proc.compare( - df=timeseries_df2, - source_columns=["y"], - compare_columns=["z"], - compare_type="difference", - drop_original_columns=True, - ) - self.assertListEqual(post_df.columns.tolist(), ["label", "difference__y__z",]) - - # `percentage` comparison - post_df = proc.compare( - df=timeseries_df2, - source_columns=["y"], - compare_columns=["z"], - compare_type="percentage", - ) - self.assertListEqual( - post_df.columns.tolist(), ["label", "y", "z", "percentage__y__z",] - ) - self.assertListEqual( - series_to_list(post_df["percentage__y__z"]), [0.0, -0.5, -0.8, -0.75], - ) - - # `ratio` comparison - post_df = proc.compare( - df=timeseries_df2, - source_columns=["y"], - compare_columns=["z"], - compare_type="ratio", - ) - self.assertListEqual( - post_df.columns.tolist(), ["label", "y", "z", "ratio__y__z",] - ) - self.assertListEqual( - series_to_list(post_df["ratio__y__z"]), [1.0, 0.5, 0.2, 0.25], - ) - - def test_cum(self): - # create new column (cumsum) - post_df = proc.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",) - self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y2"]) - self.assertListEqual(series_to_list(post_df["label"]), ["x", "y", "z", "q"]) - self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0]) - self.assertListEqual(series_to_list(post_df["y2"]), [1.0, 3.0, 6.0, 10.0]) - - # overwrite column (cumprod) - post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",) - self.assertListEqual(post_df.columns.tolist(), ["label", "y"]) - self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 6.0, 24.0]) - - # overwrite column (cummin) - post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="min",) - self.assertListEqual(post_df.columns.tolist(), ["label", "y"]) - self.assertListEqual(series_to_list(post_df["y"]), [1.0, 1.0, 1.0, 1.0]) - - # invalid operator - self.assertRaises( - QueryObjectValidationError, - proc.cum, - df=timeseries_df, - columns={"y": "y"}, - operator="abc", - ) - - def test_cum_with_pivot_df_and_single_metric(self): - pivot_df = proc.pivot( - df=single_metric_df, - index=["dttm"], - columns=["country"], - aggregates={"sum_metric": {"operator": "sum"}}, - flatten_columns=False, - reset_index=False, - ) - cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,) - # dttm UK US - # 0 2019-01-01 5 6 - # 1 2019-01-02 12 14 - assert cum_df["UK"].to_list() == [5.0, 12.0] - assert cum_df["US"].to_list() == [6.0, 14.0] - assert ( - cum_df["dttm"].to_list() - == to_datetime(["2019-01-01", "2019-01-02",]).to_list() - ) - - def test_cum_with_pivot_df_and_multiple_metrics(self): - pivot_df = proc.pivot( - df=multiple_metrics_df, - index=["dttm"], - columns=["country"], - aggregates={ - "sum_metric": {"operator": "sum"}, - "count_metric": {"operator": "sum"}, - }, - flatten_columns=False, - reset_index=False, - ) - cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,) - # dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US - # 0 2019-01-01 1 2 5 6 - # 1 2019-01-02 4 6 12 14 - assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0] - assert cum_df["count_metric, US"].to_list() == [2.0, 6.0] - assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0] - assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0] - assert ( - cum_df["dttm"].to_list() - == to_datetime(["2019-01-01", "2019-01-02",]).to_list() - ) - - def test_geohash_decode(self): - # decode lon/lat from geohash - post_df = proc.geohash_decode( - df=lonlat_df[["city", "geohash"]], - geohash="geohash", - latitude="latitude", - longitude="longitude", - ) - self.assertListEqual( - sorted(post_df.columns.tolist()), - sorted(["city", "geohash", "latitude", "longitude"]), - ) - self.assertListEqual( - round_floats(series_to_list(post_df["longitude"]), 6), - round_floats(series_to_list(lonlat_df["longitude"]), 6), - ) - self.assertListEqual( - round_floats(series_to_list(post_df["latitude"]), 6), - round_floats(series_to_list(lonlat_df["latitude"]), 6), - ) - - def test_geohash_encode(self): - # encode lon/lat into geohash - post_df = proc.geohash_encode( - df=lonlat_df[["city", "latitude", "longitude"]], - latitude="latitude", - longitude="longitude", - geohash="geohash", - ) - self.assertListEqual( - sorted(post_df.columns.tolist()), - sorted(["city", "geohash", "latitude", "longitude"]), - ) - self.assertListEqual( - series_to_list(post_df["geohash"]), series_to_list(lonlat_df["geohash"]), - ) - - def test_geodetic_parse(self): - # parse geodetic string with altitude into lon/lat/altitude - post_df = proc.geodetic_parse( - df=lonlat_df[["city", "geodetic"]], - geodetic="geodetic", - latitude="latitude", - longitude="longitude", - altitude="altitude", - ) - self.assertListEqual( - sorted(post_df.columns.tolist()), - sorted(["city", "geodetic", "latitude", "longitude", "altitude"]), - ) - self.assertListEqual( - series_to_list(post_df["longitude"]), - series_to_list(lonlat_df["longitude"]), - ) - self.assertListEqual( - series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]), - ) - self.assertListEqual( - series_to_list(post_df["altitude"]), series_to_list(lonlat_df["altitude"]), - ) - - # parse geodetic string into lon/lat - post_df = proc.geodetic_parse( - df=lonlat_df[["city", "geodetic"]], - geodetic="geodetic", - latitude="latitude", - longitude="longitude", - ) - self.assertListEqual( - sorted(post_df.columns.tolist()), - sorted(["city", "geodetic", "latitude", "longitude"]), - ) - self.assertListEqual( - series_to_list(post_df["longitude"]), - series_to_list(lonlat_df["longitude"]), - ) - self.assertListEqual( - series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]), - ) - - def test_contribution(self): - df = DataFrame( - { - DTTM_ALIAS: [ - datetime(2020, 7, 16, 14, 49), - datetime(2020, 7, 16, 14, 50), - ], - "a": [1, 3], - "b": [1, 9], - } - ) - with pytest.raises(QueryObjectValidationError, match="not numeric"): - proc.contribution(df, columns=[DTTM_ALIAS]) - - with pytest.raises(QueryObjectValidationError, match="same length"): - proc.contribution(df, columns=["a"], rename_columns=["aa", "bb"]) - - # cell contribution across row - processed_df = proc.contribution( - df, orientation=PostProcessingContributionOrientation.ROW, - ) - self.assertListEqual(processed_df.columns.tolist(), [DTTM_ALIAS, "a", "b"]) - self.assertListEqual(processed_df["a"].tolist(), [0.5, 0.25]) - self.assertListEqual(processed_df["b"].tolist(), [0.5, 0.75]) - - # cell contribution across column without temporal column - df.pop(DTTM_ALIAS) - processed_df = proc.contribution( - df, orientation=PostProcessingContributionOrientation.COLUMN - ) - self.assertListEqual(processed_df.columns.tolist(), ["a", "b"]) - self.assertListEqual(processed_df["a"].tolist(), [0.25, 0.75]) - self.assertListEqual(processed_df["b"].tolist(), [0.1, 0.9]) - - # contribution only on selected columns - processed_df = proc.contribution( - df, - orientation=PostProcessingContributionOrientation.COLUMN, - columns=["a"], - rename_columns=["pct_a"], - ) - self.assertListEqual(processed_df.columns.tolist(), ["a", "b", "pct_a"]) - self.assertListEqual(processed_df["a"].tolist(), [1, 3]) - self.assertListEqual(processed_df["b"].tolist(), [1, 9]) - self.assertListEqual(processed_df["pct_a"].tolist(), [0.25, 0.75]) - - def test_prophet_valid(self): - pytest.importorskip("prophet") - - df = proc.prophet( - df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9 - ) - columns = {column for column in df.columns} - assert columns == { - DTTM_ALIAS, - "a__yhat", - "a__yhat_upper", - "a__yhat_lower", - "a", - "b__yhat", - "b__yhat_upper", - "b__yhat_lower", - "b", - } - assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) - assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31) - assert len(df) == 7 - - df = proc.prophet( - df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9 - ) - assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) - assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31) - assert len(df) == 9 - - def test_prophet_valid_zero_periods(self): - pytest.importorskip("prophet") - - df = proc.prophet( - df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9 - ) - columns = {column for column in df.columns} - assert columns == { - DTTM_ALIAS, - "a__yhat", - "a__yhat_upper", - "a__yhat_lower", - "a", - "b__yhat", - "b__yhat_upper", - "b__yhat_lower", - "b", - } - assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) - assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31) - assert len(df) == 4 - - def test_prophet_import(self): - prophet = find_spec("prophet") - if prophet is None: - with pytest.raises(QueryObjectValidationError): - proc.prophet( - df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9 - ) - - def test_prophet_missing_temporal_column(self): - df = prophet_df.drop(DTTM_ALIAS, axis=1) - - self.assertRaises( - QueryObjectValidationError, - proc.prophet, - df=df, - time_grain="P1M", - periods=3, - confidence_interval=0.9, - ) - - def test_prophet_incorrect_confidence_interval(self): - self.assertRaises( - QueryObjectValidationError, - proc.prophet, - df=prophet_df, - time_grain="P1M", - periods=3, - confidence_interval=0.0, - ) - - self.assertRaises( - QueryObjectValidationError, - proc.prophet, - df=prophet_df, - time_grain="P1M", - periods=3, - confidence_interval=1.0, - ) - - def test_prophet_incorrect_periods(self): - self.assertRaises( - QueryObjectValidationError, - proc.prophet, - df=prophet_df, - time_grain="P1M", - periods=-1, - confidence_interval=0.8, - ) - - def test_prophet_incorrect_time_grain(self): - self.assertRaises( - QueryObjectValidationError, - proc.prophet, - df=prophet_df, - time_grain="yearly", - periods=10, - confidence_interval=0.8, - ) - - def test_boxplot_tukey(self): - df = proc.boxplot( - df=names_df, - groupby=["region"], - whisker_type=PostProcessingBoxplotWhiskerType.TUKEY, - metrics=["cars"], - ) - columns = {column for column in df.columns} - assert columns == { - "cars__mean", - "cars__median", - "cars__q1", - "cars__q3", - "cars__max", - "cars__min", - "cars__count", - "cars__outliers", - "region", - } - assert len(df) == 4 - - def test_boxplot_min_max(self): - df = proc.boxplot( - df=names_df, - groupby=["region"], - whisker_type=PostProcessingBoxplotWhiskerType.MINMAX, - metrics=["cars"], - ) - columns = {column for column in df.columns} - assert columns == { - "cars__mean", - "cars__median", - "cars__q1", - "cars__q3", - "cars__max", - "cars__min", - "cars__count", - "cars__outliers", - "region", - } - assert len(df) == 4 - - def test_boxplot_percentile(self): - df = proc.boxplot( - df=names_df, - groupby=["region"], - whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE, - metrics=["cars"], - percentiles=[1, 99], - ) - columns = {column for column in df.columns} - assert columns == { - "cars__mean", - "cars__median", - "cars__q1", - "cars__q3", - "cars__max", - "cars__min", - "cars__count", - "cars__outliers", - "region", - } - assert len(df) == 4 - - def test_boxplot_percentile_incorrect_params(self): - with pytest.raises(QueryObjectValidationError): - proc.boxplot( - df=names_df, - groupby=["region"], - whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE, - metrics=["cars"], - ) - - with pytest.raises(QueryObjectValidationError): - proc.boxplot( - df=names_df, - groupby=["region"], - whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE, - metrics=["cars"], - percentiles=[10], - ) - - with pytest.raises(QueryObjectValidationError): - proc.boxplot( - df=names_df, - groupby=["region"], - whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE, - metrics=["cars"], - percentiles=[90, 10], - ) - - with pytest.raises(QueryObjectValidationError): - proc.boxplot( - df=names_df, - groupby=["region"], - whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE, - metrics=["cars"], - percentiles=[10, 90, 10], - ) - - def test_resample(self): - df = timeseries_df.copy() - df.index.name = "time_column" - df.reset_index(inplace=True) - - post_df = proc.resample( - df=df, rule="1D", method="ffill", time_column="time_column", - ) - self.assertListEqual( - post_df["label"].tolist(), ["x", "y", "y", "y", "z", "z", "q"] - ) - self.assertListEqual(post_df["y"].tolist(), [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0]) - - post_df = proc.resample( - df=df, rule="1D", method="asfreq", time_column="time_column", fill_value=0, - ) - self.assertListEqual(post_df["label"].tolist(), ["x", "y", 0, 0, "z", 0, "q"]) - self.assertListEqual(post_df["y"].tolist(), [1.0, 2.0, 0, 0, 3.0, 0, 4.0]) - - def test_resample_with_groupby(self): - """ -The Dataframe contains a timestamp column, a string column and a numeric column. - __timestamp city val -0 2022-01-13 Chicago 6.0 -1 2022-01-13 LA 5.0 -2 2022-01-13 NY 4.0 -3 2022-01-11 Chicago 3.0 -4 2022-01-11 LA 2.0 -5 2022-01-11 NY 1.0 - """ - df = DataFrame( - { - "__timestamp": to_datetime( - [ - "2022-01-13", - "2022-01-13", - "2022-01-13", - "2022-01-11", - "2022-01-11", - "2022-01-11", - ] - ), - "city": ["Chicago", "LA", "NY", "Chicago", "LA", "NY"], - "val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0], - } - ) - post_df = proc.resample( - df=df, - rule="1D", - method="asfreq", - fill_value=0, - time_column="__timestamp", - groupby_columns=("city",), - ) - assert list(post_df.columns) == [ - "__timestamp", - "city", - "val", - ] - assert [str(dt.date()) for dt in post_df["__timestamp"]] == ( - ["2022-01-11"] * 3 + ["2022-01-12"] * 3 + ["2022-01-13"] * 3 - ) - assert list(post_df["val"]) == [3.0, 2.0, 1.0, 0, 0, 0, 6.0, 5.0, 4.0] - - # should raise error when get a non-existent column - with pytest.raises(QueryObjectValidationError): - proc.resample( - df=df, - rule="1D", - method="asfreq", - fill_value=0, - time_column="__timestamp", - groupby_columns=("city", "unkonw_column",), - ) - - # should raise error when get a None value in groupby list - with pytest.raises(QueryObjectValidationError): - proc.resample( - df=df, - rule="1D", - method="asfreq", - fill_value=0, - time_column="__timestamp", - groupby_columns=("city", None,), - ) diff --git a/tests/integration_tests/fixtures/dataframes.py b/tests/unit_tests/fixtures/dataframes.py similarity index 100% rename from tests/integration_tests/fixtures/dataframes.py rename to tests/unit_tests/fixtures/dataframes.py diff --git a/tests/unit_tests/pandas_postprocessing/__init__.py b/tests/unit_tests/pandas_postprocessing/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/pandas_postprocessing/test_aggregate.py b/tests/unit_tests/pandas_postprocessing/test_aggregate.py new file mode 100644 index 0000000000000..69d42e36f06be --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_aggregate.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from superset.utils.pandas_postprocessing import aggregate +from tests.unit_tests.fixtures.dataframes import categories_df +from tests.unit_tests.pandas_postprocessing.utils import series_to_list + + +def test_aggregate(): + aggregates = { + "asc sum": {"column": "asc_idx", "operator": "sum"}, + "asc q2": { + "column": "asc_idx", + "operator": "percentile", + "options": {"q": 75}, + }, + "desc q1": { + "column": "desc_idx", + "operator": "percentile", + "options": {"q": 25}, + }, + } + df = aggregate(df=categories_df, groupby=["constant"], aggregates=aggregates) + assert df.columns.tolist() == ["constant", "asc sum", "asc q2", "desc q1"] + assert series_to_list(df["asc sum"])[0] == 5050 + assert series_to_list(df["asc q2"])[0] == 75 + assert series_to_list(df["desc q1"])[0] == 25 diff --git a/tests/unit_tests/pandas_postprocessing/test_boxplot.py b/tests/unit_tests/pandas_postprocessing/test_boxplot.py new file mode 100644 index 0000000000000..247aba0134b55 --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_boxplot.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from superset.exceptions import QueryObjectValidationError +from superset.utils.core import PostProcessingBoxplotWhiskerType +from superset.utils.pandas_postprocessing import boxplot +from tests.unit_tests.fixtures.dataframes import names_df + + +def test_boxplot_tukey(): + df = boxplot( + df=names_df, + groupby=["region"], + whisker_type=PostProcessingBoxplotWhiskerType.TUKEY, + metrics=["cars"], + ) + columns = {column for column in df.columns} + assert columns == { + "cars__mean", + "cars__median", + "cars__q1", + "cars__q3", + "cars__max", + "cars__min", + "cars__count", + "cars__outliers", + "region", + } + assert len(df) == 4 + + +def test_boxplot_min_max(): + df = boxplot( + df=names_df, + groupby=["region"], + whisker_type=PostProcessingBoxplotWhiskerType.MINMAX, + metrics=["cars"], + ) + columns = {column for column in df.columns} + assert columns == { + "cars__mean", + "cars__median", + "cars__q1", + "cars__q3", + "cars__max", + "cars__min", + "cars__count", + "cars__outliers", + "region", + } + assert len(df) == 4 + + +def test_boxplot_percentile(): + df = boxplot( + df=names_df, + groupby=["region"], + whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE, + metrics=["cars"], + percentiles=[1, 99], + ) + columns = {column for column in df.columns} + assert columns == { + "cars__mean", + "cars__median", + "cars__q1", + "cars__q3", + "cars__max", + "cars__min", + "cars__count", + "cars__outliers", + "region", + } + assert len(df) == 4 + + +def test_boxplot_percentile_incorrect_params(): + with pytest.raises(QueryObjectValidationError): + boxplot( + df=names_df, + groupby=["region"], + whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE, + metrics=["cars"], + ) + + with pytest.raises(QueryObjectValidationError): + boxplot( + df=names_df, + groupby=["region"], + whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE, + metrics=["cars"], + percentiles=[10], + ) + + with pytest.raises(QueryObjectValidationError): + boxplot( + df=names_df, + groupby=["region"], + whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE, + metrics=["cars"], + percentiles=[90, 10], + ) + + with pytest.raises(QueryObjectValidationError): + boxplot( + df=names_df, + groupby=["region"], + whisker_type=PostProcessingBoxplotWhiskerType.PERCENTILE, + metrics=["cars"], + percentiles=[10, 90, 10], + ) diff --git a/tests/unit_tests/pandas_postprocessing/test_compare.py b/tests/unit_tests/pandas_postprocessing/test_compare.py new file mode 100644 index 0000000000000..d9213ca398f36 --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_compare.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from superset.utils.pandas_postprocessing import compare +from tests.unit_tests.fixtures.dataframes import timeseries_df2 +from tests.unit_tests.pandas_postprocessing.utils import series_to_list + + +def test_compare(): + # `difference` comparison + post_df = compare( + df=timeseries_df2, + source_columns=["y"], + compare_columns=["z"], + compare_type="difference", + ) + assert post_df.columns.tolist() == ["label", "y", "z", "difference__y__z"] + assert series_to_list(post_df["difference__y__z"]) == [0.0, -2.0, -8.0, -6.0] + + # drop original columns + post_df = compare( + df=timeseries_df2, + source_columns=["y"], + compare_columns=["z"], + compare_type="difference", + drop_original_columns=True, + ) + assert post_df.columns.tolist() == ["label", "difference__y__z"] + + # `percentage` comparison + post_df = compare( + df=timeseries_df2, + source_columns=["y"], + compare_columns=["z"], + compare_type="percentage", + ) + assert post_df.columns.tolist() == ["label", "y", "z", "percentage__y__z"] + assert series_to_list(post_df["percentage__y__z"]) == [0.0, -0.5, -0.8, -0.75] + + # `ratio` comparison + post_df = compare( + df=timeseries_df2, + source_columns=["y"], + compare_columns=["z"], + compare_type="ratio", + ) + assert post_df.columns.tolist() == ["label", "y", "z", "ratio__y__z"] + assert series_to_list(post_df["ratio__y__z"]) == [1.0, 0.5, 0.2, 0.25] diff --git a/tests/unit_tests/pandas_postprocessing/test_contribution.py b/tests/unit_tests/pandas_postprocessing/test_contribution.py new file mode 100644 index 0000000000000..78212cbe5b851 --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_contribution.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime + +import pytest +from pandas import DataFrame + +from superset.exceptions import QueryObjectValidationError +from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation +from superset.utils.pandas_postprocessing import contribution + + +def test_contribution(): + df = DataFrame( + { + DTTM_ALIAS: [datetime(2020, 7, 16, 14, 49), datetime(2020, 7, 16, 14, 50),], + "a": [1, 3], + "b": [1, 9], + } + ) + with pytest.raises(QueryObjectValidationError, match="not numeric"): + contribution(df, columns=[DTTM_ALIAS]) + + with pytest.raises(QueryObjectValidationError, match="same length"): + contribution(df, columns=["a"], rename_columns=["aa", "bb"]) + + # cell contribution across row + processed_df = contribution( + df, orientation=PostProcessingContributionOrientation.ROW, + ) + assert processed_df.columns.tolist() == [DTTM_ALIAS, "a", "b"] + assert processed_df["a"].tolist() == [0.5, 0.25] + assert processed_df["b"].tolist() == [0.5, 0.75] + + # cell contribution across column without temporal column + df.pop(DTTM_ALIAS) + processed_df = contribution( + df, orientation=PostProcessingContributionOrientation.COLUMN + ) + assert processed_df.columns.tolist() == ["a", "b"] + assert processed_df["a"].tolist() == [0.25, 0.75] + assert processed_df["b"].tolist() == [0.1, 0.9] + + # contribution only on selected columns + processed_df = contribution( + df, + orientation=PostProcessingContributionOrientation.COLUMN, + columns=["a"], + rename_columns=["pct_a"], + ) + assert processed_df.columns.tolist() == ["a", "b", "pct_a"] + assert processed_df["a"].tolist() == [1, 3] + assert processed_df["b"].tolist() == [1, 9] + assert processed_df["pct_a"].tolist() == [0.25, 0.75] diff --git a/tests/unit_tests/pandas_postprocessing/test_cum.py b/tests/unit_tests/pandas_postprocessing/test_cum.py new file mode 100644 index 0000000000000..b4b8fadb3067a --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_cum.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +from pandas import to_datetime + +from superset.exceptions import QueryObjectValidationError +from superset.utils.pandas_postprocessing import cum, pivot +from tests.unit_tests.fixtures.dataframes import ( + multiple_metrics_df, + single_metric_df, + timeseries_df, +) +from tests.unit_tests.pandas_postprocessing.utils import series_to_list + + +def test_cum(): + # create new column (cumsum) + post_df = cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",) + assert post_df.columns.tolist() == ["label", "y", "y2"] + assert series_to_list(post_df["label"]) == ["x", "y", "z", "q"] + assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0] + assert series_to_list(post_df["y2"]) == [1.0, 3.0, 6.0, 10.0] + + # overwrite column (cumprod) + post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="prod",) + assert post_df.columns.tolist() == ["label", "y"] + assert series_to_list(post_df["y"]) == [1.0, 2.0, 6.0, 24.0] + + # overwrite column (cummin) + post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="min",) + assert post_df.columns.tolist() == ["label", "y"] + assert series_to_list(post_df["y"]) == [1.0, 1.0, 1.0, 1.0] + + # invalid operator + with pytest.raises(QueryObjectValidationError): + cum( + df=timeseries_df, columns={"y": "y"}, operator="abc", + ) + + +def test_cum_with_pivot_df_and_single_metric(): + pivot_df = pivot( + df=single_metric_df, + index=["dttm"], + columns=["country"], + aggregates={"sum_metric": {"operator": "sum"}}, + flatten_columns=False, + reset_index=False, + ) + cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,) + # dttm UK US + # 0 2019-01-01 5 6 + # 1 2019-01-02 12 14 + assert cum_df["UK"].to_list() == [5.0, 12.0] + assert cum_df["US"].to_list() == [6.0, 14.0] + assert ( + cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list() + ) + + +def test_cum_with_pivot_df_and_multiple_metrics(): + pivot_df = pivot( + df=multiple_metrics_df, + index=["dttm"], + columns=["country"], + aggregates={ + "sum_metric": {"operator": "sum"}, + "count_metric": {"operator": "sum"}, + }, + flatten_columns=False, + reset_index=False, + ) + cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,) + # dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US + # 0 2019-01-01 1 2 5 6 + # 1 2019-01-02 4 6 12 14 + assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0] + assert cum_df["count_metric, US"].to_list() == [2.0, 6.0] + assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0] + assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0] + assert ( + cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list() + ) diff --git a/tests/unit_tests/pandas_postprocessing/test_diff.py b/tests/unit_tests/pandas_postprocessing/test_diff.py new file mode 100644 index 0000000000000..abade20a9bab8 --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_diff.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from superset.exceptions import QueryObjectValidationError +from superset.utils.pandas_postprocessing import diff +from tests.unit_tests.fixtures.dataframes import timeseries_df, timeseries_df2 +from tests.unit_tests.pandas_postprocessing.utils import series_to_list + + +def test_diff(): + # overwrite column + post_df = diff(df=timeseries_df, columns={"y": "y"}) + assert post_df.columns.tolist() == ["label", "y"] + assert series_to_list(post_df["y"]) == [None, 1.0, 1.0, 1.0] + + # add column + post_df = diff(df=timeseries_df, columns={"y": "y1"}) + assert post_df.columns.tolist() == ["label", "y", "y1"] + assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0] + assert series_to_list(post_df["y1"]) == [None, 1.0, 1.0, 1.0] + + # look ahead + post_df = diff(df=timeseries_df, columns={"y": "y1"}, periods=-1) + assert series_to_list(post_df["y1"]) == [-1.0, -1.0, -1.0, None] + + # invalid column reference + with pytest.raises(QueryObjectValidationError): + diff( + df=timeseries_df, columns={"abc": "abc"}, + ) + + # diff by columns + post_df = diff(df=timeseries_df2, columns={"y": "y", "z": "z"}, axis=1) + assert post_df.columns.tolist() == ["label", "y", "z"] + assert series_to_list(post_df["z"]) == [0.0, 2.0, 8.0, 6.0] diff --git a/tests/unit_tests/pandas_postprocessing/test_geography.py b/tests/unit_tests/pandas_postprocessing/test_geography.py new file mode 100644 index 0000000000000..6162f3c8a0b94 --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_geography.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from superset.utils.pandas_postprocessing import ( + geodetic_parse, + geohash_decode, + geohash_encode, +) +from tests.unit_tests.fixtures.dataframes import lonlat_df +from tests.unit_tests.pandas_postprocessing.utils import round_floats, series_to_list + + +def test_geohash_decode(): + # decode lon/lat from geohash + post_df = geohash_decode( + df=lonlat_df[["city", "geohash"]], + geohash="geohash", + latitude="latitude", + longitude="longitude", + ) + assert sorted(post_df.columns.tolist()) == sorted( + ["city", "geohash", "latitude", "longitude"] + ) + assert round_floats(series_to_list(post_df["longitude"]), 6) == round_floats( + series_to_list(lonlat_df["longitude"]), 6 + ) + assert round_floats(series_to_list(post_df["latitude"]), 6) == round_floats( + series_to_list(lonlat_df["latitude"]), 6 + ) + + +def test_geohash_encode(): + # encode lon/lat into geohash + post_df = geohash_encode( + df=lonlat_df[["city", "latitude", "longitude"]], + latitude="latitude", + longitude="longitude", + geohash="geohash", + ) + assert sorted(post_df.columns.tolist()) == sorted( + ["city", "geohash", "latitude", "longitude"] + ) + assert series_to_list(post_df["geohash"]) == series_to_list(lonlat_df["geohash"]) + + +def test_geodetic_parse(): + # parse geodetic string with altitude into lon/lat/altitude + post_df = geodetic_parse( + df=lonlat_df[["city", "geodetic"]], + geodetic="geodetic", + latitude="latitude", + longitude="longitude", + altitude="altitude", + ) + assert sorted(post_df.columns.tolist()) == sorted( + ["city", "geodetic", "latitude", "longitude", "altitude"] + ) + assert series_to_list(post_df["longitude"]) == series_to_list( + lonlat_df["longitude"] + ) + assert series_to_list(post_df["latitude"]) == series_to_list(lonlat_df["latitude"]) + assert series_to_list(post_df["altitude"]) == series_to_list(lonlat_df["altitude"]) + + # parse geodetic string into lon/lat + post_df = geodetic_parse( + df=lonlat_df[["city", "geodetic"]], + geodetic="geodetic", + latitude="latitude", + longitude="longitude", + ) + assert sorted(post_df.columns.tolist()) == sorted( + ["city", "geodetic", "latitude", "longitude"] + ) + assert series_to_list(post_df["longitude"]) == series_to_list( + lonlat_df["longitude"] + ) + assert series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]) diff --git a/tests/unit_tests/pandas_postprocessing/test_pivot.py b/tests/unit_tests/pandas_postprocessing/test_pivot.py new file mode 100644 index 0000000000000..55779e39087a3 --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_pivot.py @@ -0,0 +1,266 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pytest +from pandas import DataFrame, Timestamp, to_datetime + +from superset.exceptions import QueryObjectValidationError +from superset.utils.pandas_postprocessing import _flatten_column_after_pivot, pivot +from tests.unit_tests.fixtures.dataframes import categories_df, single_metric_df +from tests.unit_tests.pandas_postprocessing.utils import ( + AGGREGATES_MULTIPLE, + AGGREGATES_SINGLE, +) + + +def test_flatten_column_after_pivot(): + """ + Test pivot column flattening function + """ + # single aggregate cases + assert ( + _flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column="idx_nulls",) + == "idx_nulls" + ) + + assert ( + _flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column=1234,) + == "1234" + ) + + assert ( + _flatten_column_after_pivot( + aggregates=AGGREGATES_SINGLE, column=Timestamp("2020-09-29T00:00:00"), + ) + == "2020-09-29 00:00:00" + ) + + assert ( + _flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column="idx_nulls",) + == "idx_nulls" + ) + + assert ( + _flatten_column_after_pivot( + aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1"), + ) + == "col1" + ) + + assert ( + _flatten_column_after_pivot( + aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1", 1234), + ) + == "col1, 1234" + ) + + # Multiple aggregate cases + assert ( + _flatten_column_after_pivot( + aggregates=AGGREGATES_MULTIPLE, column=("idx_nulls", "asc_idx", "col1"), + ) + == "idx_nulls, asc_idx, col1" + ) + + assert ( + _flatten_column_after_pivot( + aggregates=AGGREGATES_MULTIPLE, + column=("idx_nulls", "asc_idx", "col1", 1234), + ) + == "idx_nulls, asc_idx, col1, 1234" + ) + + +def test_pivot_without_columns(): + """ + Make sure pivot without columns returns correct DataFrame + """ + df = pivot(df=categories_df, index=["name"], aggregates=AGGREGATES_SINGLE,) + assert df.columns.tolist() == ["name", "idx_nulls"] + assert len(df) == 101 + assert df.sum()[1] == 1050 + + +def test_pivot_with_single_column(): + """ + Make sure pivot with single column returns correct DataFrame + """ + df = pivot( + df=categories_df, + index=["name"], + columns=["category"], + aggregates=AGGREGATES_SINGLE, + ) + assert df.columns.tolist() == ["name", "cat0", "cat1", "cat2"] + assert len(df) == 101 + assert df.sum()[1] == 315 + + df = pivot( + df=categories_df, + index=["dept"], + columns=["category"], + aggregates=AGGREGATES_SINGLE, + ) + assert df.columns.tolist() == ["dept", "cat0", "cat1", "cat2"] + assert len(df) == 5 + + +def test_pivot_with_multiple_columns(): + """ + Make sure pivot with multiple columns returns correct DataFrame + """ + df = pivot( + df=categories_df, + index=["name"], + columns=["category", "dept"], + aggregates=AGGREGATES_SINGLE, + ) + assert len(df.columns) == 1 + 3 * 5 # index + possible permutations + + +def test_pivot_fill_values(): + """ + Make sure pivot with fill values returns correct DataFrame + """ + df = pivot( + df=categories_df, + index=["name"], + columns=["category"], + metric_fill_value=1, + aggregates={"idx_nulls": {"operator": "sum"}}, + ) + assert df.sum()[1] == 382 + + +def test_pivot_fill_column_values(): + """ + Make sure pivot witn null column names returns correct DataFrame + """ + df_copy = categories_df.copy() + df_copy["category"] = None + df = pivot( + df=df_copy, + index=["name"], + columns=["category"], + aggregates={"idx_nulls": {"operator": "sum"}}, + ) + assert len(df) == 101 + assert df.columns.tolist() == ["name", ""] + + +def test_pivot_exceptions(): + """ + Make sure pivot raises correct Exceptions + """ + # Missing index + with pytest.raises(TypeError): + pivot(df=categories_df, columns=["dept"], aggregates=AGGREGATES_SINGLE) + + # invalid index reference + with pytest.raises(QueryObjectValidationError): + pivot( + df=categories_df, + index=["abc"], + columns=["dept"], + aggregates=AGGREGATES_SINGLE, + ) + + # invalid column reference + with pytest.raises(QueryObjectValidationError): + pivot( + df=categories_df, + index=["dept"], + columns=["abc"], + aggregates=AGGREGATES_SINGLE, + ) + + # invalid aggregate options + with pytest.raises(QueryObjectValidationError): + pivot( + df=categories_df, + index=["name"], + columns=["category"], + aggregates={"idx_nulls": {}}, + ) + + +def test_pivot_eliminate_cartesian_product_columns(): + # single metric + mock_df = DataFrame( + { + "dttm": to_datetime(["2019-01-01", "2019-01-01"]), + "a": [0, 1], + "b": [0, 1], + "metric": [9, np.NAN], + } + ) + + df = pivot( + df=mock_df, + index=["dttm"], + columns=["a", "b"], + aggregates={"metric": {"operator": "mean"}}, + drop_missing_columns=False, + ) + assert list(df.columns) == ["dttm", "0, 0", "1, 1"] + assert np.isnan(df["1, 1"][0]) + + # multiple metrics + mock_df = DataFrame( + { + "dttm": to_datetime(["2019-01-01", "2019-01-01"]), + "a": [0, 1], + "b": [0, 1], + "metric": [9, np.NAN], + "metric2": [10, 11], + } + ) + + df = pivot( + df=mock_df, + index=["dttm"], + columns=["a", "b"], + aggregates={"metric": {"operator": "mean"}, "metric2": {"operator": "mean"},}, + drop_missing_columns=False, + ) + assert list(df.columns) == [ + "dttm", + "metric, 0, 0", + "metric, 1, 1", + "metric2, 0, 0", + "metric2, 1, 1", + ] + assert np.isnan(df["metric, 1, 1"][0]) + + +def test_pivot_without_flatten_columns_and_reset_index(): + df = pivot( + df=single_metric_df, + index=["dttm"], + columns=["country"], + aggregates={"sum_metric": {"operator": "sum"}}, + flatten_columns=False, + reset_index=False, + ) + # metric + # country UK US + # dttm + # 2019-01-01 5 6 + # 2019-01-02 7 8 + assert df.columns.to_list() == [("sum_metric", "UK"), ("sum_metric", "US")] + assert df.index.to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list() diff --git a/tests/unit_tests/pandas_postprocessing/test_prophet.py b/tests/unit_tests/pandas_postprocessing/test_prophet.py new file mode 100644 index 0000000000000..ce5c45b2c0ab0 --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_prophet.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import datetime +from importlib.util import find_spec + +import pytest + +from superset.exceptions import QueryObjectValidationError +from superset.utils.core import DTTM_ALIAS +from superset.utils.pandas_postprocessing import prophet +from tests.unit_tests.fixtures.dataframes import prophet_df + + +def test_prophet_valid(): + pytest.importorskip("prophet") + + df = prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9) + columns = {column for column in df.columns} + assert columns == { + DTTM_ALIAS, + "a__yhat", + "a__yhat_upper", + "a__yhat_lower", + "a", + "b__yhat", + "b__yhat_upper", + "b__yhat_lower", + "b", + } + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) + assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31) + assert len(df) == 7 + + df = prophet(df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9) + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) + assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31) + assert len(df) == 9 + + +def test_prophet_valid_zero_periods(): + pytest.importorskip("prophet") + + df = prophet(df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9) + columns = {column for column in df.columns} + assert columns == { + DTTM_ALIAS, + "a__yhat", + "a__yhat_upper", + "a__yhat_lower", + "a", + "b__yhat", + "b__yhat_upper", + "b__yhat_lower", + "b", + } + assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) + assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31) + assert len(df) == 4 + + +def test_prophet_import(): + dynamic_module = find_spec("prophet") + if dynamic_module is None: + with pytest.raises(QueryObjectValidationError): + prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9) + + +def test_prophet_missing_temporal_column(): + df = prophet_df.drop(DTTM_ALIAS, axis=1) + + with pytest.raises(QueryObjectValidationError): + prophet( + df=df, time_grain="P1M", periods=3, confidence_interval=0.9, + ) + + +def test_prophet_incorrect_confidence_interval(): + with pytest.raises(QueryObjectValidationError): + prophet( + df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.0, + ) + + with pytest.raises(QueryObjectValidationError): + prophet( + df=prophet_df, time_grain="P1M", periods=3, confidence_interval=1.0, + ) + + +def test_prophet_incorrect_periods(): + with pytest.raises(QueryObjectValidationError): + prophet( + df=prophet_df, time_grain="P1M", periods=-1, confidence_interval=0.8, + ) + + +def test_prophet_incorrect_time_grain(): + with pytest.raises(QueryObjectValidationError): + prophet( + df=prophet_df, time_grain="yearly", periods=10, confidence_interval=0.8, + ) diff --git a/tests/unit_tests/pandas_postprocessing/test_resample.py b/tests/unit_tests/pandas_postprocessing/test_resample.py new file mode 100644 index 0000000000000..872f2ed78098e --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_resample.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +from pandas import DataFrame, to_datetime + +from superset.exceptions import QueryObjectValidationError +from superset.utils.pandas_postprocessing import resample +from tests.unit_tests.fixtures.dataframes import timeseries_df + + +def test_resample(): + df = timeseries_df.copy() + df.index.name = "time_column" + df.reset_index(inplace=True) + + post_df = resample(df=df, rule="1D", method="ffill", time_column="time_column",) + assert post_df["label"].tolist() == ["x", "y", "y", "y", "z", "z", "q"] + + assert post_df["y"].tolist() == [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0] + + post_df = resample( + df=df, rule="1D", method="asfreq", time_column="time_column", fill_value=0, + ) + assert post_df["label"].tolist() == ["x", "y", 0, 0, "z", 0, "q"] + assert post_df["y"].tolist() == [1.0, 2.0, 0, 0, 3.0, 0, 4.0] + + +def test_resample_with_groupby(): + """ +The Dataframe contains a timestamp column, a string column and a numeric column. +__timestamp city val +0 2022-01-13 Chicago 6.0 +1 2022-01-13 LA 5.0 +2 2022-01-13 NY 4.0 +3 2022-01-11 Chicago 3.0 +4 2022-01-11 LA 2.0 +5 2022-01-11 NY 1.0 + """ + df = DataFrame( + { + "__timestamp": to_datetime( + [ + "2022-01-13", + "2022-01-13", + "2022-01-13", + "2022-01-11", + "2022-01-11", + "2022-01-11", + ] + ), + "city": ["Chicago", "LA", "NY", "Chicago", "LA", "NY"], + "val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0], + } + ) + post_df = resample( + df=df, + rule="1D", + method="asfreq", + fill_value=0, + time_column="__timestamp", + groupby_columns=("city",), + ) + assert list(post_df.columns) == [ + "__timestamp", + "city", + "val", + ] + assert [str(dt.date()) for dt in post_df["__timestamp"]] == ( + ["2022-01-11"] * 3 + ["2022-01-12"] * 3 + ["2022-01-13"] * 3 + ) + assert list(post_df["val"]) == [3.0, 2.0, 1.0, 0, 0, 0, 6.0, 5.0, 4.0] + + # should raise error when get a non-existent column + with pytest.raises(QueryObjectValidationError): + resample( + df=df, + rule="1D", + method="asfreq", + fill_value=0, + time_column="__timestamp", + groupby_columns=("city", "unkonw_column",), + ) + + # should raise error when get a None value in groupby list + with pytest.raises(QueryObjectValidationError): + resample( + df=df, + rule="1D", + method="asfreq", + fill_value=0, + time_column="__timestamp", + groupby_columns=("city", None,), + ) diff --git a/tests/unit_tests/pandas_postprocessing/test_rolling.py b/tests/unit_tests/pandas_postprocessing/test_rolling.py new file mode 100644 index 0000000000000..227b03a0224be --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_rolling.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +from pandas import to_datetime + +from superset.exceptions import QueryObjectValidationError +from superset.utils.pandas_postprocessing import pivot, rolling +from tests.unit_tests.fixtures.dataframes import ( + multiple_metrics_df, + single_metric_df, + timeseries_df, +) +from tests.unit_tests.pandas_postprocessing.utils import series_to_list + + +def test_rolling(): + # sum rolling type + post_df = rolling( + df=timeseries_df, + columns={"y": "y"}, + rolling_type="sum", + window=2, + min_periods=0, + ) + + assert post_df.columns.tolist() == ["label", "y"] + assert series_to_list(post_df["y"]) == [1.0, 3.0, 5.0, 7.0] + + # mean rolling type with alias + post_df = rolling( + df=timeseries_df, + rolling_type="mean", + columns={"y": "y_mean"}, + window=10, + min_periods=0, + ) + assert post_df.columns.tolist() == ["label", "y", "y_mean"] + assert series_to_list(post_df["y_mean"]) == [1.0, 1.5, 2.0, 2.5] + + # count rolling type + post_df = rolling( + df=timeseries_df, + rolling_type="count", + columns={"y": "y"}, + window=10, + min_periods=0, + ) + assert post_df.columns.tolist() == ["label", "y"] + assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0] + + # quantile rolling type + post_df = rolling( + df=timeseries_df, + columns={"y": "q1"}, + rolling_type="quantile", + rolling_type_options={"quantile": 0.25}, + window=10, + min_periods=0, + ) + assert post_df.columns.tolist() == ["label", "y", "q1"] + assert series_to_list(post_df["q1"]) == [1.0, 1.25, 1.5, 1.75] + + # incorrect rolling type + with pytest.raises(QueryObjectValidationError): + rolling( + df=timeseries_df, columns={"y": "y"}, rolling_type="abc", window=2, + ) + + # incorrect rolling type options + with pytest.raises(QueryObjectValidationError): + rolling( + df=timeseries_df, + columns={"y": "y"}, + rolling_type="quantile", + rolling_type_options={"abc": 123}, + window=2, + ) + + +def test_rolling_with_pivot_df_and_single_metric(): + pivot_df = pivot( + df=single_metric_df, + index=["dttm"], + columns=["country"], + aggregates={"sum_metric": {"operator": "sum"}}, + flatten_columns=False, + reset_index=False, + ) + rolling_df = rolling( + df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True, + ) + # dttm UK US + # 0 2019-01-01 5 6 + # 1 2019-01-02 12 14 + assert rolling_df["UK"].to_list() == [5.0, 12.0] + assert rolling_df["US"].to_list() == [6.0, 14.0] + assert ( + rolling_df["dttm"].to_list() + == to_datetime(["2019-01-01", "2019-01-02"]).to_list() + ) + + rolling_df = rolling( + df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True, + ) + assert rolling_df.empty is True + + +def test_rolling_with_pivot_df_and_multiple_metrics(): + pivot_df = pivot( + df=multiple_metrics_df, + index=["dttm"], + columns=["country"], + aggregates={ + "sum_metric": {"operator": "sum"}, + "count_metric": {"operator": "sum"}, + }, + flatten_columns=False, + reset_index=False, + ) + rolling_df = rolling( + df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True, + ) + # dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US + # 0 2019-01-01 1.0 2.0 5.0 6.0 + # 1 2019-01-02 4.0 6.0 12.0 14.0 + assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0] + assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0] + assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0] + assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0] + assert ( + rolling_df["dttm"].to_list() + == to_datetime(["2019-01-01", "2019-01-02",]).to_list() + ) diff --git a/tests/unit_tests/pandas_postprocessing/test_select.py b/tests/unit_tests/pandas_postprocessing/test_select.py new file mode 100644 index 0000000000000..aac644d316e6a --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_select.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from superset.exceptions import QueryObjectValidationError +from superset.utils.pandas_postprocessing.select import select +from tests.unit_tests.fixtures.dataframes import timeseries_df + + +def test_select(): + # reorder columns + post_df = select(df=timeseries_df, columns=["y", "label"]) + assert post_df.columns.tolist() == ["y", "label"] + + # one column + post_df = select(df=timeseries_df, columns=["label"]) + assert post_df.columns.tolist() == ["label"] + + # rename and select one column + post_df = select(df=timeseries_df, columns=["y"], rename={"y": "y1"}) + assert post_df.columns.tolist() == ["y1"] + + # rename one and leave one unchanged + post_df = select(df=timeseries_df, rename={"y": "y1"}) + assert post_df.columns.tolist() == ["label", "y1"] + + # drop one column + post_df = select(df=timeseries_df, exclude=["label"]) + assert post_df.columns.tolist() == ["y"] + + # rename and drop one column + post_df = select(df=timeseries_df, rename={"y": "y1"}, exclude=["label"]) + assert post_df.columns.tolist() == ["y1"] + + # invalid columns + with pytest.raises(QueryObjectValidationError): + select(df=timeseries_df, columns=["abc"], rename={"abc": "qwerty"}) + + # select renamed column by new name + with pytest.raises(QueryObjectValidationError): + select(df=timeseries_df, columns=["label_new"], rename={"label": "label_new"}) diff --git a/tests/unit_tests/pandas_postprocessing/test_sort.py b/tests/unit_tests/pandas_postprocessing/test_sort.py new file mode 100644 index 0000000000000..43daa9ce2b194 --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_sort.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from superset.exceptions import QueryObjectValidationError +from superset.utils.pandas_postprocessing import sort +from tests.unit_tests.fixtures.dataframes import categories_df +from tests.unit_tests.pandas_postprocessing.utils import series_to_list + + +def test_sort(): + df = sort(df=categories_df, columns={"category": True, "asc_idx": False}) + assert series_to_list(df["asc_idx"])[1] == 96 + + with pytest.raises(QueryObjectValidationError): + sort(df=df, columns={"abc": True}) diff --git a/tests/unit_tests/pandas_postprocessing/utils.py b/tests/unit_tests/pandas_postprocessing/utils.py new file mode 100644 index 0000000000000..07366b15774d1 --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/utils.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import math +from typing import Any, List, Optional + +from pandas import Series + +AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}} +AGGREGATES_MULTIPLE = { + "idx_nulls": {"operator": "sum"}, + "asc_idx": {"operator": "mean"}, +} + + +def series_to_list(series: Series) -> List[Any]: + """ + Converts a `Series` to a regular list, and replaces non-numeric values to + Nones. + + :param series: Series to convert + :return: list without nan or inf + """ + return [ + None + if not isinstance(val, str) and (math.isnan(val) or math.isinf(val)) + else val + for val in series.tolist() + ] + + +def round_floats( + floats: List[Optional[float]], precision: int +) -> List[Optional[float]]: + """ + Round list of floats to certain precision + + :param floats: floats to round + :param precision: intended decimal precision + :return: rounded floats + """ + return [round(val, precision) if val else None for val in floats]