Skip to content

Commit

Permalink
removed_trailing_whitespace
Browse files Browse the repository at this point in the history
  • Loading branch information
Chufan Gao committed Dec 29, 2020
1 parent 427a6d3 commit 92de282
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
12 changes: 6 additions & 6 deletions dsm/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,24 +197,24 @@ def _load_support_dataset():
def _load_mnist():
"""Helper function to load and preprocess the MNIST dataset.
The MNIST database of handwritten digits, available from this page, has a
training set of 60,000 examples, and a test set of 10,000 examples.
It is a good database for people who want to try learning techniques and
pattern recognition methods on real-world data while spending minimal
The MNIST database of handwritten digits, available from this page, has a
training set of 60,000 examples, and a test set of 10,000 examples.
It is a good database for people who want to try learning techniques and
pattern recognition methods on real-world data while spending minimal
efforts on preprocessing and formatting [1].
Please refer to http://yann.lecun.com/exdb/mnist/.
for the original datasource.
References
----------
[1]: LeCun, Y. (1998). The MNIST database of handwritten digits.
[1]: LeCun, Y. (1998). The MNIST database of handwritten digits.
http://yann.lecun.com/exdb/mnist/.
"""


train = torchvision.datasets.MNIST(root='datasets/',
train = torchvision.datasets.MNIST(root='datasets/',
train=True, download=True)
x = train.data.numpy()
x = np.expand_dims(x, 1).astype(float)
Expand Down
9 changes: 4 additions & 5 deletions dsm/dsm_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@

import torch.nn as nn
import torch
import torchvision
import numpy as np

__pdoc__ = {}

for clsn in ['DeepSurvivalMachinesTorch',
for clsn in ['DeepSurvivalMachinesTorch',
'DeepRecurrentSurvivalMachinesTorch']:
for membr in ['training', 'dump_patches']:

Expand Down Expand Up @@ -368,7 +367,7 @@ def create_conv_representation(inputdim, hidden, typ='ConvNet'):

if typ == 'ConvNet':
inputdim = np.squeeze(inputdim)
linear_dim = ((((inputdim-2) // 2) - 2) // 2) ** 2
linear_dim = ((((inputdim-2) // 2) - 2) // 2) ** 2
linear_dim *= 16
embedding = nn.Sequential(
nn.Conv2d(1, 6, 3),
Expand Down Expand Up @@ -474,8 +473,8 @@ def __init__(self, inputdim, k, typ='ConvNet',
nn.Linear(hidden, k, bias=True)
) for r in range(self.risks)})

self.embedding = create_conv_representation(inputdim=inputdim,
hidden=hidden,
self.embedding = create_conv_representation(inputdim=inputdim,
hidden=hidden,
typ='ConvNet')

def forward(self, x, risk='1'):
Expand Down

0 comments on commit 92de282

Please sign in to comment.