Skip to content

Commit

Permalink
fin updates
Browse files Browse the repository at this point in the history
  • Loading branch information
PotosnakW committed May 19, 2022
1 parent 9513068 commit 4cb835e
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 170 deletions.
41 changes: 20 additions & 21 deletions auton_survival/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,51 +595,50 @@ def fit(self, features, outcomes, vsize=0.15, val_data=None,
assert weights_val is None, "Weights for validation data \
must be None if validation data is not specified."

train_data = data.sample(frac=1-vsize, random_state=self.random_seed)
val_data = data[~data.index.isin(train_data.index)]
val_data = (val_data[features.columns], val_data[outcomes.columns])
data_train = data.sample(frac=1-vsize, random_state=self.random_seed)
data_val = data[~data.index.isin(data_train.index)]

else:
train_data = data
data_train = data
data_val = val_data[0].join(val_data[1])

if weights is not None:
assert len(weights) == features.shape[0], "Size of passed weights \
must match size of training data."
assert (weights>0.).any(), "All weights must be positive."

weights = pd.Series(weights, index=data.index)
val_data = val_data[0].join(val_data[1])

if weights_val is None:
weights_train = weights[train_data.index]
weights_val = weights[val_data.index]

else:
assert weights_val is not None, "Validation set weights must be \
specified."
assert len(weights_val) == val_data[0].shape[0], "Size of passed \
weights_val must match size of validation data."
if weights_val is not None:
assert len(weights_val) == data_val[features.columns].shape[0], "Size \
of passed weights_val must match size of validation data."
assert (weights_val>0.).any(), "All weights_val must be positive."

weights_train = weights

train_data_resampled = train_data.sample(weights = weights_train,
else:
assert val_data is None, "Validation weights must be specified if validation \
data and training set weights are both specified."
weights_train = weights[data_train.index]
weights_val = weights[data_val.index]

data_train_resampled = data_train.sample(weights = weights_train,
frac = resample_size,
replace = True,
random_state = self.random_seed)

val_data_resampled = val_data.sample(weights = weights_val,
data_val_resampled = data_val.sample(weights = weights_val,
frac = resample_size,
replace = True,
random_state = self.random_seed)

features = train_data_resampled[features.columns]
outcomes = train_data_resampled[outcomes.columns]
features = data_train_resampled[features.columns]
outcomes = data_train_resampled[outcomes.columns]

val_data = (val_data_resampled[features.columns],
val_data_resampled[outcomes.columns])
data_val = data_val_resampled

val_data = (val_data[0], val_data[1].time, val_data[1].event)
val_data = (data_val[features.columns], data_val[outcomes.columns].time,
data_val[outcomes.columns].event)

if self.model == 'cph':
self._model = _fit_cph(features, outcomes,
Expand Down
89 changes: 10 additions & 79 deletions auton_survival/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def __init__(self, model='dcph', folds=None, num_folds=5,
self.random_seed = random_seed
self.hyperparam_grid = list(ParameterGrid(hyperparam_grid))

def fit(self, features, outcomes, metric='ibs', horizon=None,
cat_feats=None, num_feats=None, one_hot=False):
def fit(self, features, outcomes, metric='ibs', horizon=None):

r"""Fits the survival regression model to the data in a cross-
validation or nested cross-validation fashion.
Expand All @@ -124,12 +123,6 @@ def fit(self, features, outcomes, metric='ibs', horizon=None,
horizon : int or float, default=None
Event-horizon at which to evaluate model performance.
If None, then the maximum permissible event-time from the data is used.
cat_feats: list
List of categorical features.
num_feats: list
List of numerical/continuous features.
one_hot : bool, default=False
Whether to perform One-Hot encoding for categorical features.
Returns
-----------
Expand All @@ -138,9 +131,6 @@ def fit(self, features, outcomes, metric='ibs', horizon=None,
"""

self.metric = metric
self.cat_feats = cat_feats
self.num_feats = num_feats
self.one_hot = one_hot
self.horizon = horizon

if (horizon is None) & (metric not in ['auc', 'ctd', 'brs']):
Expand All @@ -153,16 +143,11 @@ def fit(self, features, outcomes, metric='ibs', horizon=None,
self.num_folds,
self.random_seed)

if self.num_nested_folds is None:
proc_x_tr = self._process_data(features_tr=features,
cat_feats=self.cat_feats,
num_feats=self.num_feats,
one_hot=self.one_hot)

if self.num_nested_folds is None:
best_params = self._cv_select_parameters(features, outcomes, self.folds)
model = SurvivalModel(self.model, self.random_seed, **best_params)

return model.fit(proc_x_tr, outcomes)
return model.fit(features, outcomes)

else:
return self._train_nested_cv_models(features, outcomes)
Expand Down Expand Up @@ -190,19 +175,14 @@ def _train_nested_cv_models(self, features, outcomes):
for fi, fold in enumerate(set(self.folds)):
x_tr = features.copy().loc[self.folds!=fold]
y_tr = outcomes.loc[self.folds!=fold]

proc_x_tr = self._process_data(features_tr=x_tr,
cat_feats=self.cat_feats,
num_feats=self.num_feats,
one_hot=self.one_hot)

self.nested_folds = self._get_stratified_folds(y_tr, 'event',
self.num_nested_folds,
self.random_seed)
# Use unprocessed training set for nested CV.

best_params = self._cv_select_parameters(x_tr, y_tr, self.nested_folds)
model = SurvivalModel(self.model, self.random_seed, **best_params)
models[fi] = model.fit(proc_x_tr, y_tr)
models[fi] = model.fit(x_tr, y_tr)

return models

Expand Down Expand Up @@ -248,10 +228,7 @@ def _cv_select_parameters(self, features, outcomes, folds):
y_tr = outcomes.loc[folds!=fold]
y_val = outcomes.loc[folds==fold]

proc_x_tr, proc_x_val = self._process_data(x_tr, x_val, self.cat_feats,
self.num_feats,self.one_hot)

param_results = self._fit_evaluate_model(proc_x_tr, y_tr, proc_x_val, y_val)
param_results = self._fit_evaluate_model(x_tr, y_tr, x_val, y_val)

# Hyperparameter results as row items and fold results as columns.
fold_results = pd.concat([fold_results, pd.DataFrame(param_results)],
Expand Down Expand Up @@ -302,7 +279,7 @@ def _fit_evaluate_model(self, features_tr, outcomes_tr,
if self.metric == 'ibs':
times = self.times
else:
times = self.times[-1]
times = [self.times[-1]]

param_results = []
for hyper_param in tqdm(self.hyperparam_grid):
Expand Down Expand Up @@ -364,47 +341,6 @@ def _get_stratified_folds(self, dataset, event_label, n_folds, random_seed):

return df_folds

def _process_data(self, features_tr, features_te=None, cat_feats=None,
num_feats=None, one_hot=False):

"""Fit preprocessors to training set data and transform training
set and test set data.
Parameters
-----------
features_tr : pd.DataFrame
A pandas dataframe with rows corresponding to individual samples
and columns as covariates for the training set data.
features_te : pd.DataFrame
A pandas dataframe with rows corresponding to individual samples
and columns as covariates for the test set data.
cat_feats: list
List of categorical features.
num_feats: list
List of numerical/continuous features.
one_hot : bool, default=False
Whether to perform One-Hot encoding for categorical features.
Returns
-----------
Pandas dataframes of preprocessed training and test set data.
"""

preprocessor = Preprocessor(cat_feat_strat='replace',
num_feat_strat='median',
scaling_strategy='standard',
one_hot=one_hot)
transformer = preprocessor.fit(features_tr, cat_feats=cat_feats,
num_feats=num_feats, fill_value=-1)
features_tr = transformer.transform(features_tr.copy())

if features_te is None:
return features_tr
else:
features_te = transformer.transform(features_te.copy())
return features_tr, features_te

def _check_times(self, outcomes, times, folds):

"""Verify times are within an appropriate range for model evaluation.
Expand Down Expand Up @@ -513,8 +449,7 @@ def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}):
random_seed=random_seed,
hyperparam_grid=hyperparam_grid)

def fit(self, features, outcomes, interventions,
metric, cat_feats, num_feats):
def fit(self, features, outcomes, interventions, metric):

r"""Fits the Survival Regression Model to the data in a Cross
Validation fashion.
Expand Down Expand Up @@ -542,13 +477,9 @@ def fit(self, features, outcomes, interventions,

treated_model = self.treated_experiment.fit(features.loc[interventions==1],
outcomes.loc[interventions==1],
metric=metric,
cat_feats=cat_feats,
num_feats=num_feats)
metric=metric)
control_model = self.control_experiment.fit(features.loc[interventions!=1],
outcomes.loc[interventions!=1],
metric=metric,
cat_feats=cat_feats,
num_feats=num_feats)
metric=metric)

return CounterfactualSurvivalModel(treated_model, control_model)
20 changes: 11 additions & 9 deletions auton_survival/phenotyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def __init__(self,
self.random_seed = random_seed

def fit(self, features, outcomes, interventions, metric,
horizon, cat_feats, num_feats):
horizon):

"""Fit a counterfactual model and regress the difference of the estimated
counterfactual Restricted Mean Survival Time using a Random Forest regressor.
Expand All @@ -492,16 +492,19 @@ def fit(self, features, outcomes, interventions, metric,
outcomes : pd.DataFrame
A pandas dataframe with rows corresponding to individual samples
and columns 'time' and 'event'.
treatment_indicator : np.array
interventions : np.array
Boolean numpy array of treatment indicators. True means individual
was assigned a specific treatment.
metric : str, default='ibs'
Metric used to evaluate model performance and tune hyperparameters.
Options include:
- 'auc': Dynamic area under the ROC curve
- 'brs' : Brier Score
- 'ibs' : Integrated Brier Score
- 'ctd' : Concordance Index
horizon : np.float
The event horizon at which to compute the counterfacutal RMST for
regression.
cat_feats: list
List of categorical features.
num_feats: list
List of numerical/continuous features.
regression.
Returns
-----------
Expand All @@ -512,8 +515,7 @@ def fit(self, features, outcomes, interventions, metric,
cf_model = CounterfactualSurvivalRegressionCV(model=self.cf_method,
hyperparam_grid=self.cf_hyperparams)

self.cf_model = cf_model.fit(features, outcomes, interventions,
metric, cat_feats, num_feats)
self.cf_model = cf_model.fit(features, outcomes, interventions, metric)

times = np.unique(outcomes.time.values)
cf_predictions = self.cf_model.predict_counterfactual_survival(features,
Expand Down
35 changes: 17 additions & 18 deletions examples/CV Survival Regression on SUPPORT Dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
"cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']\n",
"num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', \n",
" 'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', \n",
" 'glucose', 'bun', 'urine', 'adlp', 'adls']"
" 'glucose', 'bun', 'urine', 'adlp', 'adls']\n",
"\n",
"# Data should be processed in a fold-independent manner when performing cross-validation. \n",
"# For simplicity in this demo, we process the dataset in a non-independent manner.\n",
"preprocessor = Preprocessor(cat_feat_strat='ignore', num_feat_strat= 'mean') \n",
"x = preprocessor.fit_transform(features, cat_feats=cat_feats, num_feats=num_feats,\n",
" one_hot=True, fill_value=-1)"
]
},
{
Expand All @@ -47,13 +53,6 @@
"times = np.quantile(outcomes.time[outcomes.event==1], horizons).tolist()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -68,7 +67,7 @@
" 'layers' : [[100]]}\n",
"\n",
"experiment = SurvivalRegressionCV(model='dsm', num_folds=3, hyperparam_grid=param_grid, random_seed=0)\n",
"model = experiment.fit(features, outcomes, cat_feats=cat_feats, num_feats=num_feats, one_hot=True)\n"
"model = experiment.fit(x, outcomes, metric='ctd')"
]
},
{
Expand All @@ -81,21 +80,21 @@
"model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from auton_survival.preprocessing import Preprocessor\n",
"\n",
"preprocessor = Preprocessor(cat_feat_strat='replace', num_feat_strat='median',\n",
" scaling_strategy='standard', one_hot=True)\n",
"features_preprocessed = preprocessor.fit_transform(features, cat_feats=cat_feats, \n",
" num_feats=num_feats, fill_value=-1)\n",
"\n",
"out_risk = model.predict_risk(features_preprocessed, times)\n",
"out_survival = model.predict_survival(features_preprocessed, times)"
"out_risk = model.predict_risk(x, times)\n",
"out_survival = model.predict_survival(x, times)"
]
},
{
Expand Down
Loading

0 comments on commit 4cb835e

Please sign in to comment.