diff --git a/auton_survival/models/cmhe/cmhe_utilities.py b/auton_survival/models/cmhe/cmhe_utilities.py index aebf5c3..474a292 100644 --- a/auton_survival/models/cmhe/cmhe_utilities.py +++ b/auton_survival/models/cmhe/cmhe_utilities.py @@ -358,8 +358,9 @@ def predict_latent_z(model, x): def predict_latent_phi(model, x): model, _ = model - gates, _ = model(x) - phi_gate_probs = torch.exp(gates).sum(axis=1).detach().numpy() + x = model.embedding(x) - return phi_gate_probs + p_phi_gate = torch.nn.Softmax(dim=1)(model.phi_gate(x)).detach().numpy() + + return p_phi_gate