Skip to content

Commit

Permalink
modified: estimators.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Mar 29, 2022
1 parent 45c7f38 commit 338b155
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions auton_survival/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ def _fit_dcm(features, outcomes, random_seed, **hyperparams):
model = DeepCoxMixtures(k=k,
layers=layers,
gamma=gamma,
smoothing_factor=smoothing_factor)
smoothing_factor=smoothing_factor,
random_seed=random_seed)

model.fit(features.values, outcomes.time.values, outcomes.event.values,
iters=epochs, batch_size=batch_size, learning_rate=learning_rate,
random_seed=random_seed)
iters=epochs, batch_size=batch_size, learning_rate=learning_rate)

return model

Expand Down Expand Up @@ -641,7 +641,7 @@ def __init__(self, model, random_seed=0, **hyperparams):
self.fitted = False

def fit(self, features, outcomes,
weights=None, resample_size=1.0, weights_clip=1e-2):
weights=None, resample_size=1.0):

"""This method is used to train an instance of the survival model.
Expand Down Expand Up @@ -671,8 +671,8 @@ def fit(self, features, outcomes,
assert len(weights) == features.shape[0], "Size of passed weights must match size of training data."
assert ((weights>0.0)&(weights<=1.0)).all(), "Weights must be in the range (0,1]."

weights[weights>(1-weights_clip)] = 1-weights_clip
weights[weights<(weights_clip)] = weights_clip
# weights[weights>(1-weights_clip)] = 1-weights_clip
# weights[weights<(weights_clip)] = weights_clip

data = features.join(outcomes)
data_resampled = data.sample(weights = weights,
Expand Down

0 comments on commit 338b155

Please sign in to comment.