Skip to content

Commit

Permalink
modified: cmhe_utilities.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragnagpal committed Feb 25, 2022
1 parent ae399e3 commit bfba055
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions auton_survival/models/cmhe/cmhe_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,9 @@ def predict_latent_z(model, x):
def predict_latent_phi(model, x):

model, _ = model
gates, _ = model(x)

phi_gate_probs = torch.exp(gates).sum(axis=1).detach().numpy()
x = model.embedding(x)

return phi_gate_probs
p_phi_gate = torch.nn.Softmax(dim=1)(model.phi_gate(x)).detach().numpy()

return p_phi_gate

0 comments on commit bfba055

Please sign in to comment.