Skip to content

Commit

Permalink
modified: estimators.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Mar 29, 2022
1 parent 00d9d34 commit 9102022
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion auton_survival/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,8 @@ def __init__(self, model, random_seed=0, **hyperparams):
self.random_seed = random_seed
self.fitted = False

def fit(self, features, outcomes):
def fit(self, features, outcomes,
weights=None, resample_size=1.0, weights_clip=1e-2):

"""This method is used to train an instance of the survival model.
Expand All @@ -660,6 +661,22 @@ def fit(self, features, outcomes):
"""

if weights is not None:
assert len(weights) == features.shape[0], "Size of passed weights must match size of training data."
assert ((weights>0.0)&(weights<=1.0)).all(), "Weights must be in the range (0,1]."

weights[weights>(1-weights_clip)] = 1-weights_clip
weights[weights<(weights_clip)] = weights_clip

data = features.join(outcomes)
data_resampled = data.sample(weights = weights,
frac = resample_size,
replace = True,
random_state = self.random_seed)
features = data_resampled[features.columns]
outcomes = data_resampled[outcomes.columns]


if self.model == 'cph': self._model = _fit_cph(features, outcomes, self.random_seed, **self.hyperparams)
elif self.model == 'rsf': self._model = _fit_rsf(features, outcomes, self.random_seed, **self.hyperparams)
elif self.model == 'dsm': self._model = _fit_dsm(features, outcomes, self.random_seed, **self.hyperparams)
Expand Down

0 comments on commit 9102022

Please sign in to comment.