diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index 6b890b7..6636b13 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -32,7 +32,9 @@ from dsm.dsm_torch import DeepConvolutionalSurvivalMachinesTorch from dsm.losses import predict_cdf import dsm.losses as losses -from dsm.utilities import train_dsm, _get_padded_features, _get_padded_targets +from dsm.utilities import train_dsm +from dsm.utilities import _get_padded_features, _get_padded_targets +from dsm.utilities import _reshape_tensor_with_nans import torch import numpy as np @@ -120,6 +122,40 @@ def fit(self, x, t, e, vsize=0.15, return self + def _eval_nll(self, x, t, e): + r"""This function computes the negative log likelihood of the given data. + In case of competing risks, the negative log likelihoods are summed over + the different events' type. + + Parameters + ---------- + x: np.ndarray + A numpy array of the input features, \( x \). + t: np.ndarray + A numpy array of the event/censoring times, \( t \). + e: np.ndarray + A numpy array of the event/censoring indicators, \( \delta \). + \( \delta = r \) means the event r took place. + + Returns: + float: Negative log likelihood. + """ + if not(self.fitted): + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `_eval_nll`.") + processed_data = self._prepocess_training_data(x, t, e, 0, 0) + _, _, _, x_val, t_val, e_val = processed_data + x_val, t_val, e_val = x_val,\ + _reshape_tensor_with_nans(t_val),\ + _reshape_tensor_with_nans(e_val) + loss = 0 + for r in range(self.torch_model.risks): + loss += float(losses.conditional_loss(self.torch_model, + x_val, t_val, e_val, elbo=False, + risk=str(r+1)).detach().numpy()) + return loss + def _prepocess_test_data(self, x): return torch.from_numpy(x)