diff --git a/auton_survival/models/cmhe/cmhe_api.py b/auton_survival/models/cmhe/cmhe_api.py new file mode 100644 index 0000000..8aa49c3 --- /dev/null +++ b/auton_survival/models/cmhe/cmhe_api.py @@ -0,0 +1,184 @@ +from .cmhe_torch import DeepCoxPHTorch +from .cmhe_utilities import train_cmhe, predict_scores + +import torch + +class CoxMixturesHeterogenousEffects: + """A Cox Mixtures with Heterogenous Effects model. + + This is the main interface to a Deep Cox Mixture model. + A model is instantiated with approporiate set of hyperparameters and + fit on numpy arrays consisting of the features, event/censoring times + and the event/censoring indicators. + + For full details on Deep Cox Mixture, refer to the paper [1]. + + References + ---------- + [1] Deep Cox Mixtures + for Survival Regression. Machine Learning in Health Conference (2021) + + Parameters + ---------- + k: int + The number of underlying Cox distributions. + layers: list + A list of integers consisting of the number of neurons in each + hidden layer. + Example + ------- + >>> from dsm.contrib import DeepCoxMixtures + >>> model = DeepCoxMixtures() + >>> model.fit(x, t, e) + + """ + + def __init__(self, layers=None): + + self.layers = layers + self.fitted = False + + def __call__(self): + if self.fitted: + print("A fitted instance of the Deep Cox PH model") + else: + print("An unfitted instance of the Deep Cox PH model") + + print("Hidden Layers:", self.layers) + + def _preprocess_test_data(self, x): + return torch.from_numpy(x).float() + + def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state): + + idx = list(range(x.shape[0])) + + np.random.seed(random_state) + np.random.shuffle(idx) + + x_train, t_train, e_train = x[idx], t[idx], e[idx] + + x_train = torch.from_numpy(x_train).float() + t_train = torch.from_numpy(t_train).float() + e_train = torch.from_numpy(e_train).float() + + if val_data is None: + + vsize = int(vsize*x_train.shape[0]) + x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:] + + x_train = x_train[:-vsize] + t_train = t_train[:-vsize] + e_train = e_train[:-vsize] + + else: + + x_val, t_val, e_val = val_data + + x_val = torch.from_numpy(x_val).float() + t_val = torch.from_numpy(t_val).float() + e_val = torch.from_numpy(e_val).float() + + return (x_train, t_train, e_train, x_val, t_val, e_val) + + def _gen_torch_model(self, inputdim, optimizer): + """Helper function to return a torch model.""" + return DeepCoxPHTorch(inputdim, layers=self.layers, + optimizer=optimizer) + + def fit(self, x, t, e, vsize=0.15, val_data=None, + iters=1, learning_rate=1e-3, batch_size=100, + optimizer="Adam", random_state=100): + + r"""This method is used to train an instance of the DSM model. + + Parameters + ---------- + x: np.ndarray + A numpy array of the input features, \( x \). + t: np.ndarray + A numpy array of the event/censoring times, \( t \). + e: np.ndarray + A numpy array of the event/censoring indicators, \( \delta \). + \( \delta = 1 \) means the event took place. + vsize: float + Amount of data to set aside as the validation set. + val_data: tuple + A tuple of the validation dataset. If passed vsize is ignored. + iters: int + The maximum number of training iterations on the training dataset. + learning_rate: float + The learning rate for the `Adam` optimizer. + batch_size: int + learning is performed on mini-batches of input data. this parameter + specifies the size of each mini-batch. + optimizer: str + The choice of the gradient based optimization method. One of + 'Adam', 'RMSProp' or 'SGD'. + random_state: float + random seed that determines how the validation set is chosen. + """ + + processed_data = self._preprocess_training_data(x, t, e, + vsize, val_data, + random_state) + + x_train, t_train, e_train, x_val, t_val, e_val = processed_data + + #Todo: Change this somehow. The base design shouldn't depend on child + + inputdim = x_train.shape[-1] + + model = self._gen_torch_model(inputdim, optimizer) + + model, _ = train_dcph(model, + (x_train, t_train, e_train), + (x_val, t_val, e_val), + epochs=iters, + lr=learning_rate, + bs=batch_size, + return_losses=True) + + self.torch_model = (model[0].eval(), model[1]) + self.fitted = True + + return self + + def predict_risk(self, x, t=None): + + if self.fitted: + return 1-self.predict_survival(x, t) + else: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `predict_risk`.") + + def predict_survival(self, x, t=None): + r"""Returns the estimated survival probability at time \( t \), + \( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \). + + Parameters + ---------- + x: np.ndarray + A numpy array of the input features, \( x \). + t: list or float + a list or float of the times at which survival probability is + to be computed + Returns: + np.array: numpy array of the survival probabilites at each time in t. + + """ + if not self.fitted: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `predict_survival`.") + + x = self._preprocess_test_data(x) + + if t is not None: + if not isinstance(t, list): + t = [t] + + scores = predict_survival(self.torch_model, x, t) + return scores + diff --git a/auton_survival/models/cmhe/cmhe_utilities.py b/auton_survival/models/cmhe/cmhe_utilities.py index bdc64f7..7cc69b4 100644 --- a/auton_survival/models/cmhe/cmhe_utilities.py +++ b/auton_survival/models/cmhe/cmhe_utilities.py @@ -21,8 +21,7 @@ def partial_ll_loss(lrisks, tb, eb, eps=1e-2): lrisks = lrisks[sindex] # lrisks = tf.gather(lrisks, sindex) - lrisksdenom = torch.logcumsumexp(lrisks, dim = 0) # lrisksdenom = tf.math.cumulative_logsumexp(lrisks) - + lrisksdenom = torch.logcumsumexp(lrisks, dim = 0) plls = lrisks - lrisksdenom pll = plls[eb == 1] @@ -173,7 +172,7 @@ def fit_breslow(model, x, t, e, a, log_likelihoods=None, smoothing_factor=1e-4, z = get_hard_z(z_posteriors) zeta = get_hard_z(zeta_posteriors) - breslow_splines = {} + breslow_splines = {} for i in range(model.k): breslowk = BreslowEstimator().fit(lrisks[:, i, :][range(len(zeta)), zeta][z==i], e[z==i], t[z==i]) breslow_splines[i] = smooth_bl_survival(breslowk, smoothing_factor=smoothing_factor) @@ -240,11 +239,11 @@ def test_step(model, x, t, e, a, breslow_splines, loss='q', typ='soft'): return float(loss/x.shape[0]) -def train(model, train_data, val_data, epochs=50, - patience=2, vloss='q', bs=256, typ='soft', lr=1e-3, - use_posteriors=False, debug=False, random_state=0, - return_losses=False, update_splines_after=10, - smoothing_factor=1e-4): +def train_cmhe(model, train_data, val_data, epochs=50, + patience=2, vloss='q', bs=256, typ='soft', lr=1e-3, + use_posteriors=False, debug=False, random_state=0, + return_losses=False, update_splines_after=10, + smoothing_factor=1e-4): torch.manual_seed(random_state) np.random.seed(random_state) @@ -298,7 +297,7 @@ def train(model, train_data, val_data, epochs=50, def predict_scores(model, x, a, t): - if isinstance(t, int) or isinstance(t, float): t = [t] + if isinstance(t, (int, float)): t = [t] model, breslow_splines = model gates, lrisks = model(x, a=a)