Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Apr 2, 2022
2 parents 0ef4487 + db17080 commit eb88f79
Show file tree
Hide file tree
Showing 17 changed files with 1,037 additions and 439 deletions.
273 changes: 146 additions & 127 deletions auton_survival/estimators.py

Large diffs are not rendered by default.

27 changes: 17 additions & 10 deletions auton_survival/experiments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from sklearn.utils import shuffle

from auton_survival.estimators import SurvivalModel, CounterfactualSurvivalModel
from auton_survival.metrics import survival_regression_metric
Expand Down Expand Up @@ -63,7 +62,7 @@ def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}):

def fit(self, features, outcomes, ret_trained_model=True):

r"""Fits the Survival Regression Model to the data in a Cross
r"""Fits the Survival Regression Model to the data in a Cross
Validation fashion.
Parameters
Expand All @@ -77,14 +76,13 @@ def fit(self, features, outcomes, ret_trained_model=True):
a column named 'event' that contains the censoring status.
\( \delta_i = 1 \) if the event is observed.
ret_trained_model : bool
If True, the trained model is returned. If False, the fit function
If True, the trained model is returned. If False, the fit function
returns self.
Returns
-----------
auton_survival.estimators.SurvivalModel:
The selected survival model based on lowest integrated brier score.
"""

n = len(features)
Expand Down Expand Up @@ -122,16 +120,23 @@ def fit(self, features, outcomes, ret_trained_model=True):
fold_models = {}
for fold in tqdm(range(self.cv_folds)):
# Fit the model
fold_model = SurvivalModel(model=self.model, random_seed=self.random_seed, **hyper_param)
fold_model = SurvivalModel(model=self.model, random_seed=self.random_seed, **hyper_param)
fold_model.fit(features.loc[folds!=fold], outcomes.loc[folds!=fold])
fold_models[fold] = fold_model

# Predict risk scores
predictions[folds==fold] = fold_model.predict_survival(features.loc[folds==fold], times=unique_times)
# Evaluate IBS
predictions[folds==fold] = fold_model.predict_survival(features.loc[folds==fold],
times=unique_times)

score_per_fold = []
for fold in range(self.cv_folds):
score = survival_regression_metric('ibs', predictions, outcomes, unique_times, folds, fold)
outcomes_train = outcomes.loc[folds!=fold]
outcomes_test = outcomes.loc[folds==fold]
predictions_test = predictions[folds==fold]

# Compute IBS
score = survival_regression_metric('ibs', outcomes_train, outcomes_test,
predictions_test, unique_times)
score_per_fold.append(score)

current_score = np.mean(score_per_fold)
Expand All @@ -148,7 +153,8 @@ def fit(self, features, outcomes, ret_trained_model=True):

if ret_trained_model:

model = SurvivalModel(model=self.model, random_seed=self.random_seed, **self.best_hyperparameter)
model = SurvivalModel(model=self.model, random_seed=self.random_seed,
**self.best_hyperparameter)
model.fit(features, outcomes)

return model
Expand All @@ -174,7 +180,8 @@ def evaluate(self, features, outcomes, metrics=['auc', 'ctd'], horizons=[]):
for fold in range(self.cv_folds):

fold_model = self.best_model_per_fold[fold]
fold_predictions = fold_model.predict(features.loc[self.folds==fold], times=horizons)
fold_predictions = fold_model.predict(features.loc[self.folds==fold],
times=horizons)

for i, horizon in enumerate(horizons):
for metric in metrics:
Expand Down
Loading

0 comments on commit eb88f79

Please sign in to comment.