Skip to content

Commit

Permalink
modified: estimators.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Mar 29, 2022
1 parent bb1e83e commit 00d9d34
Showing 1 changed file with 75 additions and 73 deletions.
148 changes: 75 additions & 73 deletions auton_survival/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import numpy as np
import pandas as pd

from auton_survival.models.dcm.dcm_api import DeepCoxMixtures

def _get_valid_idx(n, size, random_seed):

"""Randomly select sample indices to split into train and test data.
Expand Down Expand Up @@ -78,68 +80,73 @@ def _fit_dcm(features, outcomes, random_seed, **hyperparams):
random_seed : int
Controls the rproduecibility of fitted estimators.
hyperparams : dict
Optional arguments for the estimator stored in a python dictionary.
Optional kwarg arguments.
Keys correspond to parameter names as strings and items correspond
to parameter values.
Options include:
- 'k' : int, default=3
Size of the underlying Cox mixtures.
- 'layers' : list, default=[100]
A list consisting of the number of neurons in each hidden layer.
- 'bs' : int, default=128
- 'batch_size' : int, default=128
Learning is performed on mini-batches of input data. This parameter
specifies the size of each mini-batch.
- 'lr' : float, default=1e-3
Learning rate for the 'Adam' optimizer.
- 'epochs' : int, default=50
Number of complete passes through the training data.
-'smoothing_factor' : int, default=0
-'smoothing_factor' : float, default=1e-4
Returns
-----------
Trained instance of the Deep Cox Mixtures model.
np.array : breslow splines used to interpolate baseline survival rates.
np.array : A float or list of the times at which to compute the survival probability.
"""
raise NotImplementedError()

from .models.dcm import DeepCoxMixtures

import torch
torch.manual_seed(random_seed)
np.random.seed(random_seed)

k = hyperparams.get("k", 3)
layers = hyperparams.get("layers", [100])
bs = hyperparams.get("bs", 128)
batch_size = hyperparams.get("batch_size", 128)
lr = hyperparams.get("lr", 1e-3)
epochs = hyperparams.get("epochs", 50)
smoothing_factor = hyperparams.get("smoothing_factor", 0)
smoothing_factor = hyperparams.get("smoothing_factor", 1e-4)
gamma = hyperparams.get("gamma", 10)

model = DeepCoxMixtures(k=k,
layers=layers,
gamma=gamma,
smoothing_factor=smoothing_factor)

model.fit(features.values, outcomes.time.values, outcomes.event.values,
iters=epochs, batch_size=batch_size, lr=lr,
random_seed=random_seed)

return model

if len(layers): model = DeepCoxMixture(k=k, inputdim=features.shape[1], hidden=layers[0])
else: model = CoxMixture(k=k, inputdim=features.shape[1])
# if len(layers): model = DeepCoxMixture(k=k, inputdim=features.shape[1], hidden=layers[0])
# else: model = CoxMixture(k=k, inputdim=features.shape[1])

x = torch.from_numpy(features.values.astype('float32'))
t = torch.from_numpy(outcomes['time'].values.astype('float32'))
e = torch.from_numpy(outcomes['event'].values.astype('float32'))
# x = torch.from_numpy(features.values.astype('float32'))
# t = torch.from_numpy(outcomes['time'].values.astype('float32'))
# e = torch.from_numpy(outcomes['event'].values.astype('float32'))

vidx = _get_valid_idx(x.shape[0], 0.15, random_seed)
# vidx = _get_valid_idx(x.shape[0], 0.15, random_seed)

train_data = (x[~vidx], t[~vidx], e[~vidx])
val_data = (x[vidx], t[vidx], e[vidx])
# train_data = (x[~vidx], t[~vidx], e[~vidx])
# val_data = (x[vidx], t[vidx], e[vidx])

(model, breslow_splines, unique_times) = train(model,
train_data,
val_data,
epochs=epochs,
lr=lr, bs=bs,
use_posteriors=True,
patience=5,
return_losses=False,
smoothing_factor=smoothing_factor)
# (model, breslow_splines, unique_times) = train(model,
# train_data,
# val_data,
# epochs=epochs,
# lr=lr, bs=bs,
# use_posteriors=True,
# patience=5,
# return_losses=False,
# smoothing_factor=smoothing_factor)

