diff --git a/auton_survival/models/cmhe/__init__.py b/auton_survival/models/cmhe/__init__.py index 9df9a00..267a508 100644 --- a/auton_survival/models/cmhe/__init__.py +++ b/auton_survival/models/cmhe/__init__.py @@ -182,7 +182,7 @@ def _gen_torch_model(self, inputdim, optimizer): def fit(self, x, t, e, a, vsize=0.15, val_data=None, iters=1, learning_rate=1e-3, batch_size=100, - optimizer="Adam", random_state=100): + patience=2, optimizer="Adam", random_state=100): r"""This method is used to train an instance of the DSM model. @@ -234,6 +234,7 @@ def fit(self, x, t, e, a, vsize=0.15, val_data=None, epochs=iters, lr=learning_rate, bs=batch_size, + patience=patience, return_losses=True) self.torch_model = (model[0].eval(), model[1]) @@ -272,7 +273,7 @@ def predict_survival(self, x, a, t=None): "model using the `fit` method on some training data " + "before calling `predict_survival`.") - x = self._preprocess_test_data(x, a) + x, a = self._preprocess_test_data(x, a) if t is not None: if not isinstance(t, list): diff --git a/auton_survival/models/cmhe/cmhe_utilities.py b/auton_survival/models/cmhe/cmhe_utilities.py index 474a292..069e1da 100644 --- a/auton_survival/models/cmhe/cmhe_utilities.py +++ b/auton_survival/models/cmhe/cmhe_utilities.py @@ -259,7 +259,7 @@ def train_step(model, x, t, e, a, breslow_splines, optimizer, return breslow_splines -def test_step(model, x, t, e, a, breslow_splines, loss='q', typ='soft'): +def test_step(model, x, t, e, a, breslow_splines, loss='q', typ='soft'): if loss == 'q': with torch.no_grad(): @@ -330,6 +330,7 @@ def predict_survival(model, x, a, t): if isinstance(t, (int, float)): t = [t] model, breslow_splines = model + gates, lrisks = model(x, a=a) lrisks = lrisks.detach().numpy()