From 59a0f95d8ac3639b649cd6de09eb2dc037bf4afa Mon Sep 17 00:00:00 2001 From: Zezhi Shao <864453277@qq.com> Date: Wed, 11 Dec 2024 13:13:29 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20=E2=9C=8F=EF=B8=8F=20fix=20isort?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- basicts/metrics/__init__.py | 6 +++--- basicts/metrics/smape.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/basicts/metrics/__init__.py b/basicts/metrics/__init__.py index 4d88ba8..66e465b 100644 --- a/basicts/metrics/__init__.py +++ b/basicts/metrics/__init__.py @@ -1,11 +1,11 @@ +from .corr import masked_corr from .mae import masked_mae from .mape import masked_mape from .mse import masked_mse +from .r_square import masked_r2 from .rmse import masked_rmse -from .wape import masked_wape from .smape import masked_smape -from .r_square import masked_r2 -from .corr import masked_corr +from .wape import masked_wape ALL_METRICS = { 'MAE': masked_mae, diff --git a/basicts/metrics/smape.py b/basicts/metrics/smape.py index 48e4166..ebc2308 100644 --- a/basicts/metrics/smape.py +++ b/basicts/metrics/smape.py @@ -1,5 +1,6 @@ -import torch import numpy as np +import torch + def masked_smape(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: """