return (model, breslow_splines, unique_times)
#return (model, breslow_splines, unique_times)

def _predict_dcm(model, features, times):

Expand All @@ -162,14 +169,9 @@ def _predict_dcm(model, features, times):
"""

raise NotImplementedError()

from sdcm.dcm_utils import predict_scores
#raise NotImplementedError()

import torch
x = torch.from_numpy(features.values.astype('float32'))

survival_predictions = predict_scores(model, x, None, model[-1])
survival_predictions = model.predict_survival(features, times)
if len(times)>1:
survival_predictions = pd.DataFrame(survival_predictions, columns=times).T
return __interpolate_missing_times(survival_predictions, times)
Expand Down Expand Up @@ -221,52 +223,52 @@ def _fit_dcph(features, outcomes, random_seed, **hyperparams):
Trained instance of the Deep Cox Proportional Hazards model.
"""
raise NotImplementedError()
# import torch
# import torchtuples as ttup

import torch
import torchtuples as ttup

from pycox.models import CoxPH
# from pycox.models import CoxPH

torch.manual_seed(random_seed)
np.random.seed(random_seed)
# torch.manual_seed(random_seed)
# np.random.seed(random_seed)

layers = hyperparams.get('layers', [100])
lr = hyperparams.get('lr', 1e-3)
bs = hyperparams.get('bs', 100)
epochs = hyperparams.get('epochs', 50)
activation = hyperparams.get('activation', 'relu')
# layers = hyperparams.get('layers', [100])
# lr = hyperparams.get('lr', 1e-3)
# bs = hyperparams.get('bs', 100)
# epochs = hyperparams.get('epochs', 50)
# activation = hyperparams.get('activation', 'relu')

if activation == 'relu': activation = torch.nn.ReLU
elif activation == 'relu6': activation = torch.nn.ReLU6
elif activation == 'tanh': activation = torch.nn.Tanh
else: raise NotImplementedError("Activation function not implemented")
# if activation == 'relu': activation = torch.nn.ReLU
# elif activation == 'relu6': activation = torch.nn.ReLU6
# elif activation == 'tanh': activation = torch.nn.Tanh
# else: raise NotImplementedError("Activation function not implemented")

x = features.values.astype('float32')
t = outcomes['time'].values.astype('float32')
e = outcomes['event'].values.astype('bool')
# x = features.values.astype('float32')
# t = outcomes['time'].values.astype('float32')
# e = outcomes['event'].values.astype('bool')

in_features = x.shape[1]
out_features = 1
batch_norm = False
dropout = 0.0
# in_features = x.shape[1]
# out_features = 1
# batch_norm = False
# dropout = 0.0

net = ttup.practical.MLPVanilla(in_features, layers,
out_features, batch_norm, dropout,
activation=activation,
output_bias=False)
# net = ttup.practical.MLPVanilla(in_features, layers,
# out_features, batch_norm, dropout,
# activation=activation,
# output_bias=False)

model = CoxPH(net, torch.optim.Adam)
# model = CoxPH(net, torch.optim.Adam)

vidx = _get_valid_idx(x.shape[0], 0.15, random_seed)
# vidx = _get_valid_idx(x.shape[0], 0.15, random_seed)

y_train, y_val = (t[~vidx], e[~vidx]), (t[vidx], e[vidx])
val_data = x[vidx], y_val
# y_train, y_val = (t[~vidx], e[~vidx]), (t[vidx], e[vidx])
# val_data = x[vidx], y_val

callbacks = [ttup.callbacks.EarlyStopping()]
model.fit(x[~vidx], y_train, bs, epochs, callbacks, True,
val_data=val_data,
val_batch_size=bs)
model.compute_baseline_hazards()
# callbacks = [ttup.callbacks.EarlyStopping()]
# model.fit(x[~vidx], y_train, bs, epochs, callbacks, True,
# val_data=val_data,
# val_batch_size=bs)
# model.compute_baseline_hazards()

return model

Expand Down

0 comments on commit 00d9d34

Please sign in to comment.