Skip to content

Commit

Permalink
updates v2
Browse files Browse the repository at this point in the history
  • Loading branch information
PotosnakW committed Apr 20, 2022
1 parent ff77c7f commit 4ff3e69
Show file tree
Hide file tree
Showing 10 changed files with 486 additions and 663 deletions.
135 changes: 91 additions & 44 deletions auton_survival/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _get_valid_idx(n, size, random_seed):

return vidx

def _fit_dcm(features, outcomes, random_seed, **hyperparams):
def _fit_dcm(features, outcomes, vsize, val_data, random_seed, **hyperparams):

r"""Fit the Deep Cox Mixtures (DCM) [1] model to a given dataset.
Expand Down Expand Up @@ -113,12 +113,13 @@ def _fit_dcm(features, outcomes, random_seed, **hyperparams):
gamma=gamma,
smoothing_factor=smoothing_factor,
random_seed=random_seed)
model.fit(features.values, outcomes.time.values, outcomes.event.values,
iters=epochs, batch_size=batch_size, learning_rate=learning_rate)
model.fit(x=features, t=outcomes.time, e=outcomes.event, vsize=vsize,
val_data=val_data, iters=epochs, batch_size=batch_size,
learning_rate=learning_rate)

return model

def _fit_dcph(features, outcomes, random_seed, **hyperparams):
def _fit_dcph(features, outcomes, vsize, val_data, random_seed, **hyperparams):

"""Fit a Deep Cox Proportional Hazards Model/Farragi Simon Network [1,2]
model to a given dataset.
Expand Down Expand Up @@ -166,9 +167,9 @@ def _fit_dcph(features, outcomes, random_seed, **hyperparams):

model = DeepCoxPH(layers=layers, random_seed=random_seed)

model.fit(features.values, outcomes.time.values, outcomes.event.values,
iters=epochs, learning_rate=learning_rate, batch_size=bs,
optimizer="Adam")
model.fit(x=features, t=outcomes.time, e=outcomes.event, vsize=vsize,
val_data=val_data, iters=epochs, batch_size=batch_size,
learning_rate=learning_rate)

return model

Expand Down Expand Up @@ -311,7 +312,7 @@ def _fit_rsf(features, outcomes, random_seed, **hyperparams):
return rsf


def _fit_dsm(features, outcomes, random_seed, **hyperparams):
def _fit_dsm(features, outcomes, vsize, val_data, random_seed, **hyperparams):

