diff --git a/snuba/web/rpc/common/aggregation.py b/snuba/web/rpc/common/aggregation.py index 4045e8880c..96ea4818d0 100644 --- a/snuba/web/rpc/common/aggregation.py +++ b/snuba/web/rpc/common/aggregation.py @@ -21,6 +21,7 @@ from snuba.web.rpc.common.exceptions import BadSnubaRPCRequestException sampling_weight_column = column("sampling_weight") +sign_column = column("sign") # Z value for 95% confidence interval is 1.96 which comes from the normal distribution z score. z_value = 1.96 @@ -211,26 +212,27 @@ def get_extrapolated_function( Function.ValueType, CurriedFunctionCall | FunctionCall ] = { Function.FUNCTION_SUM: f.sum( - f.multiply(field, sampling_weight_column), **alias_dict + f.multiply(field, f.multiply(sign_column, sampling_weight_column)), + **alias_dict, ), Function.FUNCTION_AVERAGE: f.divide( - f.sum(f.multiply(field, sampling_weight_column)), + f.sum(f.multiply(field, f.multiply(sign_column, sampling_weight_column))), f.sumIf( - sampling_weight_column, + f.multiply(sign_column, sampling_weight_column), get_field_existence_expression(aggregation), ), **alias_dict, ), Function.FUNCTION_AVG: f.divide( - f.sum(f.multiply(field, sampling_weight_column)), + f.sum(f.multiply(field, f.multiply(sign_column, sampling_weight_column))), f.sumIf( - sampling_weight_column, + f.multiply(sign_column, sampling_weight_column), get_field_existence_expression(aggregation), ), **alias_dict, ), Function.FUNCTION_COUNT: f.sumIf( - sampling_weight_column, + f.multiply(sign_column, sampling_weight_column), get_field_existence_expression(aggregation), **alias_dict, ), @@ -407,9 +409,17 @@ def aggregation_to_expression(aggregation: AttributeAggregation) -> Expression: alias = aggregation.label if aggregation.label else None alias_dict = {"alias": alias} if alias else {} function_map: dict[Function.ValueType, CurriedFunctionCall | FunctionCall] = { - Function.FUNCTION_SUM: f.sum(field, **alias_dict), - Function.FUNCTION_AVERAGE: f.avg(field, **alias_dict), - Function.FUNCTION_COUNT: f.count(field, **alias_dict), + Function.FUNCTION_SUM: f.sum(f.multiply(field, sign_column), **alias_dict), + Function.FUNCTION_AVERAGE: f.divide( + f.sum(f.multiply(field, sign_column)), + f.sumIf(sign_column, get_field_existence_expression(aggregation)), + **alias_dict, + ), + Function.FUNCTION_COUNT: f.sumIf( + sign_column, + get_field_existence_expression(aggregation), + **alias_dict, + ), Function.FUNCTION_P50: cf.quantile(0.5)(field, **alias_dict), Function.FUNCTION_P75: cf.quantile(0.75)(field, **alias_dict), Function.FUNCTION_P90: cf.quantile(0.9)(field, **alias_dict),