Skip to content

Commit

Permalink
Evaluation updated to competing risks (autonlab#29)
Browse files Browse the repository at this point in the history
* Implemented computation of Negative Log-Likelihood.
  • Loading branch information
Jeanselme authored and texchi2 committed Jan 23, 2021
1 parent 8880249 commit 7f2c21e
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion dsm/dsm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from dsm.dsm_torch import DeepConvolutionalSurvivalMachinesTorch
from dsm.losses import predict_cdf
import dsm.losses as losses
from dsm.utilities import train_dsm, _get_padded_features, _get_padded_targets
from dsm.utilities import train_dsm
from dsm.utilities import _get_padded_features, _get_padded_targets
from dsm.utilities import _reshape_tensor_with_nans

import torch
import numpy as np
Expand Down Expand Up @@ -120,6 +122,40 @@ def fit(self, x, t, e, vsize=0.15,

return self

def _eval_nll(self, x, t, e):
r"""This function computes the negative log likelihood of the given data.
In case of competing risks, the negative log likelihoods are summed over
the different events' type.
Parameters
----------
x: np.ndarray
A numpy array of the input features, \( x \).
t: np.ndarray
A numpy array of the event/censoring times, \( t \).
e: np.ndarray
A numpy array of the event/censoring indicators, \( \delta \).
\( \delta = r \) means the event r took place.
Returns:
float: Negative log likelihood.
"""
if not(self.fitted):
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `_eval_nll`.")
processed_data = self._prepocess_training_data(x, t, e, 0, 0)
_, _, _, x_val, t_val, e_val = processed_data
x_val, t_val, e_val = x_val,\
_reshape_tensor_with_nans(t_val),\
_reshape_tensor_with_nans(e_val)
loss = 0
for r in range(self.torch_model.risks):
loss += float(losses.conditional_loss(self.torch_model,
x_val, t_val, e_val, elbo=False,
risk=str(r+1)).detach().numpy())
return loss

def _prepocess_test_data(self, x):
return torch.from_numpy(x)

Expand Down

0 comments on commit 7f2c21e

Please sign in to comment.