Skip to content

Commit

Permalink
modified: metrics.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Jul 7, 2022
1 parent d231e6c commit e2c619e
Showing 1 changed file with 61 additions and 15 deletions.
76 changes: 61 additions & 15 deletions auton_survival/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def treatment_effect(metric, outcomes, treatment_indicator,
random_seed=i) for i in range(n_bootstrap)]

def survival_regression_metric(metric, outcomes, predictions,
times, outcomes_train=None):
times, outcomes_train=None,
n_bootstrap=None, random_seed=0):
"""Compute metrics to assess survival model performance.
Parameters
Expand All @@ -189,7 +190,14 @@ def survival_regression_metric(metric, outcomes, predictions,
outcomes_train : pd.DataFrame
A pandas dataframe with rows corresponding to individual samples and
columns 'time' and 'event' for training data.
n_bootstrap : int, default=None
The number of bootstrap samples to use.
If None, bootrapping is not performed.
size_bootstrap : float, default=1.0
The fraction of the population to sample for each bootstrap sample.
random_seed: int, default=0
Controls the reproducibility random sampling for bootstrapping.
Returns
-----------
float: The metric value for the specified metric.
Expand All @@ -213,25 +221,63 @@ def survival_regression_metric(metric, outcomes, predictions,
survival_test = util.Surv.from_dataframe('event', 'time', outcomes)

if metric == 'brs':
return metrics.brier_score(survival_train, survival_test,
predictions, times)[-1]
_metric = _brier_score
elif metric == 'ibs':
return metrics.integrated_brier_score(survival_train, survival_test,
predictions, times)
_metric = _integrated_brier_score
elif metric == 'auc':
return metrics.cumulative_dynamic_auc(survival_train, survival_test,
1-predictions, times)[0]
_metric = _cumulative_dynamic_auc
elif metric == 'ctd':
vals = []
for i in range(len(times)):
vals.append(metrics.concordance_index_ipcw(survival_train, survival_test,
1-predictions[:,i],
tau=times[i])[0])
return vals

_metric = _concordance_index_ipcw
else:
raise NotImplementedError()

if n_bootstrap is None:
return _metric(survival_train, survival_test, predictions, times)
else:
return [_metric(survival_train, survival_test, predictions, times, random_seed=i) for i in range(n_bootstrap)]


def _brier_score(survival_train, survival_test, predictions, times, random_seed=None):

idx = np.arange(len(predictions))
if random_seed is not None:
np.random.seed(random_seed)
idx = np.random.choice(idx, len(predictions), replace=True)

return metrics.brier_score(survival_train, survival_test[idx], predictions[idx], times)[-1]

def _integrated_brier_score(survival_train, survival_test, predictions, times, random_seed=None):

idx = np.arange(len(predictions))
if random_seed is not None:
np.random.seed(random_seed)
idx = np.random.choice(idx, len(predictions), replace=True)

return metrics.integrated_brier_score(survival_train, survival_test[idx], predictions[idx], times)

def _cumulative_dynamic_auc(survival_train, survival_test, predictions, times, random_seed=None):

idx = np.arange(len(predictions))
if random_seed is not None:
np.random.seed(random_seed)
idx = np.random.choice(idx, len(predictions), replace=True)

return metrics.cumulative_dynamic_auc(survival_train, survival_test[idx], 1-predictions[idx], times)[0]

def _concordance_index_ipcw(survival_train, survival_test, predictions, times, random_seed=None):

idx = np.arange(len(predictions))
if random_seed is not None:
np.random.seed(random_seed)
idx = np.random.choice(idx, len(predictions), replace=True)

vals = []
for i in range(len(times)):
vals.append(metrics.concordance_index_ipcw(survival_train, survival_test[idx],
1-predictions[idx][:,i],
tau=times[i])[0])
return vals

def phenotype_purity(phenotypes_train, outcomes_train,
phenotypes_test=None, outcomes_test=None,
strategy='instantaneous', horizons=None,
Expand Down

0 comments on commit e2c619e

Please sign in to comment.