"""Fit the Deep Survival Machines (DSM) [1] model to a given dataset.
Expand Down Expand Up @@ -339,13 +340,18 @@ def _fit_dsm(features, outcomes, random_seed, **hyperparams):
Options include:
- 'layers' : list
A list of integers describing the dimensionality of each hidden layer.
- 'iters' : int, default=10
The maximum number of training iterations on the training dataset.
- 'distribution' : str, default='Weibull'
Choice of the underlying survival distributions.
Options include: 'Weibull' and 'LogNormal'.
- 'temperature' : float, default=1.0
The value with which to rescale the logits for the gate.
- `batch_size` : int, default=100
Learning is performed on mini-batches of input data. This parameter
specifies the size of each mini-batch.
- `lr` : float, default=1e-3
Learning rate for the 'Adam' optimizer.
- `epochs` : int, default=1
Number of complete passes through the training data.
Returns
-----------
Expand All @@ -357,19 +363,19 @@ def _fit_dsm(features, outcomes, random_seed, **hyperparams):

k = hyperparams.get("k", 3)
layers = hyperparams.get("layers", [100])
iters = hyperparams.get("iters", 10)
epochs = hyperparams.get("iters", 10)
distribution = hyperparams.get("distribution", "Weibull")
temperature = hyperparams.get("temperature", 1.0)
lr = hyperparams.get("lr", 1e-3)
bs = hyperparams.get("batch_size", 1.0)

model = DeepSurvivalMachines(k=k, layers=layers,
distribution=distribution,
temp=temperature,
random_seed=random_seed)

model.fit(features.values,
outcomes['time'].values,
outcomes['event'].values,
iters=iters)
model.fit(x=features, t=outcomes.time, e=outcomes.event, vsize=vsize,
val_data=val_data, iters=epochs, learning_rate=lr, batch_size=bs)

return model

Expand Down Expand Up @@ -532,66 +538,107 @@ def __init__(self, model, random_seed=0, **hyperparams):
self.random_seed = random_seed
self.fitted = False

def fit(self, features, outcomes,
weights=None, resample_size=1.0):
def fit(self, features, outcomes, vsize=None, val_data=None,
weights_train=None, weights_val=None, resample_size=1.0):

"""This method is used to train an instance of the survival model.
Parameters
-----------
features: pd.DataFrame
features : pd.DataFrame
a pandas dataframe with rows corresponding to individual samples and
columns as covariates.
outcomes : pd.DataFrame
a pandas dataframe with columns 'time' and 'event'.
weights: list or np.array
vsize : float
Amount of data to set aside as the validation set.
Not applicable to 'rsf' and 'cph' models.
val_data : tuple
A tuple of the validation dataset.
If passed vsize is ignored.
Not applicable to 'rsf' and 'cph' models.
weights_train : list or np.array
a list or numpy array of importance weights for each sample.
resample_size: float
weights_val : list or np.array
a list or numpy array of importance weights for each validation set sample.
Ignored if val_data is None.
resample_size : float
a float between 0 and 1 that controls the size of the resampled dataset.
weights_clip: float
a float that controls the minimum and maximum importance weight.
(To reduce estimator variance.)
Returns
--------
self
Trained instance of a survival model.
"""

if weights is not None:
assert len(weights) == features.shape[0], "Size of passed weights \
must match size of training data."
assert (weights>0.).any(), "All weights must be positive."
# assert ((weights>0.0)&(weights<=1.0)).all(), "Weights must be in the range (0,1]."
# weights[weights>(1-weights_clip)] = 1-weights_clip
# weights[weights<(weights_clip)] = weights_clip

if (self.model=='cph') | (self.model=='rsf'):
if (vsize is not None) | (val_data is not None):
raise Exception("'vsize' and 'val_data' should be None for 'cph' and 'rsf' models.")

if weights_train is not None:
assert len(weights_train) == features.shape[0], "Size of passed weights \
must match size of training data."
assert (weights_train>0.).any(), "All weights must be positive."
assert (vsize is not None) | (val_data is not None), "'vsize' or 'val_data' must \
be specified if weights are used."

weights = pd.Series(weights, index=data.index)

data = features.join(outcomes)
data_resampled = data.sample(weights = weights,
frac = resample_size,
replace = True,
random_state = self.random_seed)
features = data_resampled[features.columns]
outcomes = data_resampled[outcomes.columns]


if val_data is not None:
data_train = data
data_val = val_data
else:
data_train = data.sample(frac=1-vsize, random_state=self.random_seed)
data_val = data[~data.index.isin(data_train.index)]
weights_train = weights[data_train.index]
weights_val = weights[data_val.index]

data_train_resampled = data_train.sample(weights = weights_train,
frac = resample_size,
replace = True,
random_state = self.random_seed)

data_val_resampled = data_val.sample(weights = weights_val,
frac = resample_size,
replace = True,
random_state = self.random_seed)

features = data_train_resampled[features.columns]
outcomes = data_train_resampled[outcomes.columns]

val_data = (data_val_resampled[features.columns],
data_val_resampled[outcomes.columns])

if self.model == 'cph':
self._model = _fit_cph(features, outcomes, self.random_seed,
self._model = _fit_cph(features, outcomes,
self.random_seed,
**self.hyperparams)
elif self.model == 'rsf':
self._model = _fit_rsf(features, outcomes, self.random_seed,
self._model = _fit_rsf(features, outcomes,
self.random_seed,
**self.hyperparams)
elif self.model == 'dsm':
self._model = _fit_dsm(features, outcomes, self.random_seed,
self._model = _fit_dsm(features, outcomes,
vsize, val_data,
self.random_seed,
**self.hyperparams)
elif self.model == 'dcph':
self._model = _fit_dcph(features, outcomes, self.random_seed,
self._model = _fit_dcph(features, outcomes,
vsize, val_data,
self.random_seed,
**self.hyperparams)
elif self.model == 'dcm':
self._model = _fit_dcm(features, outcomes, self.random_seed,
self._model = _fit_dcm(features, outcomes,
vsize, val_data,
self.random_seed,
**self.hyperparams)
else : raise NotImplementedError()

else:
raise NotImplementedError()

self.fitted = True
return self

Expand Down
10 changes: 7 additions & 3 deletions auton_survival/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,13 @@ def func(x, y, x1, y1, x2, y2):

tar_diff = []
for risk in risks:
treated_tar = interp_x(treated_risk, treated_horizons, risk)
control_tar = interp_x(control_risk, control_horizons, risk)
tar_diff.append(treated_tar - control_tar)
if risk == 1:
tar_diff.append((treated_horizons[treated_risk==1] -
control_horizons[control_risk==1])[0])
else:
treated_tar = interp_x(treated_risk, treated_horizons, risk)
control_tar = interp_x(control_risk, control_horizons, risk)
tar_diff.append(treated_tar - control_tar)

return np.array(tar_diff)

Expand Down
16 changes: 15 additions & 1 deletion auton_survival/models/cmhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"""

