Skip to content

Commit

Permalink
modified: dsm/dsm_api.py
Browse files Browse the repository at this point in the history
	modified:   dsm/dsm_torch.py
	modified:   dsm/utilities.py
  • Loading branch information
chiragnagpal committed Jan 30, 2021
1 parent 5252d6f commit 49a9580
Show file tree
Hide file tree
Showing 3 changed files with 375 additions and 143 deletions.
73 changes: 57 additions & 16 deletions dsm/dsm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from dsm.dsm_torch import DeepSurvivalMachinesTorch
from dsm.dsm_torch import DeepRecurrentSurvivalMachinesTorch
from dsm.dsm_torch import DeepConvolutionalSurvivalMachinesTorch
from dsm.dsm_torch import DeepCNNRNNSurvivalMachinesTorch

import dsm.losses as losses

Expand Down Expand Up @@ -109,7 +110,13 @@ def fit(self, x, t, e, vsize=0.15,
random_state)
x_train, t_train, e_train, x_val, t_val, e_val = processed_data

inputdim = x_train.shape[-1]
#Todo: Change this somehow. The base design shouldn't depend on child
if type(self).__name__ in ["DeepConvolutionalSurvivalMachines",
"DeepCNNRNNSurvivalMachines"]:
inputdim = tuple(x_train.shape)[-2:]
else:
inputdim = x_train.shape[-1]

maxrisk = int(np.nanmax(e_train.cpu().numpy()))
model = self._gen_torch_model(inputdim, optimizer, risks=maxrisk)
model, _ = train_dsm(model,
Expand All @@ -122,7 +129,7 @@ def fit(self, x, t, e, vsize=0.15,

self.torch_model = model.eval()
self.fitted = True

return self


Expand All @@ -140,11 +147,11 @@ def compute_nll(self, x, t, e):
e: np.ndarray
A numpy array of the event/censoring indicators, \( \delta \).
\( \delta = r \) means the event r took place.
Returns:
float: Negative log likelihood.
"""
if not(self.fitted):
if not self.fitted:
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `_eval_nll`.")
Expand All @@ -155,9 +162,9 @@ def compute_nll(self, x, t, e):
_reshape_tensor_with_nans(e_val)
loss = 0
for r in range(self.torch_model.risks):
loss += float(losses.conditional_loss(self.torch_model,
x_val, t_val, e_val, elbo=False,
risk=str(r+1)).detach().numpy())
loss += float(losses.conditional_loss(self.torch_model,
x_val, t_val, e_val, elbo=False,
risk=str(r+1)).detach().numpy())
return loss

def _prepocess_test_data(self, x):
Expand Down Expand Up @@ -369,6 +376,7 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state):
if vsize is not None:

vsize = int(vsize*x_train.shape[0])

x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:]

x_train = x_train[:-vsize]
Expand All @@ -389,7 +397,7 @@ class DeepConvolutionalSurvivalMachines(DSMBase):
"""

def __init__(self, k=3, layers=None, hidden=None,
distribution='Weibull', temp=1000., discount=1.0, typ='ConvNet'):
distribution="Weibull", temp=1000., discount=1.0, typ="ConvNet"):
super(DeepConvolutionalSurvivalMachines, self).__init__(k=k,
distribution=distribution,
temp=temp,
Expand All @@ -399,11 +407,44 @@ def __init__(self, k=3, layers=None, hidden=None,
def _gen_torch_model(self, inputdim, optimizer, risks):
"""Helper function to return a torch model."""
return DeepConvolutionalSurvivalMachinesTorch(inputdim,
k=self.k,
hidden=self.hidden,
dist=self.dist,
temp=self.temp,
discount=self.discount,
optimizer=optimizer,
typ=self.typ,
risks=risks)
k=self.k,
hidden=self.hidden,
dist=self.dist,
temp=self.temp,
discount=self.discount,
optimizer=optimizer,
typ=self.typ,
risks=risks)


class DeepCNNRNNSurvivalMachines(DeepRecurrentSurvivalMachines):

"""The Deep CNN-RNN Survival Machines model to handle data with
moving image streams.
"""

def __init__(self, k=3, layers=None, hidden=None,
distribution="Weibull", temp=1000., discount=1.0, typ="LSTM"):
super(DeepCNNRNNSurvivalMachines, self).__init__(k=k,
layers=layers,
distribution=distribution,
temp=temp,
discount=discount)
self.hidden = hidden
self.typ = typ

def _gen_torch_model(self, inputdim, optimizer, risks):
"""Helper function to return a torch model."""
return DeepCNNRNNSurvivalMachinesTorch(inputdim,
k=self.k,
layers=self.layers,
hidden=self.hidden,
dist=self.dist,
temp=self.temp,
discount=self.discount,
optimizer=optimizer,
typ=self.typ,
risks=risks)


Loading

0 comments on commit 49a9580

Please sign in to comment.