Skip to content

Commit

Permalink
conv_model
Browse files Browse the repository at this point in the history
  • Loading branch information
Chufan Gao committed Dec 11, 2020
1 parent 9972b3b commit cb20d8d
Show file tree
Hide file tree
Showing 16 changed files with 867 additions and 7 deletions.
2 changes: 1 addition & 1 deletion dsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,4 @@
"""

from dsm.dsm_api import DeepSurvivalMachines, DeepRecurrentSurvivalMachines
from dsm.dsm_api import DeepSurvivalMachines, DeepRecurrentSurvivalMachines, DeepConvolutionalSurvivalMachines
33 changes: 33 additions & 0 deletions dsm/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler

import torchvision

def increase_censoring(e, t, p):

uncens = np.where(e == 1)[0]
Expand Down Expand Up @@ -192,6 +194,35 @@ def _load_support_dataset():
remove = ~np.isnan(t)
return x[remove], t[remove], e[remove]

def _load_mnist():
"""Helper function to load and preprocess the MNIST dataset.
The MNIST database of handwritten digits, available from this page, has a
training set of 60,000 examples, and a test set of 10,000 examples.
It is a good database for people who want to try learning techniques and
pattern recognition methods on real-world data while spending minimal
efforts on preprocessing and formatting [1].
Please refer to http://yann.lecun.com/exdb/mnist/.
for the original datasource.
References
----------
[1]: LeCun, Y. (1998). The MNIST database of handwritten digits.
http://yann.lecun.com/exdb/mnist/.
"""


train = torchvision.datasets.MNIST(root='datasets/',
train=True, download=True)
x = train.data.numpy()
x = np.expand_dims(x, 1).astype(float)
t = train.targets.numpy().astype(float) + 1

e, t = increase_censoring(np.ones(t.shape), t, p=.5)

return x, t, e

def load_dataset(dataset='SUPPORT', **kwargs):
"""Helper function to load datasets to test Survival Analysis models.
Expand Down Expand Up @@ -249,5 +280,7 @@ def load_dataset(dataset='SUPPORT', **kwargs):
return _load_pbc_dataset(sequential)
if dataset == 'FRAMINGHAM':
return _load_framingham_dataset(sequential)
if dataset == 'MNIST':
return _load_mnist()
else:
raise NotImplementedError('Dataset '+dataset+' not implemented.')
38 changes: 32 additions & 6 deletions dsm/dsm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"""

from dsm.dsm_torch import DeepSurvivalMachinesTorch
from dsm.dsm_torch import DeepRecurrentSurvivalMachinesTorch
from dsm.dsm_torch import DeepRecurrentSurvivalMachinesTorch, DeepConvolutionalSurvivalMachinesTorch
from dsm.losses import predict_cdf
import dsm.losses as losses
from dsm.utilities import train_dsm, _get_padded_features, _get_padded_targets
Expand Down Expand Up @@ -66,7 +66,8 @@ def _gen_torch_model(self, inputdim, optimizer, risks):

def fit(self, x, t, e, vsize=0.15,
iters=1, learning_rate=1e-3, batch_size=100,
elbo=True, optimizer="Adam", random_state=100):
elbo=True, optimizer="Adam", random_state=100,
cuda=False):

r"""This method is used to train an instance of the DSM model.
Expand Down Expand Up @@ -185,7 +186,7 @@ def predict_risk(self, x, t, risk=1):
"before calling `predict_risk`.")


