Skip to content

Commit

Permalink
refactor: decouple pandas postprocessing operator (#18710)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyongjie authored Feb 14, 2022
1 parent ea12024 commit 8d6aff3
Show file tree
Hide file tree
Showing 16 changed files with 1,364 additions and 1,002 deletions.
1,002 changes: 0 additions & 1,002 deletions superset/utils/pandas_postprocessing.py

This file was deleted.

53 changes: 53 additions & 0 deletions superset/utils/pandas_postprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.utils.pandas_postprocessing.aggregate import aggregate
from superset.utils.pandas_postprocessing.boxplot import boxplot
from superset.utils.pandas_postprocessing.compare import compare
from superset.utils.pandas_postprocessing.contribution import contribution
from superset.utils.pandas_postprocessing.cum import cum
from superset.utils.pandas_postprocessing.diff import diff
from superset.utils.pandas_postprocessing.geography import (
geodetic_parse,
geohash_decode,
geohash_encode,
)
from superset.utils.pandas_postprocessing.pivot import pivot
from superset.utils.pandas_postprocessing.prophet import prophet
from superset.utils.pandas_postprocessing.resample import resample
from superset.utils.pandas_postprocessing.rolling import rolling
from superset.utils.pandas_postprocessing.select import select
from superset.utils.pandas_postprocessing.sort import sort
from superset.utils.pandas_postprocessing.utils import _flatten_column_after_pivot

__all__ = [
"aggregate",
"boxplot",
"compare",
"contribution",
"cum",
"diff",
"geohash_encode",
"geohash_decode",
"geodetic_parse",
"pivot",
"prophet",
"resample",
"rolling",
"select",
"sort",
"_flatten_column_after_pivot",
]
46 changes: 46 additions & 0 deletions superset/utils/pandas_postprocessing/aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List

from pandas import DataFrame

from superset.utils.pandas_postprocessing.utils import (
_get_aggregate_funcs,
validate_column_args,
)


@validate_column_args("groupby")
def aggregate(
df: DataFrame, groupby: List[str], aggregates: Dict[str, Dict[str, Any]]
) -> DataFrame:
"""
Apply aggregations to a DataFrame.
:param df: Object to aggregate.
:param groupby: columns to aggregate
:param aggregates: A mapping from metric column to the function used to
aggregate values.
:raises QueryObjectValidationError: If the request in incorrect
"""
aggregates = aggregates or {}
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
if groupby:
df_groupby = df.groupby(by=groupby)
else:
df_groupby = df.groupby(lambda _: True)
return df_groupby.agg(**aggregate_funcs).reset_index(drop=not groupby)
125 changes: 125 additions & 0 deletions superset/utils/pandas_postprocessing/boxplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import numpy as np
from flask_babel import gettext as _
from pandas import DataFrame, Series

from superset.exceptions import QueryObjectValidationError
from superset.utils.core import PostProcessingBoxplotWhiskerType
from superset.utils.pandas_postprocessing.aggregate import aggregate


def boxplot(
df: DataFrame,
groupby: List[str],
metrics: List[str],
whisker_type: PostProcessingBoxplotWhiskerType,
percentiles: Optional[
Union[List[Union[int, float]], Tuple[Union[int, float], Union[int, float]]]
] = None,
) -> DataFrame:
"""
Calculate boxplot statistics. For each metric, the operation creates eight
new columns with the column name suffixed with the following values:
- `__mean`: the mean
- `__median`: the median
- `__max`: the maximum value excluding outliers (see whisker type)
- `__min`: the minimum value excluding outliers (see whisker type)
- `__q1`: the median
- `__q1`: the first quartile (25th percentile)
- `__q3`: the third quartile (75th percentile)
- `__count`: count of observations
- `__outliers`: the values that fall outside the minimum/maximum value
(see whisker type)
:param df: DataFrame containing all-numeric data (temporal column ignored)
:param groupby: The categories to group by (x-axis)
:param metrics: The metrics for which to calculate the distribution
:param whisker_type: The confidence level type
:return: DataFrame with boxplot statistics per groupby
"""

def quartile1(series: Series) -> float:
return np.nanpercentile(series, 25, interpolation="midpoint")

def quartile3(series: Series) -> float:
return np.nanpercentile(series, 75, interpolation="midpoint")

if whisker_type == PostProcessingBoxplotWhiskerType.TUKEY:

def whisker_high(series: Series) -> float:
upper_outer_lim = quartile3(series) + 1.5 * (
quartile3(series) - quartile1(series)
)
return series[series <= upper_outer_lim].max()

def whisker_low(series: Series) -> float:
lower_outer_lim = quartile1(series) - 1.5 * (
quartile3(series) - quartile1(series)
)
return series[series >= lower_outer_lim].min()

elif whisker_type == PostProcessingBoxplotWhiskerType.PERCENTILE:
if (
not isinstance(percentiles, (list, tuple))
or len(percentiles) != 2
or not isinstance(percentiles[0], (int, float))
or not isinstance(percentiles[1], (int, float))
or percentiles[0] >= percentiles[1]
):
raise QueryObjectValidationError(
_(
"percentiles must be a list or tuple with two numeric values, "
"of which the first is lower than the second value"
)
)
low, high = percentiles[0], percentiles[1]

def whisker_high(series: Series) -> float:
return np.nanpercentile(series, high)

def whisker_low(series: Series) -> float:
return np.nanpercentile(series, low)

else:
whisker_high = np.max
whisker_low = np.min

def outliers(series: Series) -> Set[float]:
above = series[series > whisker_high(series)]
below = series[series < whisker_low(series)]
return above.tolist() + below.tolist()

operators: Dict[str, Callable[[Any], Any]] = {
"mean": np.mean,
"median": np.median,
"max": whisker_high,
"min": whisker_low,
"q1": quartile1,
"q3": quartile3,
"count": np.ma.count,
"outliers": outliers,
}
aggregates: Dict[str, Dict[str, Union[str, Callable[..., Any]]]] = {
f"{metric}__{operator_name}": {"column": metric, "operator": operator}
for operator_name, operator in operators.items()
for metric in metrics
}
return aggregate(df, groupby=groupby, aggregates=aggregates)
79 changes: 79 additions & 0 deletions superset/utils/pandas_postprocessing/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import List, Optional

import pandas as pd
from flask_babel import gettext as _
from pandas import DataFrame

from superset.constants import PandasPostprocessingCompare
from superset.exceptions import QueryObjectValidationError
from superset.utils.core import TIME_COMPARISION
from superset.utils.pandas_postprocessing.utils import validate_column_args


@validate_column_args("source_columns", "compare_columns")
def compare( # pylint: disable=too-many-arguments
df: DataFrame,
source_columns: List[str],
compare_columns: List[str],
compare_type: Optional[PandasPostprocessingCompare],
drop_original_columns: Optional[bool] = False,
precision: Optional[int] = 4,
) -> DataFrame:
"""
Calculate column-by-column changing for select columns.
:param df: DataFrame on which the compare will be based.
:param source_columns: Main query columns
:param compare_columns: Columns being compared
:param compare_type: Type of compare. Choice of `absolute`, `percentage` or `ratio`
:param drop_original_columns: Whether to remove the source columns and
compare columns.
:param precision: Round a change rate to a variable number of decimal places.
:return: DataFrame with compared columns.
:raises QueryObjectValidationError: If the request in incorrect.
"""
if len(source_columns) != len(compare_columns):
raise QueryObjectValidationError(
_("`compare_columns` must have the same length as `source_columns`.")
)
if compare_type not in tuple(PandasPostprocessingCompare):
raise QueryObjectValidationError(
_("`compare_type` must be `difference`, `percentage` or `ratio`")
)
if len(source_columns) == 0:
return df

for s_col, c_col in zip(source_columns, compare_columns):
if compare_type == PandasPostprocessingCompare.DIFF:
diff_series = df[s_col] - df[c_col]
elif compare_type == PandasPostprocessingCompare.PCT:
diff_series = (
((df[s_col] - df[c_col]) / df[c_col]).astype(float).round(precision)
)
else:
# compare_type == "ratio"
diff_series = (df[s_col] / df[c_col]).astype(float).round(precision)
diff_df = diff_series.to_frame(
name=TIME_COMPARISION.join([compare_type, s_col, c_col])
)
df = pd.concat([df, diff_df], axis=1)

if drop_original_columns:
df = df.drop(source_columns + compare_columns, axis=1)
return df
75 changes: 75 additions & 0 deletions superset/utils/pandas_postprocessing/contribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from decimal import Decimal
from typing import List, Optional

from flask_babel import gettext as _
from pandas import DataFrame

from superset.exceptions import QueryObjectValidationError
from superset.utils.core import PostProcessingContributionOrientation
from superset.utils.pandas_postprocessing.utils import validate_column_args


@validate_column_args("columns")
def contribution(
df: DataFrame,
orientation: Optional[
PostProcessingContributionOrientation
] = PostProcessingContributionOrientation.COLUMN,
columns: Optional[List[str]] = None,
rename_columns: Optional[List[str]] = None,
) -> DataFrame:
"""
Calculate cell contibution to row/column total for numeric columns.
Non-numeric columns will be kept untouched.
If `columns` are specified, only calculate contributions on selected columns.
:param df: DataFrame containing all-numeric data (temporal column ignored)
:param columns: Columns to calculate values from.
:param rename_columns: The new labels for the calculated contribution columns.
The original columns will not be removed.
:param orientation: calculate by dividing cell with row/column total
:return: DataFrame with contributions.
"""
contribution_df = df.copy()
numeric_df = contribution_df.select_dtypes(include=["number", Decimal])
# verify column selections
if columns:
numeric_columns = numeric_df.columns.tolist()
for col in columns:
if col not in numeric_columns:
raise QueryObjectValidationError(
_(
'Column "%(column)s" is not numeric or does not '
"exists in the query results.",
column=col,
)
)
columns = columns or numeric_df.columns
rename_columns = rename_columns or columns
if len(rename_columns) != len(columns):
raise QueryObjectValidationError(
_("`rename_columns` must have the same length as `columns`.")
)
# limit to selected columns
numeric_df = numeric_df[columns]
axis = 0 if orientation == PostProcessingContributionOrientation.COLUMN else 1
numeric_df = numeric_df / numeric_df.values.sum(axis=axis, keepdims=True)
contribution_df[rename_columns] = numeric_df
return contribution_df
Loading

0 comments on commit 8d6aff3

Please sign in to comment.