Skip to content

Commit

Permalink
modified: __init__.py
Browse files Browse the repository at this point in the history
	modified:   cmhe_utilities.py
  • Loading branch information
chiragnagpal committed Feb 28, 2022
1 parent 4aec0d1 commit c40887d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
5 changes: 3 additions & 2 deletions auton_survival/models/cmhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _gen_torch_model(self, inputdim, optimizer):

def fit(self, x, t, e, a, vsize=0.15, val_data=None,
iters=1, learning_rate=1e-3, batch_size=100,
optimizer="Adam", random_state=100):
patience=2, optimizer="Adam", random_state=100):

r"""This method is used to train an instance of the DSM model.
Expand Down Expand Up @@ -234,6 +234,7 @@ def fit(self, x, t, e, a, vsize=0.15, val_data=None,
epochs=iters,
lr=learning_rate,
bs=batch_size,
patience=patience,
return_losses=True)

self.torch_model = (model[0].eval(), model[1])
Expand Down Expand Up @@ -272,7 +273,7 @@ def predict_survival(self, x, a, t=None):
"model using the `fit` method on some training data " +
"before calling `predict_survival`.")

x = self._preprocess_test_data(x, a)
x, a = self._preprocess_test_data(x, a)

if t is not None:
if not isinstance(t, list):
Expand Down
3 changes: 2 additions & 1 deletion auton_survival/models/cmhe/cmhe_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def train_step(model, x, t, e, a, breslow_splines, optimizer,
return breslow_splines


def test_step(model, x, t, e, a, breslow_splines, loss='q', typ='soft'):
def test_step(model, x, t, e, a, breslow_splines, loss='q', typ='soft'):

if loss == 'q':
with torch.no_grad():
Expand Down Expand Up @@ -330,6 +330,7 @@ def predict_survival(model, x, a, t):
if isinstance(t, (int, float)): t = [t]

model, breslow_splines = model

gates, lrisks = model(x, a=a)

lrisks = lrisks.detach().numpy()
Expand Down

0 comments on commit c40887d

Please sign in to comment.