Skip to content

Commit

Permalink
style: 💄 pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Dec 15, 2023
1 parent 1e09f24 commit 333bc64
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
4 changes: 2 additions & 2 deletions basicts/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
def l1_loss(prediction: torch.Tensor, target: torch._tensor, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = "mean") -> torch.Tensor:
"""unmasked mae."""

return F.l1_loss(prediction, target)
return F.l1_loss(prediction, target, size_average=size_average, reduce=reduce, reduction=reduction)


def l2_loss(prediction: torch.Tensor, target: torch.Tensor, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = "mean") -> torch.Tensor:
"""unmasked mse"""

return F.mse_loss(prediction, target)
return F.mse_loss(prediction, target, size_average=size_average, reduce=reduce, reduction=reduction)


def masked_mae(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor:
Expand Down
25 changes: 13 additions & 12 deletions basicts/utils/m4.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class M4Dataset:
values: np.ndarray

@staticmethod
def load(info_file_path: str = None, data: np.array = None) -> 'M4Dataset':
def load(info_file_path: str = None, data: np.array = None) -> "M4Dataset":
"""
Load cached dataset.
Expand Down Expand Up @@ -165,20 +165,20 @@ def group_count(group_name):
return len(np.where(self.test_set.groups == group_name)[0])

weighted_score = {}
for g in ['Yearly', 'Quarterly', 'Monthly']:
for g in ["Yearly", "Quarterly", "Monthly"]:
weighted_score[g] = scores[g] * group_count(g)
scores_summary[g] = scores[g]

others_score = 0
others_count = 0
for g in ['Weekly', 'Daily', 'Hourly']:
for g in ["Weekly", "Daily", "Hourly"]:
others_score += scores[g] * group_count(g)
others_count += group_count(g)
weighted_score['Others'] = others_score
scores_summary['Others'] = others_score / others_count
weighted_score["Others"] = others_score
scores_summary["Others"] = others_score / others_count

average = np.sum(list(weighted_score.values())) / len(self.test_set.groups)
scores_summary['Average'] = average
scores_summary["Average"] = average

return scores_summary

Expand All @@ -187,7 +187,8 @@ def m4_summary(save_dir, project_dir):
"""Summary evaluation for M4 dataset.
Args:
save_dir (str): Directory where prediction results are saved. All "{save_dir}/M4_{seasonal pattern}.npy" should exist. Seasonal patterns = ["Yearly", "Quarterly", "Monthly", "Weekly", "Daily", "Hourly"]
save_dir (str): Directory where prediction results are saved. All "{save_dir}/M4_{seasonal pattern}.npy" should exist.
Seasonal patterns = ["Yearly", "Quarterly", "Monthly", "Weekly", "Daily", "Hourly"]
project_dir (str): Project directory. The M4 raw data should be in "{project_dir}/datasets/raw_data/M4".
"""
seasonal_patterns = ["Yearly", "Quarterly", "Monthly", "Weekly", "Daily", "Hourly"] # the order cannot be changed
Expand All @@ -205,16 +206,16 @@ def build_cache(files: str) -> None:
values = row.values
timeseries_dict[m4id] = values[~np.isnan(values)]
return np.array(list(timeseries_dict.values()), dtype=object)

print("Building cache for M4 dataset...")
# read prediction and ground truth
prediction = []
for seasonal_pattern in seasonal_patterns:
prediction.extend(np.load(save_dir + "/M4_{0}.npy".format(seasonal_pattern)))
prediction = np.array(prediction, dtype=object)
train_values = build_cache('*-train.csv')
test_values = build_cache('*-test.csv')
print("Summarizing M4 dataset...")
train_values = build_cache("*-train.csv")
test_values = build_cache("*-test.csv")
print("Summarizing M4 dataset...")
summary = M4Summary(info_file_path, train_values, test_values, data_dir + "/submission-Naive2.csv")
results = pd.DataFrame(summary.evaluate(prediction), index=['SMAPE', 'OWA'])
results = pd.DataFrame(summary.evaluate(prediction), index=["SMAPE", "OWA"])
return results

0 comments on commit 333bc64

Please sign in to comment.