Skip to content

Commit

Permalink
modified: dsm_api.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Feb 7, 2021
1 parent cd648ec commit 95d1057
Showing 1 changed file with 34 additions and 18 deletions.
52 changes: 34 additions & 18 deletions dsm/dsm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _gen_torch_model(self, inputdim, optimizer, risks):
optimizer=optimizer,
risks=risks)

def fit(self, x, t, e, vsize=0.15,
def fit(self, x, t, e, vsize=0.15, val_data=None,
iters=1, learning_rate=1e-3, batch_size=100,
elbo=True, optimizer="Adam", random_state=100):

Expand All @@ -88,6 +88,8 @@ def fit(self, x, t, e, vsize=0.15,
\( \delta = 1 \) means the event took place.
vsize: float
Amount of data to set aside as the validation set.
val_data: tuple
A tuple of the validation dataset. If passed vsize is ignored.
iters: int
The maximum number of training iterations on the training dataset.
learning_rate: float
Expand All @@ -106,7 +108,8 @@ def fit(self, x, t, e, vsize=0.15,
"""

processed_data = self._prepocess_training_data(x, t, e, vsize,
processed_data = self._prepocess_training_data(x, t, e,
vsize, val_data,
random_state)
x_train, t_train, e_train, x_val, t_val, e_val = processed_data

Expand Down Expand Up @@ -155,7 +158,7 @@ def compute_nll(self, x, t, e):
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `_eval_nll`.")
processed_data = self._prepocess_training_data(x, t, e, 0, 0)
processed_data = self._prepocess_training_data(x, t, e, 0, None, 0)
_, _, _, x_val, t_val, e_val = processed_data
x_val, t_val, e_val = x_val,\
_reshape_tensor_with_nans(t_val),\
Expand All @@ -170,7 +173,7 @@ def compute_nll(self, x, t, e):
def _prepocess_test_data(self, x):
return torch.from_numpy(x)

def _prepocess_training_data(self, x, t, e, vsize, random_state):
def _prepocess_training_data(self, x, t, e, vsize, val_data, random_state):

idx = list(range(x.shape[0]))
np.random.seed(random_state)
Expand All @@ -181,21 +184,25 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state):
t_train = torch.from_numpy(t_train).double()
e_train = torch.from_numpy(e_train).double()

if vsize is not None:
if val_data is None:

vsize = int(vsize*x_train.shape[0])

x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:]

x_train = x_train[:-vsize]
t_train = t_train[:-vsize]
e_train = e_train[:-vsize]

return (x_train, t_train, e_train,
x_val, t_val, e_val)

else:
return (x_train, t_train, e_train,
x_train, t_train, e_train)

x_val, t_val, e_val = val_data

x_val = torch.from_numpy(x_val).double()
t_val = torch.from_numpy(t_val).double()
e_val = torch.from_numpy(e_val).double()

return (x_train, t_train, e_train,
x_val, t_val, e_val)


def predict_mean(self, x, risk=1):
Expand Down Expand Up @@ -266,7 +273,7 @@ def predict_survival(self, x, t, risk=1):
else:
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `predict_risk`.")
"before calling `predict_survival`.")


class DeepSurvivalMachines(DSMBase):
Expand Down Expand Up @@ -356,7 +363,7 @@ def _gen_torch_model(self, inputdim, optimizer, risks):
def _prepocess_test_data(self, x):
return torch.from_numpy(_get_padded_features(x))

def _prepocess_training_data(self, x, t, e, vsize, random_state):
def _prepocess_training_data(self, x, t, e, vsize, val_data, random_state):
"""RNNs require different preprocessing for variable length sequences"""

idx = list(range(x.shape[0]))
Expand All @@ -373,7 +380,7 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state):
t_train = torch.from_numpy(t_train).double()
e_train = torch.from_numpy(e_train).double()

if vsize is not None:
if val_data is None:

vsize = int(vsize*x_train.shape[0])

Expand All @@ -383,11 +390,20 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state):
t_train = t_train[:-vsize]
e_train = e_train[:-vsize]

return (x_train, t_train, e_train,
x_val, t_val, e_val)
else:
return (x_train, t_train, e_train,
x_train, t_train, e_train)

x_val, t_val, e_val = val_data

x_val = _get_padded_features(x_val)
t_val = _get_padded_features(t_val)
e_val = _get_padded_features(e_val)

x_val = torch.from_numpy(x_val).double()
t_val = torch.from_numpy(t_val).double()
e_val = torch.from_numpy(e_val).double()

return (x_train, t_train, e_train,
x_val, t_val, e_val)


class DeepConvolutionalSurvivalMachines(DSMBase):
Expand Down

0 comments on commit 95d1057

Please sign in to comment.