Skip to content

Commit

Permalink
modified: estimators.py
Browse files Browse the repository at this point in the history
	modified:   experiments.py
	modified:   reporting.py
  • Loading branch information
chiragnagpal committed Mar 25, 2022
1 parent 55c5761 commit 529c6cc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
18 changes: 11 additions & 7 deletions auton_survival/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def _fit_dcm(features, outcomes, random_seed, **hyperparams):
np.array : A float or list of the times at which to compute the survival probability.
"""
raise NotImplementedError()

from .models.dcm import DeepCoxMixture
from sdcm.dcm_utils import train
from .models.dcm import DeepCoxMixtures

import torch
torch.manual_seed(random_seed)
Expand Down Expand Up @@ -162,6 +162,8 @@ def _predict_dcm(model, features, times):
"""

raise NotImplementedError()

from sdcm.dcm_utils import predict_scores

import torch
Expand Down Expand Up @@ -352,7 +354,9 @@ def _fit_cph(features, outcomes, random_seed, **hyperparams):
data = outcomes.join(features)
penalizer = hyperparams.get('l2', 1e-3)

return CoxPHFitter(penalizer=penalizer).fit(data, duration_col='time', event_col='event')
return CoxPHFitter(penalizer=penalizer).fit(data,
duration_col='time',
event_col='event')

def _fit_rsf(features, outcomes, random_seed, **hyperparams):

Expand Down Expand Up @@ -708,9 +712,8 @@ def predict_risk(self, features, times):

class CounterfactualSurvivalModel:

"""Universal interface to train multiple differenct counterfactual survival models."""

_VALID_MODELS = ['rsf', 'cph', 'dsm']
"""Universal interface to train multiple different counterfactual
survival models."""

def __init__(self, treated_model, control_model):

Expand All @@ -727,4 +730,5 @@ def predict_counterfactuals(self, features, times):
control_outcomes = self.control_model.predict(features, times)
treated_outcomes = self.treated_model.predict(features, times)

return treated_outcomes, control_outcomes
return treated_outcomes, control_outcomes

3 changes: 1 addition & 2 deletions auton_survival/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def evaluate(self, features, outcomes, metrics=['auc', 'ctd'], horizons=[]):
for metric in metrics:
raise NotImplementedError()


class CounterfactualSurvivalRegressionCV:

r"""Universal interface to train Counterfactual Survival Analysis models in a
Expand Down Expand Up @@ -227,7 +226,7 @@ class CounterfactualSurvivalRegressionCV:
In Machine Learning for Healthcare Conference, pages 674–708. PMLR
"""

def __init__(self, model, cv_folds=5, random_seed=0, hyperparam_grid={}):

self.model = model
Expand Down
2 changes: 0 additions & 2 deletions auton_survival/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from sklearn.metrics import roc_curve, auc

import os


def plot_kaplanmeier(outcomes, groups=None, plot_counts=False, **kwargs):

Expand Down

0 comments on commit 529c6cc

Please sign in to comment.