diff --git a/auton_survival/estimators.py b/auton_survival/estimators.py index 47c8625..b7442a7 100644 --- a/auton_survival/estimators.py +++ b/auton_survival/estimators.py @@ -102,9 +102,9 @@ def _fit_dcm(features, outcomes, random_seed, **hyperparams): np.array : A float or list of the times at which to compute the survival probability. """ + raise NotImplementedError() - from .models.dcm import DeepCoxMixture - from sdcm.dcm_utils import train + from .models.dcm import DeepCoxMixtures import torch torch.manual_seed(random_seed) @@ -162,6 +162,8 @@ def _predict_dcm(model, features, times): """ + raise NotImplementedError() + from sdcm.dcm_utils import predict_scores import torch @@ -352,7 +354,9 @@ def _fit_cph(features, outcomes, random_seed, **hyperparams): data = outcomes.join(features) penalizer = hyperparams.get('l2', 1e-3) - return CoxPHFitter(penalizer=penalizer).fit(data, duration_col='time', event_col='event') + return CoxPHFitter(penalizer=penalizer).fit(data, + duration_col='time', + event_col='event') def _fit_rsf(features, outcomes, random_seed, **hyperparams): @@ -708,9 +712,8 @@ def predict_risk(self, features, times): class CounterfactualSurvivalModel: - """Universal interface to train multiple differenct counterfactual survival models.""" - - _VALID_MODELS = ['rsf', 'cph', 'dsm'] + """Universal interface to train multiple different counterfactual + survival models.""" def __init__(self, treated_model, control_model): @@ -727,4 +730,5 @@ def predict_counterfactuals(self, features, times): control_outcomes = self.control_model.predict(features, times) treated_outcomes = self.treated_model.predict(features, times) - return treated_outcomes, control_outcomes \ No newline at end of file + return treated_outcomes, control_outcomes + \ No newline at end of file diff --git a/auton_survival/experiments.py b/auton_survival/experiments.py index 97e4157..687fa55 100644 --- a/auton_survival/experiments.py +++ b/auton_survival/experiments.py @@ -179,7 +179,6 @@ def evaluate(self, features, outcomes, metrics=['auc', 'ctd'], horizons=[]): for metric in metrics: raise NotImplementedError() - class CounterfactualSurvivalRegressionCV: r"""Universal interface to train Counterfactual Survival Analysis models in a @@ -227,7 +226,7 @@ class CounterfactualSurvivalRegressionCV: In Machine Learning for Healthcare Conference, pages 674–708. PMLR """ - + def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}): self.model = model diff --git a/auton_survival/reporting.py b/auton_survival/reporting.py index 7772fb7..c75d932 100644 --- a/auton_survival/reporting.py +++ b/auton_survival/reporting.py @@ -14,8 +14,6 @@ from sklearn.metrics import roc_curve, auc -import os - def plot_kaplanmeier(outcomes, groups=None, plot_counts=False, **kwargs):