From 9ea6c2c15417ca5084ba4fa1eabbe27e159c7571 Mon Sep 17 00:00:00 2001 From: Chufan Gao Date: Mon, 28 Dec 2020 14:28:11 -0500 Subject: [PATCH] linting --- dsm/dsm_torch.py | 4 +++- dsm/utilities.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dsm/dsm_torch.py b/dsm/dsm_torch.py index ab19077..9d1a684 100644 --- a/dsm/dsm_torch.py +++ b/dsm/dsm_torch.py @@ -473,7 +473,9 @@ def __init__(self, inputdim, k, typ='ConvNet', nn.Linear(hidden, k, bias=True) ) for r in range(self.risks)}) - self.embedding = create_conv_representation(inputdim=inputdim, hidden=hidden, typ='ConvNet') + self.embedding = create_conv_representation(inputdim=inputdim, + hidden=hidden, + typ='ConvNet') def forward(self, x, risk='1'): """The forward function that is called when data is passed through DSM. diff --git a/dsm/utilities.py b/dsm/utilities.py index a2f61be..49b647e 100644 --- a/dsm/utilities.py +++ b/dsm/utilities.py @@ -165,7 +165,7 @@ def train_dsm(model, _reshape_tensor_with_nans(eb), elbo=elbo, risk=str(r+1)) - print ("Train Loss:", float(loss), "batch num:", j, "/", nbatches, "n_iter:", i, "/", n_iter) + #print ("Train Loss:", float(loss)) loss.backward() optimizer.step() valid_loss = 0