From 9e95ad1cc25fd48db8967d06aefff75b036a281b Mon Sep 17 00:00:00 2001 From: Chirag Nagpal Date: Tue, 28 Jun 2022 10:51:30 -0400 Subject: [PATCH] modified: estimators.py --- auton_survival/estimators.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/auton_survival/estimators.py b/auton_survival/estimators.py index 01a1ca7..cc65afd 100644 --- a/auton_survival/estimators.py +++ b/auton_survival/estimators.py @@ -228,6 +228,9 @@ def _predict_dcph(model, features, times): if isinstance(times, float) or isinstance(times, int): times = [float(times)] + if isinstance(times, np.ndarray): + times = times.ravel().tolist() + return model.predict_survival(x=features.values, t=times) def _fit_cph(features, outcomes, val_data, random_seed, **hyperparams): @@ -378,7 +381,7 @@ def _fit_dsm(features, outcomes, val_data, random_seed, **hyperparams): k = hyperparams.get("k", 3) layers = hyperparams.get("layers", [100]) - epochs = hyperparams.get("iters", 10) + epochs = hyperparams.get("iters", 50) distribution = hyperparams.get("distribution", "Weibull") temperature = hyperparams.get("temperature", 1.0) lr = hyperparams.get("learning_rate", 1e-3)