You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
WAPE is a common metric used in time series forecasting.
Alternatives
I don't think this metric can be implemented using arithmetics of existing metrics.
Additional context
A draft implementation is as follows. I'd like someone to take over the rest.
'''functional'''fromtypingimportTupleimporttorchfromtorchimportTensorfromtorchmetrics.utilities.checksimport_check_same_shapedef_weighted_absolute_percentage_error_update(
preds: Tensor,
target: Tensor,
) ->Tuple[Tensor, int]:
"""Updates and returns variables required to compute Weighted Absolute Percentage Error. Checks for same shape of input tensors. Args: preds: Predicted tensor target: Ground truth tensor epsilon: Avoids ZeroDivisionError. """_check_same_shape(preds, target)
sum_abs_error= (preds-target).abs().sum()
sum_scale=target.abs().sum()
returnsum_abs_error, sum_scaledef_weighted_absolute_percentage_error_compute(sum_abs_error: Tensor, sum_scale: Tensor, epsilon: float=1.17e-06,) ->Tensor:
"""Computes Weighted Absolute Percentage Error. Args: num_obs: Number of predictions or observations """returnsum_abs_error/torch.clamp(sum_scale, min=epsilon)
defweighted_absolute_percentage_error(preds: Tensor, target: Tensor) ->Tensor:
r""" Computes weighted absolute percentage error (WAPE_): Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. Args: preds: estimated labels target: ground truth labels Return: Tensor with WAPE. """sum_abs_error, sum_scale=_weighted_absolute_percentage_error_update(
preds,
target,
)
weighted_ape=_weighted_absolute_percentage_error_compute(
sum_abs_error,
sum_scale,
)
returnweighted_ape'''module'''fromtypingimportAny, Callable, OptionalimporttorchfromtorchimportTensor, tensor# todo functional importfromtorchmetrics.metricimportMetricclassWeightedAbsolutePercentageError(Metric):
r""" Computes weighted absolute percentage error (`WAPE`_). Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. Args: compute_on_step: Forward only calls ``update()`` and return None if this is set to False. dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. process_group: Specify the process group on which synchronization is called. Note: WAPE output is a non-negative floating point. Best result is 0.0 . """is_differentiable=Truehigher_is_better=Falsesum_abs_error: Tensorsum_scale: Tensordef__init__(
self,
compute_on_step: bool=True,
dist_sync_on_step: bool=False,
process_group: Optional[Any] =None,
dist_sync_fn: Callable=None,
) ->None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.add_state("sum_abs_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_scale", default=tensor(0.0), dist_reduce_fx="sum")
defupdate(self, preds: Tensor, target: Tensor) ->None: # type: ignore"""Update state with predictions and targets. Args: preds: Predictions from model target: Ground truth values """sum_abs_error, sum_scale=_weighted_absolute_percentage_error_update(preds, target)
self.sum_abs_error+=sum_abs_errorself.sum_scale+=sum_scaledefcompute(self) ->Tensor:
"""Computes weighted absolute percentage error over state."""return_weighted_absolute_percentage_error_compute(self.sum_abs_error, self.sum_scale)
The text was updated successfully, but these errors were encountered:
🚀 Feature
Add Weighted Absolute Percentage Error (WAPE) metric. The description and formula can be found here: https://en.wikipedia.org/wiki/WMAPE
Motivation
WAPE is a common metric used in time series forecasting.
Alternatives
I don't think this metric can be implemented using arithmetics of existing metrics.
Additional context
A draft implementation is as follows. I'd like someone to take over the rest.
The text was updated successfully, but these errors were encountered: