Skip to content

Commit

Permalink
docs: ✏️ fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Dec 11, 2024
1 parent b39b45b commit 75eae85
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion basicts/metrics/corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ def masked_corr(prediction: torch.Tensor, target: torch.Tensor, null_val: float
loss = loss * mask # Apply the mask to the loss
loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero

return torch.mean(loss)
return torch.mean(loss)
4 changes: 2 additions & 2 deletions basicts/metrics/r_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def masked_r2(prediction: torch.Tensor, target: torch.Tensor, null_val: float =

ss_res = torch.sum(torch.pow((target - prediction), 2), dim=1) # 残差平方和
ss_tot = torch.sum(torch.pow(target - torch.mean(target, dim=1, keepdim=True), 2), dim=1) # 总平方和

# 计算 R^2
loss = 1 - (ss_res / (ss_tot + eps))

loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero
return torch.mean(loss)

0 comments on commit 75eae85

Please sign in to comment.