Skip to content

Commit

Permalink
modified: estimators.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Mar 29, 2022
1 parent a4debae commit 2806a45
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions auton_survival/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import numpy as np
import pandas as pd

from auton_survival.models.dcm.dcm_api import DeepCoxMixtures

def _get_valid_idx(n, size, random_seed):

Expand Down Expand Up @@ -108,7 +107,7 @@ def _fit_dcm(features, outcomes, random_seed, **hyperparams):
k = hyperparams.get("k", 3)
layers = hyperparams.get("layers", [100])
batch_size = hyperparams.get("batch_size", 128)
lr = hyperparams.get("lr", 1e-3)
learning_rate = hyperparams.get("learning_rate", 1e-3)
epochs = hyperparams.get("epochs", 50)
smoothing_factor = hyperparams.get("smoothing_factor", 1e-4)
gamma = hyperparams.get("gamma", 10)
Expand All @@ -119,7 +118,7 @@ def _fit_dcm(features, outcomes, random_seed, **hyperparams):
smoothing_factor=smoothing_factor)

model.fit(features.values, outcomes.time.values, outcomes.event.values,
iters=epochs, batch_size=batch_size, lr=lr,
iters=epochs, batch_size=batch_size, learning_rate=learning_rate,
random_seed=random_seed)

return model
Expand Down

0 comments on commit 2806a45

Please sign in to comment.