From 4cb835eea62fc01e125d0fbd4b806761d1877702 Mon Sep 17 00:00:00 2001 From: W Potosnak Date: Wed, 18 May 2022 21:51:04 -0400 Subject: [PATCH] fin updates --- auton_survival/estimators.py | 41 +++++---- auton_survival/experiments.py | 89 +++---------------- auton_survival/phenotyping.py | 20 +++-- ...rvival Regression on SUPPORT Dataset.ipynb | 35 ++++---- ...vival Regression with Auton-Survival.ipynb | 86 +++++++++--------- 5 files changed, 101 insertions(+), 170 deletions(-) diff --git a/auton_survival/estimators.py b/auton_survival/estimators.py index 03fc54a..01a1ca7 100644 --- a/auton_survival/estimators.py +++ b/auton_survival/estimators.py @@ -595,12 +595,12 @@ def fit(self, features, outcomes, vsize=0.15, val_data=None, assert weights_val is None, "Weights for validation data \ must be None if validation data is not specified." - train_data = data.sample(frac=1-vsize, random_state=self.random_seed) - val_data = data[~data.index.isin(train_data.index)] - val_data = (val_data[features.columns], val_data[outcomes.columns]) + data_train = data.sample(frac=1-vsize, random_state=self.random_seed) + data_val = data[~data.index.isin(data_train.index)] else: - train_data = data + data_train = data + data_val = val_data[0].join(val_data[1]) if weights is not None: assert len(weights) == features.shape[0], "Size of passed weights \ @@ -608,38 +608,37 @@ def fit(self, features, outcomes, vsize=0.15, val_data=None, assert (weights>0.).any(), "All weights must be positive." weights = pd.Series(weights, index=data.index) - val_data = val_data[0].join(val_data[1]) - if weights_val is None: - weights_train = weights[train_data.index] - weights_val = weights[val_data.index] - - else: - assert weights_val is not None, "Validation set weights must be \ -specified." - assert len(weights_val) == val_data[0].shape[0], "Size of passed \ -weights_val must match size of validation data." + if weights_val is not None: + assert len(weights_val) == data_val[features.columns].shape[0], "Size \ +of passed weights_val must match size of validation data." assert (weights_val>0.).any(), "All weights_val must be positive." weights_train = weights - train_data_resampled = train_data.sample(weights = weights_train, + else: + assert val_data is None, "Validation weights must be specified if validation \ +data and training set weights are both specified." + weights_train = weights[data_train.index] + weights_val = weights[data_val.index] + + data_train_resampled = data_train.sample(weights = weights_train, frac = resample_size, replace = True, random_state = self.random_seed) - val_data_resampled = val_data.sample(weights = weights_val, + data_val_resampled = data_val.sample(weights = weights_val, frac = resample_size, replace = True, random_state = self.random_seed) - features = train_data_resampled[features.columns] - outcomes = train_data_resampled[outcomes.columns] + features = data_train_resampled[features.columns] + outcomes = data_train_resampled[outcomes.columns] - val_data = (val_data_resampled[features.columns], - val_data_resampled[outcomes.columns]) + data_val = data_val_resampled - val_data = (val_data[0], val_data[1].time, val_data[1].event) + val_data = (data_val[features.columns], data_val[outcomes.columns].time, + data_val[outcomes.columns].event) if self.model == 'cph': self._model = _fit_cph(features, outcomes, diff --git a/auton_survival/experiments.py b/auton_survival/experiments.py index 4c6236d..91b22cd 100644 --- a/auton_survival/experiments.py +++ b/auton_survival/experiments.py @@ -100,8 +100,7 @@ def __init__(self, model='dcph', folds=None, num_folds=5, self.random_seed = random_seed self.hyperparam_grid = list(ParameterGrid(hyperparam_grid)) - def fit(self, features, outcomes, metric='ibs', horizon=None, - cat_feats=None, num_feats=None, one_hot=False): + def fit(self, features, outcomes, metric='ibs', horizon=None): r"""Fits the survival regression model to the data in a cross- validation or nested cross-validation fashion. @@ -124,12 +123,6 @@ def fit(self, features, outcomes, metric='ibs', horizon=None, horizon : int or float, default=None Event-horizon at which to evaluate model performance. If None, then the maximum permissible event-time from the data is used. - cat_feats: list - List of categorical features. - num_feats: list - List of numerical/continuous features. - one_hot : bool, default=False - Whether to perform One-Hot encoding for categorical features. Returns ----------- @@ -138,9 +131,6 @@ def fit(self, features, outcomes, metric='ibs', horizon=None, """ self.metric = metric - self.cat_feats = cat_feats - self.num_feats = num_feats - self.one_hot = one_hot self.horizon = horizon if (horizon is None) & (metric not in ['auc', 'ctd', 'brs']): @@ -153,16 +143,11 @@ def fit(self, features, outcomes, metric='ibs', horizon=None, self.num_folds, self.random_seed) - if self.num_nested_folds is None: - proc_x_tr = self._process_data(features_tr=features, - cat_feats=self.cat_feats, - num_feats=self.num_feats, - one_hot=self.one_hot) - + if self.num_nested_folds is None: best_params = self._cv_select_parameters(features, outcomes, self.folds) model = SurvivalModel(self.model, self.random_seed, **best_params) - return model.fit(proc_x_tr, outcomes) + return model.fit(features, outcomes) else: return self._train_nested_cv_models(features, outcomes) @@ -190,19 +175,14 @@ def _train_nested_cv_models(self, features, outcomes): for fi, fold in enumerate(set(self.folds)): x_tr = features.copy().loc[self.folds!=fold] y_tr = outcomes.loc[self.folds!=fold] - - proc_x_tr = self._process_data(features_tr=x_tr, - cat_feats=self.cat_feats, - num_feats=self.num_feats, - one_hot=self.one_hot) self.nested_folds = self._get_stratified_folds(y_tr, 'event', self.num_nested_folds, self.random_seed) - # Use unprocessed training set for nested CV. + best_params = self._cv_select_parameters(x_tr, y_tr, self.nested_folds) model = SurvivalModel(self.model, self.random_seed, **best_params) - models[fi] = model.fit(proc_x_tr, y_tr) + models[fi] = model.fit(x_tr, y_tr) return models @@ -248,10 +228,7 @@ def _cv_select_parameters(self, features, outcomes, folds): y_tr = outcomes.loc[folds!=fold] y_val = outcomes.loc[folds==fold] - proc_x_tr, proc_x_val = self._process_data(x_tr, x_val, self.cat_feats, - self.num_feats,self.one_hot) - - param_results = self._fit_evaluate_model(proc_x_tr, y_tr, proc_x_val, y_val) + param_results = self._fit_evaluate_model(x_tr, y_tr, x_val, y_val) # Hyperparameter results as row items and fold results as columns. fold_results = pd.concat([fold_results, pd.DataFrame(param_results)], @@ -302,7 +279,7 @@ def _fit_evaluate_model(self, features_tr, outcomes_tr, if self.metric == 'ibs': times = self.times else: - times = self.times[-1] + times = [self.times[-1]] param_results = [] for hyper_param in tqdm(self.hyperparam_grid): @@ -364,47 +341,6 @@ def _get_stratified_folds(self, dataset, event_label, n_folds, random_seed): return df_folds - def _process_data(self, features_tr, features_te=None, cat_feats=None, - num_feats=None, one_hot=False): - - """Fit preprocessors to training set data and transform training - set and test set data. - - Parameters - ----------- - features_tr : pd.DataFrame - A pandas dataframe with rows corresponding to individual samples - and columns as covariates for the training set data. - features_te : pd.DataFrame - A pandas dataframe with rows corresponding to individual samples - and columns as covariates for the test set data. - cat_feats: list - List of categorical features. - num_feats: list - List of numerical/continuous features. - one_hot : bool, default=False - Whether to perform One-Hot encoding for categorical features. - - Returns - ----------- - Pandas dataframes of preprocessed training and test set data. - - """ - - preprocessor = Preprocessor(cat_feat_strat='replace', - num_feat_strat='median', - scaling_strategy='standard', - one_hot=one_hot) - transformer = preprocessor.fit(features_tr, cat_feats=cat_feats, - num_feats=num_feats, fill_value=-1) - features_tr = transformer.transform(features_tr.copy()) - - if features_te is None: - return features_tr - else: - features_te = transformer.transform(features_te.copy()) - return features_tr, features_te - def _check_times(self, outcomes, times, folds): """Verify times are within an appropriate range for model evaluation. @@ -513,8 +449,7 @@ def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}): random_seed=random_seed, hyperparam_grid=hyperparam_grid) - def fit(self, features, outcomes, interventions, - metric, cat_feats, num_feats): + def fit(self, features, outcomes, interventions, metric): r"""Fits the Survival Regression Model to the data in a Cross Validation fashion. @@ -542,13 +477,9 @@ def fit(self, features, outcomes, interventions, treated_model = self.treated_experiment.fit(features.loc[interventions==1], outcomes.loc[interventions==1], - metric=metric, - cat_feats=cat_feats, - num_feats=num_feats) + metric=metric) control_model = self.control_experiment.fit(features.loc[interventions!=1], outcomes.loc[interventions!=1], - metric=metric, - cat_feats=cat_feats, - num_feats=num_feats) + metric=metric) return CounterfactualSurvivalModel(treated_model, control_model) diff --git a/auton_survival/phenotyping.py b/auton_survival/phenotyping.py index 1941783..faf8ca5 100644 --- a/auton_survival/phenotyping.py +++ b/auton_survival/phenotyping.py @@ -479,7 +479,7 @@ def __init__(self, self.random_seed = random_seed def fit(self, features, outcomes, interventions, metric, - horizon, cat_feats, num_feats): + horizon): """Fit a counterfactual model and regress the difference of the estimated counterfactual Restricted Mean Survival Time using a Random Forest regressor. @@ -492,16 +492,19 @@ def fit(self, features, outcomes, interventions, metric, outcomes : pd.DataFrame A pandas dataframe with rows corresponding to individual samples and columns 'time' and 'event'. - treatment_indicator : np.array + interventions : np.array Boolean numpy array of treatment indicators. True means individual was assigned a specific treatment. + metric : str, default='ibs' + Metric used to evaluate model performance and tune hyperparameters. + Options include: + - 'auc': Dynamic area under the ROC curve + - 'brs' : Brier Score + - 'ibs' : Integrated Brier Score + - 'ctd' : Concordance Index horizon : np.float The event horizon at which to compute the counterfacutal RMST for - regression. - cat_feats: list - List of categorical features. - num_feats: list - List of numerical/continuous features. + regression. Returns ----------- @@ -512,8 +515,7 @@ def fit(self, features, outcomes, interventions, metric, cf_model = CounterfactualSurvivalRegressionCV(model=self.cf_method, hyperparam_grid=self.cf_hyperparams) - self.cf_model = cf_model.fit(features, outcomes, interventions, - metric, cat_feats, num_feats) + self.cf_model = cf_model.fit(features, outcomes, interventions, metric) times = np.unique(outcomes.time.values) cf_predictions = self.cf_model.predict_counterfactual_survival(features, diff --git a/examples/CV Survival Regression on SUPPORT Dataset.ipynb b/examples/CV Survival Regression on SUPPORT Dataset.ipynb index a0b7c70..e1410c9 100644 --- a/examples/CV Survival Regression on SUPPORT Dataset.ipynb +++ b/examples/CV Survival Regression on SUPPORT Dataset.ipynb @@ -33,7 +33,13 @@ "cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']\n", "num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', \n", " 'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', \n", - " 'glucose', 'bun', 'urine', 'adlp', 'adls']" + " 'glucose', 'bun', 'urine', 'adlp', 'adls']\n", + "\n", + "# Data should be processed in a fold-independent manner when performing cross-validation. \n", + "# For simplicity in this demo, we process the dataset in a non-independent manner.\n", + "preprocessor = Preprocessor(cat_feat_strat='ignore', num_feat_strat= 'mean') \n", + "x = preprocessor.fit_transform(features, cat_feats=cat_feats, num_feats=num_feats,\n", + " one_hot=True, fill_value=-1)" ] }, { @@ -47,13 +53,6 @@ "times = np.quantile(outcomes.time[outcomes.event==1], horizons).tolist()" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, @@ -68,7 +67,7 @@ " 'layers' : [[100]]}\n", "\n", "experiment = SurvivalRegressionCV(model='dsm', num_folds=3, hyperparam_grid=param_grid, random_seed=0)\n", - "model = experiment.fit(features, outcomes, cat_feats=cat_feats, num_feats=num_feats, one_hot=True)\n" + "model = experiment.fit(x, outcomes, metric='ctd')" ] }, { @@ -81,21 +80,21 @@ "model" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from auton_survival.preprocessing import Preprocessor\n", - "\n", - "preprocessor = Preprocessor(cat_feat_strat='replace', num_feat_strat='median',\n", - " scaling_strategy='standard', one_hot=True)\n", - "features_preprocessed = preprocessor.fit_transform(features, cat_feats=cat_feats, \n", - " num_feats=num_feats, fill_value=-1)\n", - "\n", - "out_risk = model.predict_risk(features_preprocessed, times)\n", - "out_survival = model.predict_survival(features_preprocessed, times)" + "out_risk = model.predict_risk(x, times)\n", + "out_survival = model.predict_survival(x, times)" ] }, { diff --git a/examples/Survival Regression with Auton-Survival.ipynb b/examples/Survival Regression with Auton-Survival.ipynb index d28f400..973029f 100644 --- a/examples/Survival Regression with Auton-Survival.ipynb +++ b/examples/Survival Regression with Auton-Survival.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "e4613ff9", + "id": "2963a324", "metadata": {}, "source": [ "# Survival Regression with `estimators.SurvivalModel`\n", @@ -47,7 +47,7 @@ }, { "cell_type": "markdown", - "id": "40c1c241", + "id": "6f3632c5", "metadata": {}, "source": [ "\n", @@ -57,7 +57,7 @@ }, { "cell_type": "markdown", - "id": "f778788d", + "id": "c673f52f", "metadata": {}, "source": [ "The `SurvivalModels` class offers a steamlined approach to train two `auton-survival` models and three baseline survival models for right-censored time-to-event data. The fit method requires the same inputs across all five models, however, model parameter types vary and must be defined and tuned for the specified model.\n", @@ -103,7 +103,7 @@ }, { "cell_type": "markdown", - "id": "313e994b", + "id": "7ddb455c", "metadata": {}, "source": [ "\n", @@ -113,7 +113,7 @@ }, { "cell_type": "markdown", - "id": "9ed40cb0", + "id": "eac37671", "metadata": {}, "source": [ "*For the original datasource, please refer to the following [website](https://biostat.app.vumc.org/wiki/Main/SupportDesc).*\n", @@ -124,7 +124,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cffe5b7b", + "id": "7cbd4534", "metadata": {}, "outputs": [], "source": [ @@ -138,7 +138,7 @@ { "cell_type": "code", "execution_count": null, - "id": "94c63604", + "id": "c4f515b7", "metadata": {}, "outputs": [], "source": [ @@ -160,7 +160,7 @@ }, { "cell_type": "markdown", - "id": "93da9245", + "id": "fb70e37a", "metadata": {}, "source": [ "\n", @@ -170,7 +170,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f8c1c8d1", + "id": "c1186ce4", "metadata": {}, "outputs": [], "source": [ @@ -189,7 +189,7 @@ { "cell_type": "code", "execution_count": null, - "id": "524bb920", + "id": "5deea197", "metadata": {}, "outputs": [], "source": [ @@ -206,7 +206,7 @@ }, { "cell_type": "markdown", - "id": "735839db", + "id": "6d487a62", "metadata": {}, "source": [ "\n", @@ -215,7 +215,7 @@ }, { "cell_type": "markdown", - "id": "778c44c3", + "id": "1736009f", "metadata": {}, "source": [ "CPH [2] model assumes that individuals across the population have constant proportional hazards overtime. In this model, the estimator of the survival function conditional on $X, S(ยท|X) , P(T > t|X)$, is assumed to have constant proportional hazard. Thus, the relative proportional hazard between individuals is constant across time.\n", @@ -227,7 +227,7 @@ }, { "cell_type": "markdown", - "id": "f816d49c", + "id": "99014953", "metadata": {}, "source": [ "\n", @@ -237,7 +237,7 @@ { "cell_type": "code", "execution_count": null, - "id": "921144b0", + "id": "f65f9e13", "metadata": {}, "outputs": [], "source": [ @@ -273,7 +273,7 @@ }, { "cell_type": "markdown", - "id": "ab9d4f80", + "id": "960fa050", "metadata": {}, "source": [ "\n", @@ -283,7 +283,7 @@ { "cell_type": "code", "execution_count": null, - "id": "da091f55", + "id": "592b484c", "metadata": {}, "outputs": [], "source": [ @@ -303,7 +303,7 @@ }, { "cell_type": "markdown", - "id": "284f2cf9", + "id": "fe34852e", "metadata": {}, "source": [ "\n", @@ -312,7 +312,7 @@ }, { "cell_type": "markdown", - "id": "7ed82807", + "id": "e445a853", "metadata": {}, "source": [ "DCPH [2], [3] is an extension to the CPH model. DCPH involves modeling the proportional hazard ratios over the individuals with Deep Neural Networks allowing the ability to learn non linear hazard ratios.\n", @@ -326,7 +326,7 @@ }, { "cell_type": "markdown", - "id": "9e68494c", + "id": "fe1665d1", "metadata": {}, "source": [ "\n", @@ -336,7 +336,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8d05ab45", + "id": "ff10e3e5", "metadata": {}, "outputs": [], "source": [ @@ -376,7 +376,7 @@ }, { "cell_type": "markdown", - "id": "931935ec", + "id": "2f438c28", "metadata": {}, "source": [ "\n", @@ -385,7 +385,7 @@ }, { "cell_type": "markdown", - "id": "372c3602", + "id": "c12714fe", "metadata": {}, "source": [ "Compute the Brier Score and time-dependent concordance index for the test set. See notebook introduction for more details." @@ -394,7 +394,7 @@ { "cell_type": "code", "execution_count": null, - "id": "baf94753", + "id": "b3f92680", "metadata": {}, "outputs": [], "source": [ @@ -419,7 +419,7 @@ } }, "cell_type": "markdown", - "id": "edba5fb0", + "id": "aecd7591", "metadata": {}, "source": [ "\n", @@ -440,7 +440,7 @@ }, { "cell_type": "markdown", - "id": "35d5a8c3", + "id": "2594e7be", "metadata": {}, "source": [ "\n", @@ -450,7 +450,7 @@ { "cell_type": "code", "execution_count": null, - "id": "be2feaa7", + "id": "641c5fda", "metadata": {}, "outputs": [], "source": [ @@ -490,7 +490,7 @@ }, { "cell_type": "markdown", - "id": "73deaff5", + "id": "8dad2aae", "metadata": {}, "source": [ "\n", @@ -499,7 +499,7 @@ }, { "cell_type": "markdown", - "id": "8d327af4", + "id": "d9c09ab3", "metadata": {}, "source": [ "Compute the Brier Score and time-dependent concordance index for the test set. See notebook introduction for more details." @@ -508,7 +508,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c235869e", + "id": "be7efe1a", "metadata": {}, "outputs": [], "source": [ @@ -528,7 +528,7 @@ }, { "cell_type": "markdown", - "id": "185bc3e9", + "id": "43c9a02c", "metadata": {}, "source": [ "\n", @@ -542,7 +542,7 @@ } }, "cell_type": "markdown", - "id": "ec6a2e2f", + "id": "3131f043", "metadata": {}, "source": [ "DCM [2] generalizes the proportional hazards assumption via a mixture model, by assuming that there are latent groups and within each, the proportional hazards assumption holds. DCM allows the hazard ratio in each latent group, as well as the latent group membership, to be flexibly modeled by a deep neural network.\n", @@ -560,7 +560,7 @@ }, { "cell_type": "markdown", - "id": "1f7ed7dc", + "id": "4745228c", "metadata": {}, "source": [ "\n", @@ -570,7 +570,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1f065f49", + "id": "5f458cdb", "metadata": {}, "outputs": [], "source": [ @@ -610,7 +610,7 @@ }, { "cell_type": "markdown", - "id": "eac60a5b", + "id": "12df8c99", "metadata": {}, "source": [ "\n", @@ -619,7 +619,7 @@ }, { "cell_type": "markdown", - "id": "82e10ce9", + "id": "75f8aac3", "metadata": {}, "source": [ "Compute the Brier Score and time-dependent concordance index for the test set. See notebook introduction for more details." @@ -628,7 +628,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d4f4276c", + "id": "91c641cc", "metadata": {}, "outputs": [], "source": [ @@ -648,7 +648,7 @@ }, { "cell_type": "markdown", - "id": "cf23ffd9", + "id": "995015ef", "metadata": {}, "source": [ "\n", @@ -665,7 +665,7 @@ }, { "cell_type": "markdown", - "id": "60c3b394", + "id": "c33d4764", "metadata": {}, "source": [ "\n", @@ -675,7 +675,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d1d09058", + "id": "d5cff6fa", "metadata": {}, "outputs": [], "source": [ @@ -715,7 +715,7 @@ }, { "cell_type": "markdown", - "id": "43638d7d", + "id": "2fda09b5", "metadata": {}, "source": [ "\n", @@ -724,7 +724,7 @@ }, { "cell_type": "markdown", - "id": "52d822d1", + "id": "8645f579", "metadata": {}, "source": [ "Compute the Brier Score and time-dependent concordance index for the test set. See notebook introduction for more details." @@ -733,7 +733,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cdb73742", + "id": "e6b604ab", "metadata": {}, "outputs": [], "source": [ @@ -754,7 +754,7 @@ { "cell_type": "code", "execution_count": null, - "id": "38ce3818", + "id": "e1aa943c", "metadata": {}, "outputs": [], "source": []