Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

notebooks and metrics updates #67

Merged
merged 1 commit into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 78 additions & 72 deletions auton_survival/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def survival_diff_metric(metric, outcomes, treatment_indicator,
treated_outcomes = outcomes[treatment_indicator]
control_outcomes = outcomes[~treatment_indicator]

if metric == 'survival_at': _metric = _survival_at_diff
if metric == 'survival_at':
_metric = _survival_at_diff
elif metric == 'time_to':
_metric = _time_to_diff
elif metric == 'restricted_mean':
Expand Down Expand Up @@ -201,118 +202,123 @@ def survival_regression_metric(metric, outcomes_train, outcomes_test,
else:
raise NotImplementedError()

def phenotype_purity(phenotypes, outcomes,
strategy='instantaneous', folds=None,
fold=None, time=None, bootstrap=None):
def phenotype_purity(phenotypes_train, outcomes_train,
phenotypes_test=None, outcomes_test=None,
strategy='instantaneous', horizon=None,
bootstrap=None):
"""Compute the brier score to assess survival model performance
for phenotypes.

Parameters
-----------
phenotypes: np.array
A numpy array containing a list of strings that define subgroups.
outcomes : pd.DataFrame
phenotypes_train: np.array
A numpy array containing an array of integers that define subgroups
for the train set.
outcomes_train : pd.DataFrame
A pandas dataframe with rows corresponding to individual samples and
columns 'time' and 'event' for the train set.
phenotypes_test: np.array
A numpy array containing an array of integers that define subgroups
for the test set.
outcomes_test : pd.DataFrame
A pandas dataframe with rows corresponding to individual samples and
columns 'time' and 'event'.
strategy: string, default='instantaneous'
columns 'time' and 'event' for the test set.
strategy : string, default='instantaneous'
Options include:
- `instantaneous` : Predict the Kaplan Meier survival estimate at a
certain point in time and compute the brier score.
- `integrated` : Predict the Kaplan Meier survival estimate at all unique
times points and compute the integrated brier score.
folds: pd.DataFrame, default=None
A pandas dataframe of train and test folds.
fold: int, default=None
A specific fold number in the folds input.
time: int, default=None
A certain point in time at which to predict the Kaplan Meier survival
estimate.
bootstrap: integer, default=None
- `instantaneous` : Compute the brier score.
- `integrated` : Compute the integrated brier score.
horizon : float or int or np.array of floats or ints, default=None
Event horizon(s) at which to compute the metric
bootstrap : integer, default=None
The number of bootstrap iterations.

Returns
-----------
float:
The brier score is computed for the 'instantaneous' strategy.
The integreted brier score is computed for the 'integrated' strategy.
list:
Columns are metric values computed for each event horizon.
If bootstrapping, rows are bootstrap results.

"""

# CODE UPDATE: enable phenotype purity to be computed for the test set...
# without specifying folds to determine the train/test sets when folds are inapplicable (no CV)
np.random.seed(0)

if folds is None:
assert fold is None, "Please pass the data folds.."

assert time is not None, "Please pass the time of evaluation!"
if (outcomes_test is None) & (phenotypes_test is not None):
raise Exception("Specify outcomes for test set.")
if (outcomes_test is not None) & (phenotypes_test is None):
raise Exception("Specify phenotypes for test set.")

if folds is not None:
outcomes_train = outcomes.iloc[folds!=fold]
outcomes_test = outcomes.iloc[folds==fold]
phenotypes_train = phenotypes[folds!=fold]
phenotypes_test = phenotypes[folds==fold]
else:
outcomes_train, outcomes_test = outcomes, outcomes
phenotypes_train, phenotypes_test = phenotypes, phenotypes
assert horizon is not None, "Please specify Event Horizon"

assert (time<outcomes_test['time'].max()) and (time>outcomes_test['time'].min())
assert (time<outcomes_train['time'].max()) and (time>outcomes_train['time'].min())
if isinstance(horizon, float) | isinstance(horizon, int):
horizon = [horizon]

for phenotype in set(phenotypes_test):
assert phenotype in phenotypes_train, "Testing on Phenotypes not found \
in the Training set!!"
if outcomes_test is None:
phenotypes_test = phenotypes_train
outcomes_test = outcomes_train
warnings.warn("You are are estimating survival probabilities for \
the same dataset used to estimate the censoring \
distribution.")

survival_curves = {}
for phenotype in set(phenotypes_train):
survival_curves[phenotype] = KaplanMeierFitter().fit(outcomes_train.iloc[phenotypes_train==phenotype]['time'],
outcomes_train.iloc[phenotypes_test==phenotype]['event'])
for phenotype in np.unique(phenotypes_train):
survival_curves[phenotype] = KaplanMeierFitter().fit(outcomes_train.time.iloc[phenotypes_train==phenotype],
outcomes_train.event.iloc[phenotypes_train==phenotype])

survival_train = util.Surv.from_dataframe('event', 'time', outcomes_train)
survival_test = util.Surv.from_dataframe('event', 'time', outcomes_test)

n = len(survival_test)

if strategy == 'instantaneous':

predictions = np.zeros(len(survival_test))
for phenotype in set(phenotypes):
predictions[phenotypes==phenotype] = float(survival_curves[phenotype].predict(times=time,
interpolate=True))
predictions = np.zeros((len(survival_test), len(horizon)))
for phenotype in set(phenotypes_test):
predictions[phenotypes_test==phenotype, :] = survival_curves[phenotype].predict(times=horizon,
interpolate=True)
if bootstrap is None:
return float(metrics.brier_score(survival_train, survival_test,
predictions, time)[1])
return metrics.brier_score(survival_train, survival_test,
predictions, horizon)[1]
else:
scores = []
for i in tqdm(range(bootstrap)):
idx = np.random.choice(n, size=n, replace=True)
score = float(metrics.brier_score(survival_train, survival_test[idx],
predictions[idx], time)[1])
score = metrics.brier_score(survival_train, survival_test[idx],
predictions[idx], horizon)[1]
scores.append(score)
return scores

elif strategy == 'integrated':

times = np.unique(outcomes_test['time'])
times = times[times<time]
predictions = np.zeros((len(survival_test), len(times)))
for phenotype in set(phenotypes):
predictions[phenotypes==phenotype, :] = survival_curves[phenotype].predict(times=times,
interpolate=True).values
horizon_scores = []
for time in horizon:
times = np.unique(outcomes_test['time'])
times = times[times<time]
predictions = np.zeros((len(survival_test), len(times)))
for phenotype in set(phenotypes_test):
predictions[phenotypes_test==phenotype, :] = survival_curves[phenotype].predict(times=times,
interpolate=True).values
if bootstrap is None:
horizon_scores.append(metrics.integrated_brier_score(survival_train,
survival_test,
predictions,
times))

else:
score = []
for i in tqdm(range(bootstrap)):
idx = np.random.choice(n, size=n, replace=True)
score.append(metrics.integrated_brier_score(survival_train,
survival_test[idx],
predictions[idx],
times))
horizon_scores.append(score)

if bootstrap is None:
return metrics.integrated_brier_score(survival_train,
survival_test,
predictions,
times)
return np.array(horizon_scores)
else:
scores = []
for i in tqdm(range(bootstrap)):
idx = np.random.choice(n, size=n, replace=True)
score = metrics.integrated_brier_score(survival_train,
survival_test[idx],
predictions[idx],
times)
scores.append(score)
return scores
# Format scores exactly like "instantaneous" option w/ bootstrapping for consistency
return [np.array([j[i] for j in np.array(horizon_scores)]) for i in range(bootstrap)]

else:
raise NotImplementedError()
Expand Down
Loading