Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

notebook and random seed updates #64

Merged
merged 2 commits into from
Apr 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 146 additions & 127 deletions auton_survival/estimators.py

Large diffs are not rendered by default.

30 changes: 19 additions & 11 deletions auton_survival/experiments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from sklearn.utils import shuffle

from auton_survival.estimators import SurvivalModel, CounterfactualSurvivalModel
from auton_survival.metrics import survival_regression_metric
Expand Down Expand Up @@ -63,7 +62,8 @@ def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}):

def fit(self, features, outcomes, ret_trained_model=True):

r"""Fits the Survival Regression Model to the data in a Cross Validation fashion.
r"""Fits the Survival Regression Model to the data in a Cross
Validation fashion.

Parameters
-----------
Expand All @@ -76,14 +76,13 @@ def fit(self, features, outcomes, ret_trained_model=True):
a column named 'event' that contains the censoring status.
\( \delta_i = 1 \) if the event is observed.
ret_trained_model : bool
If True, the trained model is returned. If False, the fit function returns
self.
If True, the trained model is returned. If False, the fit function
returns self.

Returns
-----------
auton_survival.estimators.SurvivalModel:
The selected survival model based on lowest integrated brier score.

"""

n = len(features)
Expand Down Expand Up @@ -121,16 +120,23 @@ def fit(self, features, outcomes, ret_trained_model=True):
fold_models = {}
for fold in tqdm(range(self.cv_folds)):
# Fit the model
fold_model = SurvivalModel(model=self.model, random_seed=self.random_seed, **hyper_param)
fold_model = SurvivalModel(model=self.model, random_seed=self.random_seed, **hyper_param)
fold_model.fit(features.loc[folds!=fold], outcomes.loc[folds!=fold])
fold_models[fold] = fold_model

# Predict risk scores
predictions[folds==fold] = fold_model.predict_survival(features.loc[folds==fold], times=unique_times)
# Evaluate IBS
predictions[folds==fold] = fold_model.predict_survival(features.loc[folds==fold],
times=unique_times)

score_per_fold = []
for fold in range(self.cv_folds):
score = survival_regression_metric('ibs', predictions, outcomes, unique_times, folds, fold)
outcomes_train = outcomes.loc[folds!=fold]
outcomes_test = outcomes.loc[folds==fold]
predictions_test = predictions[folds==fold]

# Compute IBS
score = survival_regression_metric('ibs', outcomes_train, outcomes_test,
predictions_test, unique_times)
score_per_fold.append(score)

current_score = np.mean(score_per_fold)
Expand All @@ -147,7 +153,8 @@ def fit(self, features, outcomes, ret_trained_model=True):

if ret_trained_model:

model = SurvivalModel(model=self.model, random_seed=self.random_seed, **self.best_hyperparameter)
model = SurvivalModel(model=self.model, random_seed=self.random_seed,
**self.best_hyperparameter)
model.fit(features, outcomes)

return model
Expand All @@ -173,7 +180,8 @@ def evaluate(self, features, outcomes, metrics=['auc', 'ctd'], horizons=[]):
for fold in range(self.cv_folds):

fold_model = self.best_model_per_fold[fold]
fold_predictions = fold_model.predict(features.loc[self.folds==fold], times=horizons)
fold_predictions = fold_model.predict(features.loc[self.folds==fold],
times=horizons)

for i, horizon in enumerate(horizons):
for metric in metrics:
Expand Down
Loading