Skip to content

Commit

Permalink
notebook and random seed updates (#64)
Browse files Browse the repository at this point in the history
* notebook and random seed updates

* modified:   estimators.py
	modified:   preprocessing.py
  • Loading branch information
PotosnakW authored Apr 2, 2022
1 parent e347556 commit db17080
Show file tree
Hide file tree
Showing 17 changed files with 2,129 additions and 551 deletions.
273 changes: 146 additions & 127 deletions auton_survival/estimators.py

Large diffs are not rendered by default.

30 changes: 19 additions & 11 deletions auton_survival/experiments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from sklearn.utils import shuffle

from auton_survival.estimators import SurvivalModel, CounterfactualSurvivalModel
from auton_survival.metrics import survival_regression_metric
Expand Down Expand Up @@ -63,7 +62,8 @@ def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}):

def fit(self, features, outcomes, ret_trained_model=True):

r"""Fits the Survival Regression Model to the data in a Cross Validation fashion.
r"""Fits the Survival Regression Model to the data in a Cross
Validation fashion.
Parameters
-----------
Expand All @@ -76,14 +76,13 @@ def fit(self, features, outcomes, ret_trained_model=True):
a column named 'event' that contains the censoring status.
\( \delta_i = 1 \) if the event is observed.
ret_trained_model : bool
If True, the trained model is returned. If False, the fit function returns
self.
If True, the trained model is returned. If False, the fit function
returns self.
Returns
-----------
auton_survival.estimators.SurvivalModel:
The selected survival model based on lowest integrated brier score.
"""

n = len(features)
Expand Down Expand Up @@ -121,16 +120,23 @@ def fit(self, features, outcomes, ret_trained_model=True):
fold_models = {}
for fold in tqdm(range(self.cv_folds)):
# Fit the model
fold_model = SurvivalModel(model=self.model, random_seed=self.random_seed, **hyper_param)
fold_model = SurvivalModel(model=self.model, random_seed=self.random_seed, **hyper_param)
fold_model.fit(features.loc[folds!=fold], outcomes.loc[folds!=fold])
fold_models[fold] = fold_model

# Predict risk scores
predictions[folds==fold] = fold_model.predict_survival(features.loc[folds==fold], times=unique_times)
# Evaluate IBS
predictions[folds==fold] = fold_model.predict_survival(features.loc[folds==fold],
times=unique_times)

score_per_fold = []
for fold in range(self.cv_folds):
score = survival_regression_metric('ibs', predictions, outcomes, unique_times, folds, fold)
outcomes_train = outcomes.loc[folds!=fold]
outcomes_test = outcomes.loc[folds==fold]
predictions_test = predictions[folds==fold]

# Compute IBS
score = survival_regression_metric('ibs', outcomes_train, outcomes_test,
predictions_test, unique_times)
score_per_fold.append(score)

current_score = np.mean(score_per_fold)
Expand All @@ -147,7 +153,8 @@ def fit(self, features, outcomes, ret_trained_model=True):

if ret_trained_model:

model = SurvivalModel(model=self.model, random_seed=self.random_seed, **self.best_hyperparameter)
model = SurvivalModel(model=self.model, random_seed=self.random_seed,
**self.best_hyperparameter)
model.fit(features, outcomes)

return model
Expand All @@ -173,7 +180,8 @@ def evaluate(self, features, outcomes, metrics=['auc', 'ctd'], horizons=[]):
for fold in range(self.cv_folds):

fold_model = self.best_model_per_fold[fold]
fold_predictions = fold_model.predict(features.loc[self.folds==fold], times=horizons)
fold_predictions = fold_model.predict(features.loc[self.folds==fold],
times=horizons)

for i, horizon in enumerate(horizons):
for metric in metrics:
Expand Down
Loading

0 comments on commit db17080

Please sign in to comment.