Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WeightedAbsolutePercentageError #928

Closed
Guan-t7 opened this issue Apr 4, 2022 · 1 comment · Fixed by #948
Closed

Add WeightedAbsolutePercentageError #928

Guan-t7 opened this issue Apr 4, 2022 · 1 comment · Fixed by #948
Assignees
Labels
enhancement New feature or request New metric
Milestone

Comments

@Guan-t7
Copy link
Contributor

Guan-t7 commented Apr 4, 2022

🚀 Feature

Add Weighted Absolute Percentage Error (WAPE) metric. The description and formula can be found here: https://en.wikipedia.org/wiki/WMAPE

Motivation

WAPE is a common metric used in time series forecasting.

Alternatives

I don't think this metric can be implemented using arithmetics of existing metrics.

Additional context

A draft implementation is as follows. I'd like someone to take over the rest.

'''functional'''
from typing import Tuple

import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape


def _weighted_absolute_percentage_error_update(
    preds: Tensor,
    target: Tensor,
) -> Tuple[Tensor, int]:
    """Updates and returns variables required to compute Weighted Absolute Percentage Error. Checks for same
    shape of input tensors.

    Args:
        preds: Predicted tensor
        target: Ground truth tensor
        epsilon: Avoids ZeroDivisionError.
    """

    _check_same_shape(preds, target)

    sum_abs_error = (preds - target).abs().sum()
    sum_scale = target.abs().sum()

    return sum_abs_error, sum_scale


def _weighted_absolute_percentage_error_compute(sum_abs_error: Tensor, sum_scale: Tensor, epsilon: float = 1.17e-06,) -> Tensor:
    """Computes Weighted Absolute Percentage Error.

    Args:
        num_obs: Number of predictions or observations
    """

    return sum_abs_error / torch.clamp(sum_scale, min=epsilon)


def weighted_absolute_percentage_error(preds: Tensor, target: Tensor) -> Tensor:
    r"""
    Computes weighted absolute percentage error (WAPE_):

    Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.

    Args:
        preds: estimated labels
        target: ground truth labels

    Return:
        Tensor with WAPE.
    """
    sum_abs_error, sum_scale = _weighted_absolute_percentage_error_update(
        preds,
        target,
    )
    weighted_ape = _weighted_absolute_percentage_error_compute(
        sum_abs_error,
        sum_scale,
    )

    return weighted_ape


'''module'''
from typing import Any, Callable, Optional

import torch
from torch import Tensor, tensor

# todo functional import
from torchmetrics.metric import Metric


class WeightedAbsolutePercentageError(Metric):
    r"""
    Computes weighted absolute percentage error (`WAPE`_).

    Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.

    Args:
        compute_on_step:
            Forward only calls ``update()`` and return None if this is set to False.
        dist_sync_on_step:
            Synchronize metric state across processes at each ``forward()`` before returning the value at the step.
        process_group:
            Specify the process group on which synchronization is called.

    Note:
        WAPE output is a non-negative floating point. Best result is 0.0 .
    """
    is_differentiable = True
    higher_is_better = False
    sum_abs_error: Tensor
    sum_scale: Tensor

    def __init__(
        self,
        compute_on_step: bool = True,
        dist_sync_on_step: bool = False,
        process_group: Optional[Any] = None,
        dist_sync_fn: Callable = None,
    ) -> None:
        super().__init__(
            compute_on_step=compute_on_step,
            dist_sync_on_step=dist_sync_on_step,
            process_group=process_group,
            dist_sync_fn=dist_sync_fn,
        )

        self.add_state("sum_abs_error", default=tensor(0.0), dist_reduce_fx="sum")
        self.add_state("sum_scale", default=tensor(0.0), dist_reduce_fx="sum")

    def update(self, preds: Tensor, target: Tensor) -> None:  # type: ignore
        """Update state with predictions and targets.

        Args:
            preds: Predictions from model
            target: Ground truth values
        """
        sum_abs_error, sum_scale = _weighted_absolute_percentage_error_update(preds, target)

        self.sum_abs_error += sum_abs_error
        self.sum_scale += sum_scale

    def compute(self) -> Tensor:
        """Computes weighted absolute percentage error over state."""
        return _weighted_absolute_percentage_error_compute(self.sum_abs_error, self.sum_scale)
@Guan-t7 Guan-t7 added the enhancement New feature or request label Apr 4, 2022
@github-actions
Copy link

github-actions bot commented Apr 4, 2022

Hi! thanks for your contribution!, great first issue!

@SkafteNicki SkafteNicki mentioned this issue Apr 11, 2022
4 tasks
@SkafteNicki SkafteNicki added this to the v0.8 milestone Apr 11, 2022
@Borda Borda added this to the v0.8 milestone May 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request New metric
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants