Skip to content

Commit

Permalink
Update dcm_api.py with predict_risk method (#85)
Browse files Browse the repository at this point in the history
I added the predict_risk method for the DeepCoxMixtures to resolve bug issue#79
  • Loading branch information
haivanle authored Jul 18, 2022
1 parent 0b26fb6 commit 1a06c17
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion dsm/contrib/dcm/dcm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ def fit(self, x, t, e, vsize=0.15, val_data=None,
return self


def predict_risk(self, x, t=None):

if self.fitted:
return 1-self.predict_survival(x, t)
else:
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `predict_risk`.")

def predict_survival(self, x, t):
r"""Returns the estimated survival probability at time \( t \),
\( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \).
Expand Down Expand Up @@ -195,4 +204,4 @@ def predict_latent_z(self, x):
"model using the `fit` method on some training data " +
"before calling `predict_latent_z`.")



0 comments on commit 1a06c17

Please sign in to comment.