From 75eae85b7317b6a27de0100aeaeff068f6e1c6a1 Mon Sep 17 00:00:00 2001 From: Zezhi Shao <864453277@qq.com> Date: Wed, 11 Dec 2024 13:07:55 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20=E2=9C=8F=EF=B8=8F=20fix=20lint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- basicts/metrics/corr.py | 2 +- basicts/metrics/r_square.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/basicts/metrics/corr.py b/basicts/metrics/corr.py index 95cbdcd..4e10ee4 100644 --- a/basicts/metrics/corr.py +++ b/basicts/metrics/corr.py @@ -47,4 +47,4 @@ def masked_corr(prediction: torch.Tensor, target: torch.Tensor, null_val: float loss = loss * mask # Apply the mask to the loss loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero - return torch.mean(loss) \ No newline at end of file + return torch.mean(loss) diff --git a/basicts/metrics/r_square.py b/basicts/metrics/r_square.py index 5034440..2d3ce51 100644 --- a/basicts/metrics/r_square.py +++ b/basicts/metrics/r_square.py @@ -36,9 +36,9 @@ def masked_r2(prediction: torch.Tensor, target: torch.Tensor, null_val: float = ss_res = torch.sum(torch.pow((target - prediction), 2), dim=1) # 残差平方和 ss_tot = torch.sum(torch.pow(target - torch.mean(target, dim=1, keepdim=True), 2), dim=1) # 总平方和 - + # 计算 R^2 loss = 1 - (ss_res / (ss_tot + eps)) - + loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero return torch.mean(loss)