diff --git a/auton_survival/models/cmhe/cmhe_utilities.py b/auton_survival/models/cmhe/cmhe_utilities.py
index aebf5c3..474a292 100644
--- a/auton_survival/models/cmhe/cmhe_utilities.py
+++ b/auton_survival/models/cmhe/cmhe_utilities.py
@@ -358,8 +358,9 @@ def predict_latent_z(model, x):
 def predict_latent_phi(model, x):
 
   model, _ = model
-  gates, _ = model(x)
 
-  phi_gate_probs = torch.exp(gates).sum(axis=1).detach().numpy()
+  x = model.embedding(x)
 
-  return phi_gate_probs
+  p_phi_gate = torch.nn.Softmax(dim=1)(model.phi_gate(x)).detach().numpy()
+
+  return p_phi_gate