import numpy as np
import pandas as pd
import torch

from .cmhe_torch import DeepCMHETorch
Expand Down Expand Up @@ -139,14 +140,27 @@ def __call__(self):
print("Hidden Layers:", self.layers)

def _preprocess_test_data(self, x, a=None):
if isinstance(x, pd.DataFrame):
x = x.values
if a is not None:
if isinstance(a, (pd.Series, pd.DataFrame)):
a = a.values
return torch.from_numpy(x).float(), torch.from_numpy(a).float()
else:
return torch.from_numpy(x).float()

def _preprocess_training_data(self, x, t, e, a, vsize, val_data,
random_seed):


if isinstance(x, pd.DataFrame):
x = x.values
if isinstance(t, (pd.Series, pd.DataFrame)):
t = t.values
if isinstance(e, (pd.Series, pd.DataFrame)):
e = e.values
if isinstance(a, (pd.Series, pd.DataFrame)):
a = a.values

idx = list(range(x.shape[0]))

np.random.seed(random_seed)
Expand Down
21 changes: 20 additions & 1 deletion auton_survival/models/cph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import torch
import numpy as np
import pandas as pd

from .dcph_torch import DeepCoxPHTorch, DeepRecurrentCoxPHTorch
from .dcph_utilities import train_dcph, predict_survival
Expand Down Expand Up @@ -84,10 +85,19 @@ def __call__(self):
print("Hidden Layers:", self.layers)

def _preprocess_test_data(self, x):
if isinstance(x, pd.DataFrame):
x = x.values
return torch.from_numpy(x).float()

def _preprocess_training_data(self, x, t, e, vsize, val_data, random_seed):


if isinstance(x, pd.DataFrame):
x = x.values
if isinstance(t, (pd.Series, pd.DataFrame)):
t = t.values
if isinstance(e, (pd.Series, pd.DataFrame)):
e = e.values

idx = list(range(x.shape[0]))

np.random.seed(random_seed)
Expand Down Expand Up @@ -276,11 +286,20 @@ def _gen_torch_model(self, inputdim, optimizer):
optimizer=optimizer, typ=self.typ)

def _preprocess_test_data(self, x):
if isinstance(x, pd.DataFrame):
x = x.values
return torch.from_numpy(_get_padded_features(x)).float()

def _preprocess_training_data(self, x, t, e, vsize, val_data, random_seed):
"""RNNs require different preprocessing for variable length sequences"""

if isinstance(x, pd.DataFrame):
x = x.values
if isinstance(t, (pd.Series, pd.DataFrame)):
t = t.values
if isinstance(e, (pd.Series, pd.DataFrame)):
e = e.values

idx = list(range(x.shape[0]))
np.random.seed(random_seed)
np.random.shuffle(idx)
Expand Down
12 changes: 11 additions & 1 deletion auton_survival/models/dcm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

import torch
import numpy as np
import pandas as pd

from .dcm_torch import DeepCoxMixturesTorch
from .dcm_utilities import train_dcm, predict_survival, predict_latent_z
Expand Down Expand Up @@ -112,10 +113,19 @@ def __call__(self):
print("Hidden Layers:", self.layers)

def _preprocess_test_data(self, x):
if isinstance(x, pd.DataFrame):
x = x.values
return torch.from_numpy(x).float()

def _preprocess_training_data(self, x, t, e, vsize, val_data, random_seed):


if isinstance(x, pd.DataFrame):
x = x.values
if isinstance(t, (pd.Series, pd.DataFrame)):
t = t.values
if isinstance(e, (pd.Series, pd.DataFrame)):
e = e.values

idx = list(range(x.shape[0]))
np.random.seed(random_seed)
np.random.shuffle(idx)
Expand Down
Loading

0 comments on commit 4ff3e69

Please sign in to comment.