Skip to content

Commit

Permalink
Merge pull request #521 from AIStream-Peelout/gpu_class_bug
Browse files Browse the repository at this point in the history
DSANet + GPU bug
  • Loading branch information
isaacmg authored Apr 13, 2022
2 parents 5fa534e + 24c43af commit cd605b0
Show file tree
Hide file tree
Showing 16 changed files with 450 additions and 14 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ jobs:
coverage run flood_forecast/trainer.py -p tests/custom_encode.json
coverage run flood_forecast/trainer.py -p tests/multi_decoder_test.json
coverage run flood_forecast/trainer.py -p tests/test_dual.json
coverage run flood_forecast/trainer.py -p tests/dsanet.json
- store_test_results:
path: test-results

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Using the library
7. [Transformer XL](https://arxiv.org/abs/1901.02860): Porting Transformer XL for time series.
8. [Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting](https://arxiv.org/abs/2012.07436) (Informer)
9. [DeepAR](https://arxiv.org/abs/1704.04110)
10. [DSANet](https://www.semanticscholar.org/paper/DSANet%3A-Dual-Self-Attention-Network-for-Time-Series-Huang-Wang/6645a09c742760144e4ba0a6f6652e429b1bf107)

**Forthcoming Models**

Expand Down
4 changes: 3 additions & 1 deletion flood_forecast/model_dict_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from flood_forecast.transformer_xl.transformer_bottleneck import DecoderTransformer
from flood_forecast.custom.dilate_loss import DilateLoss
from flood_forecast.meta_models.basic_ae import AE
from flood_forecast.transformer_xl.dsanet import DSANet

"""
Utility dictionaries to map a string to a class
Expand All @@ -32,7 +33,8 @@
"DARNN": DARNN,
"DecoderTransformer": DecoderTransformer,
"BasicAE": AE,
"Informer": Informer
"Informer": Informer,
"DSANet": DSANet
}

pytorch_criterion_dict = {
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/preprocessing/process_usgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def make_usgs_data(start_date: datetime, end_date: datetime, site_number: str) -> pd.DataFrame:
""" """
""""""
base_url = "https://nwis.waterdata.usgs.gov/usa/nwis/uv/?cb_00060=on&cb_00065&format=rdb&"
full_url = base_url + "site_no=" + site_number + "&period=&begin_date=" + \
start_date.strftime("%Y-%m-%d") + "&end_date=" + end_date.strftime("%Y-%m-%d")
Expand Down
4 changes: 2 additions & 2 deletions flood_forecast/pytorch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def compute_validation(validation_loader: DataLoader,
wandb.log({"roc_" + str(epoch): wandb.plot.roc_curve(fin, mod_output1, classes_to_plot=None, labels=None,
title="roc_" + str(epoch))})
wandb.log({"pr": wandb.plot.pr_curve(fin, mod_output1)})
wandb.log({"conf_": wandb.plot.confusion_matrix(probs=mod_output1.detach().numpy(), y_true=fin.detach().numpy(),
class_names=None)})
wandb.log({"conf_": wandb.plot.confusion_matrix(probs=mod_output1.detach().cpu().numpy(),
y_true=fin.detach().cpu().numpy(), class_names=None)})
model.train()
return list(scaled_crit.values())[0]
2 changes: 1 addition & 1 deletion flood_forecast/time_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def save_model(self, final_path: str, epoch: int) -> None:
print("Wandb stupid error")
print(e.__traceback__)

def __re_add_params__(self, start_end_params, dataset_params, data_path):
def __re_add_params__(self, start_end_params: Dict, dataset_params, data_path):
"""
Function to re-add the params to the model
"""
Expand Down
Loading

0 comments on commit cd605b0

Please sign in to comment.