def predict_survival(self, x, t, risk=1):
def predict_survival(self, x, t, risk=1, cuda=False):
r"""Returns the estimated survival probability at time \( t \),
\( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \).
Expand Down Expand Up @@ -327,6 +328,31 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state):
x_val, t_val, e_val)


class DeepConvolutionalSurvivalMachines(DeepRecurrentSurvivalMachines):
__doc__ = "..warning:: Not Implemented"
pass
class DeepConvolutionalSurvivalMachines(DSMBase):
"""The Deep Convolutional Survival Machines model to handle data with
image-based covariates.
"""

def __init__(self, k=3, layers=None, hidden=None,
distribution='Weibull', temp=1000., discount=1.0, typ='ConvNet'):
super(DeepConvolutionalSurvivalMachines, self).__init__(k=k,
layers=layers,
distribution=distribution,
temp=temp,
discount=discount)
self.hidden = hidden
self.typ = typ
def _gen_torch_model(self, inputdim, optimizer, risks):
"""Helper function to return a torch model."""
return DeepConvolutionalSurvivalMachinesTorch(inputdim,
k=self.k,
layers=self.layers,
hidden=self.hidden,
dist=self.dist,
temp=self.temp,
discount=self.discount,
optimizer=optimizer,
typ=self.typ,
risks=risks)

129 changes: 129 additions & 0 deletions dsm/dsm_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

import torch.nn as nn
import torch
import torchvision
import torch.nn.functional as F

__pdoc__ = {}

Expand Down Expand Up @@ -337,3 +339,130 @@ def forward(self, x, risk='1'):
def get_shape_scale(self, risk='1'):
return(self.shape[risk],
self.scale[risk])

class DeepConvolutionalSurvivalMachinesTorch(nn.Module):
"""A Torch implementation of Deep Convolutional Survival Machines model.
This is an implementation of Deep Convolutional Survival Machines model
in torch. It inherits from `DeepSurvivalMachinesTorch` and replaces the
input representation learning MLP with an simple convnet, the parameters of the
underlying distributions and the forward function which is called whenever
data is passed to the module. Each of the parameters are nn.Parameters and
torch automatically keeps track and computes gradients for them.
.. warning::
Not designed to be used directly.
Please use the API inferface `dsm.dsm_api.DeepConvolutionalSurvivalMachines`!!
Parameters
----------
inputdim: int
Dimensionality of the input features.
k: int
The number of underlying parametric distributions.
layers: int
The number of hidden layers in the LSTM or RNN cell.
hidden: int
The number of neurons in each hidden layer.
init: tuple
A tuple for initialization of the parameters for the underlying
distributions. (shape, scale).
dist: str
Choice of the underlying survival distributions.
One of 'Weibull', 'LogNormal'.
Default is 'Weibull'.
temp: float
The logits for the gate are rescaled with this value.
Default is 1000.
discount: float
a float in [0,1] that determines how to discount the tail bias
from the uncensored instances.
Default is 1.
"""

def __init__(self, inputdim, k, typ='ResNet', layers=1,
hidden=None, dist='Weibull',
temp=1000., discount=1.0, optimizer='Adam', risks=1):
super(DeepConvolutionalSurvivalMachinesTorch, self).__init__()

self.k = k
self.dist = dist
self.temp = float(temp)
self.discount = float(discount)
self.optimizer = optimizer
self.hidden = hidden
self.layers = layers
self.typ = typ
self.risks = risks

if self.dist in ['Weibull']:
self.act = nn.SELU()
self.shape = nn.ParameterDict({str(r+1): nn.Parameter(-torch.ones(k))
for r in range(self.risks)})
self.scale = nn.ParameterDict({str(r+1):nn.Parameter(-torch.ones(k))
for r in range(self.risks)})
elif self.dist in ['Normal']:
self.act = nn.Identity()
self.shape = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k))
for r in range(self.risks)})
self.scale = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k))
for r in range(self.risks)})
elif self.dist in ['LogNormal']:
self.act = nn.Tanh()
self.shape = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k))
for r in range(self.risks)})
self.scale = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k))
for r in range(self.risks)})
else:
raise NotImplementedError('Distribution: '+self.dist+' not implemented'+
' yet.')

self.gate = nn.ModuleDict({str(r+1): nn.Sequential(
nn.Linear(hidden, k, bias=False)
) for r in range(self.risks)})

self.scaleg = nn.ModuleDict({str(r+1): nn.Sequential(
nn.Linear(hidden, k, bias=True)
) for r in range(self.risks)})

self.shapeg = nn.ModuleDict({str(r+1): nn.Sequential(
nn.Linear(hidden, k, bias=True)
) for r in range(self.risks)})

if self.typ == 'ConvNet':
# self.cnn = torchvision.models.resnet18(pretrained=True).float()
# self.cnn.conv1 = torch.nn.Conv1d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
# self.linear = torch.nn.Linear(1000, hidden)
self.conv1 = nn.Conv2d(1, 6, 3)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, hidden)


def forward(self, x, risk='1'):
"""The forward function that is called when data is passed through DSM.
Args:
x:
a torch.tensor of the input features.
"""
# xrep = self.linear(self.cnn(x))
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
xrep = self.fc3(x)

dim = x.shape[0]
return(self.act(self.shapeg[risk](xrep))+self.shape[risk].expand(dim, -1),
self.act(self.scaleg[risk](xrep))+self.scale[risk].expand(dim, -1),
self.gate[risk](xrep)/self.temp)

def get_shape_scale(self, risk='1'):
return(self.shape[risk],
self.scale[risk])
Loading

0 comments on commit cb20d8d

Please sign in to comment.