From bfba0550c14f454a4959d0a578aeca4a4429b2d1 Mon Sep 17 00:00:00 2001 From: Chirag Nagpal Date: Fri, 25 Feb 2022 14:07:23 -0500 Subject: [PATCH] modified: cmhe_utilities.py --- auton_survival/models/cmhe/cmhe_utilities.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/auton_survival/models/cmhe/cmhe_utilities.py b/auton_survival/models/cmhe/cmhe_utilities.py index aebf5c3..474a292 100644 --- a/auton_survival/models/cmhe/cmhe_utilities.py +++ b/auton_survival/models/cmhe/cmhe_utilities.py @@ -358,8 +358,9 @@ def predict_latent_z(model, x): def predict_latent_phi(model, x): model, _ = model - gates, _ = model(x) - phi_gate_probs = torch.exp(gates).sum(axis=1).detach().numpy() + x = model.embedding(x) - return phi_gate_probs + p_phi_gate = torch.nn.Softmax(dim=1)(model.phi_gate(x)).detach().numpy() + + return p_phi_gate