From 0a785349030946310fd7aa3e1321a1d47ffd5bde Mon Sep 17 00:00:00 2001 From: Chirag Nagpal Date: Thu, 3 Dec 2020 03:23:19 +0530 Subject: [PATCH] modified: dsm/dsm_api.py modified: dsm/dsm_torch.py modified: dsm/losses.py modified: dsm/utilities.py --- dsm/datasets.py | 38 ++++++++------- dsm/dsm_api.py | 52 +++++++++++++++------ dsm/dsm_torch.py | 119 ++++++++++++++++++++++++++++++++--------------- dsm/losses.py | 93 ++++++++++++++++++------------------ dsm/utilities.py | 47 ++++++++++++------- 5 files changed, 216 insertions(+), 133 deletions(-) diff --git a/dsm/datasets.py b/dsm/datasets.py index 50c6273..1e53e1e 100644 --- a/dsm/datasets.py +++ b/dsm/datasets.py @@ -1,21 +1,25 @@ # coding=utf-8 -# Copyright 2020 Chirag Nagpal -# -# This file is part of Deep Survival Machines. - -# Deep Survival Machines is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. - -# Deep Survival Machines is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. - -# You should have received a copy of the GNU General Public License -# along with Deep Survival Machines. -# If not, see . +# MIT License + +# Copyright (c) 2020 Carnegie Mellon University, Auton Lab + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. """Utility functions to load standard datasets to train and evaluate the diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index f2e8fdd..4ccd340 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -30,6 +30,7 @@ from dsm.dsm_torch import DeepSurvivalMachinesTorch from dsm.dsm_torch import DeepRecurrentSurvivalMachinesTorch from dsm.losses import predict_cdf +import dsm.losses as losses from dsm.utilities import train_dsm, _get_padded_features, _get_padded_targets import torch @@ -52,7 +53,7 @@ def __init__(self, k=3, layers=None, distribution="Weibull", self.discount = discount self.fitted = False - def _gen_torch_model(self, inputdim, optimizer): + def _gen_torch_model(self, inputdim, optimizer, risks): """Helper function to return a torch model.""" return DeepSurvivalMachinesTorch(inputdim, k=self.k, @@ -60,7 +61,8 @@ def _gen_torch_model(self, inputdim, optimizer): dist=self.dist, temp=self.temp, discount=self.discount, - optimizer=optimizer) + optimizer=optimizer, + risks=risks) def fit(self, x, t, e, vsize=0.15, iters=1, learning_rate=1e-3, batch_size=100, @@ -102,8 +104,8 @@ def fit(self, x, t, e, vsize=0.15, x_train, t_train, e_train, x_val, t_val, e_val = processed_data inputdim = x_train.shape[-1] - - model = self._gen_torch_model(inputdim, optimizer) + maxrisk = int(e_train.max()) + model = self._gen_torch_model(inputdim, optimizer, risks=maxrisk) model, _ = train_dsm(model, x_train, t_train, e_train, x_val, t_val, e_val, @@ -139,8 +141,27 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state): return (x_train, t_train, e_train, x_val, t_val, e_val) + def predict_mean(self, x, risk=1): + r"""Returns the mean Time-to-Event \( t \) + + Parameters + ---------- + x: np.ndarray + A numpy array of the input features, \( x \). + Returns: + np.array: numpy array of the mean time to event. - def predict_risk(self, x, t): + """ + + if self.fitted: + x = self._prepocess_test_data(x) + scores = losses.predict_mean(self.torch_model, x, risk=str(risk)) + return scores + else: + raise Exception("The model has not been fitted yet. Please fit the " + + "model using the `fit` method on some training data " + + "before calling `predict_mean`.") + def predict_risk(self, x, t, risk=1): r"""Returns the estimated risk of an event occuring before time \( t \) \( \widehat{\mathbb{P}}(T\leq t|X) \) for some input data \( x \). @@ -157,14 +178,14 @@ def predict_risk(self, x, t): """ if self.fitted: - return 1-self.predict_survival(x, t) + return 1-self.predict_survival(x, t, risk=str(risk)) else: raise Exception("The model has not been fitted yet. Please fit the " + "model using the `fit` method on some training data " + - "before calling `predict_survival`.") + "before calling `predict_risk`.") - def predict_survival(self, x, t): + def predict_survival(self, x, t, risk=1): r"""Returns the estimated survival probability at time \( t \), \( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \). @@ -183,7 +204,7 @@ def predict_survival(self, x, t): if not isinstance(t, list): t = [t] if self.fitted: - scores = predict_cdf(self.torch_model, x, t) + scores = predict_cdf(self.torch_model, x, t, risk=str(risk)) return np.exp(np.array(scores)).T else: raise Exception("The model has not been fitted yet. Please fit the " + @@ -255,12 +276,14 @@ class DeepRecurrentSurvivalMachines(DSMBase): def __init__(self, k=3, layers=None, hidden=None, distribution='Weibull', temp=1000., discount=1.0, typ='LSTM'): - super(DeepRecurrentSurvivalMachines, self).__init__(k=k, layers=layers, + super(DeepRecurrentSurvivalMachines, self).__init__(k=k, + layers=layers, distribution=distribution, - temp=temp, discount=discount) + temp=temp, + discount=discount) self.hidden = hidden self.typ = typ - def _gen_torch_model(self, inputdim, optimizer): + def _gen_torch_model(self, inputdim, optimizer, risks): """Helper function to return a torch model.""" return DeepRecurrentSurvivalMachinesTorch(inputdim, k=self.k, @@ -270,7 +293,8 @@ def _gen_torch_model(self, inputdim, optimizer): temp=self.temp, discount=self.discount, optimizer=optimizer, - typ=self.typ) + typ=self.typ, + risks=risks) def _prepocess_test_data(self, x): return torch.from_numpy(_get_padded_features(x)) @@ -286,8 +310,6 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state): t = _get_padded_targets(t) e = _get_padded_targets(e) - print (x.shape) - x_train, t_train, e_train = x[idx], t[idx], e[idx] x_train = torch.from_numpy(x_train).double() diff --git a/dsm/dsm_torch.py b/dsm/dsm_torch.py index 5fa76da..551a875 100644 --- a/dsm/dsm_torch.py +++ b/dsm/dsm_torch.py @@ -133,7 +133,8 @@ class DeepSurvivalMachinesTorch(nn.Module): """ def __init__(self, inputdim, k, layers=None, dist='Weibull', - temp=1000., discount=1.0, optimizer='Adam'): + temp=1000., discount=1.0, optimizer='Adam', + risks=1): super(DeepSurvivalMachinesTorch, self).__init__() self.k = k @@ -141,19 +142,30 @@ def __init__(self, inputdim, k, layers=None, dist='Weibull', self.temp = float(temp) self.discount = float(discount) self.optimizer = optimizer + self.risks = risks if layers is None: layers = [] self.layers = layers - if self.dist == 'Weibull': + if self.dist in ['Weibull']: self.act = nn.SELU() - self.scale = nn.Parameter(-torch.ones(k)) - self.shape = nn.Parameter(-torch.ones(k)) - elif self.dist == 'LogNormal': + self.shape = nn.ParameterDict({str(r+1): nn.Parameter(-torch.ones(k)) + for r in range(self.risks)}) + self.scale = nn.ParameterDict({str(r+1):nn.Parameter(-torch.ones(k)) + for r in range(self.risks)}) + elif self.dist in ['Normal']: + self.act = nn.Identity() + self.shape = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k)) + for r in range(self.risks)}) + self.scale = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k)) + for r in range(self.risks)}) + elif self.dist in ['LogNormal']: self.act = nn.Tanh() - self.scale = nn.Parameter(torch.ones(k)) - self.shape = nn.Parameter(torch.ones(k)) + self.shape = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k)) + for r in range(self.risks)}) + self.scale = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k)) + for r in range(self.risks)}) else: raise NotImplementedError('Distribution: '+self.dist+' not implemented'+ ' yet.') @@ -161,15 +173,24 @@ def __init__(self, inputdim, k, layers=None, dist='Weibull', self.embedding = create_representation(inputdim, layers, 'ReLU6') if len(layers) == 0: - self.gate = nn.Sequential(nn.Linear(inputdim, k, bias=False)) - self.scaleg = nn.Sequential(nn.Linear(inputdim, k, bias=True)) - self.shapeg = nn.Sequential(nn.Linear(inputdim, k, bias=True)) + lastdim = inputdim else: - self.gate = nn.Sequential(nn.Linear(layers[-1], k, bias=False)) - self.scaleg = nn.Sequential(nn.Linear(layers[-1], k, bias=True)) - self.shapeg = nn.Sequential(nn.Linear(layers[-1], k, bias=True)) + lastdim = layers[-1] - def forward(self, x): + self.gate = nn.ModuleDict({str(r+1): nn.Sequential( + nn.Linear(lastdim, k, bias=False) + ) for r in range(self.risks)}) + + self.scaleg = nn.ModuleDict({str(r+1): nn.Sequential( + nn.Linear(lastdim, k, bias=True) + ) for r in range(self.risks)}) + + self.shapeg = nn.ModuleDict({str(r+1): nn.Sequential( + nn.Linear(lastdim, k, bias=True) + ) for r in range(self.risks)}) + + + def forward(self, x, risk='1'): """The forward function that is called when data is passed through DSM. Args: @@ -178,13 +199,14 @@ def forward(self, x): """ xrep = self.embedding(x) - return(self.act(self.shapeg(xrep))+self.shape.expand(x.shape[0], -1), - self.act(self.scaleg(xrep))+self.scale.expand(x.shape[0], -1), - self.gate(xrep)/self.temp) + dim = x.shape[0] + return(self.act(self.shapeg[risk](xrep))+self.shape[risk].expand(dim, -1), + self.act(self.scaleg[risk](xrep))+self.scale[risk].expand(dim, -1), + self.gate[risk](xrep)/self.temp) - def get_shape_scale(self): - return(self.shape, - self.scale) + def get_shape_scale(self, risk='1'): + return(self.shape[risk], + self.scale[risk]) class DeepRecurrentSurvivalMachinesTorch(nn.Module): """A Torch implementation of Deep Recurrent Survival Machines model. @@ -229,7 +251,8 @@ class DeepRecurrentSurvivalMachinesTorch(nn.Module): def __init__(self, inputdim, k, typ='LSTM', layers=1, hidden=None, dist='Weibull', - temp=1000., discount=1.0, optimizer='Adam'): + temp=1000., discount=1.0, + optimizer='Adam', risks=1): super(DeepRecurrentSurvivalMachinesTorch, self).__init__() self.k = k @@ -240,22 +263,41 @@ def __init__(self, inputdim, k, typ='LSTM', layers=1, self.hidden = hidden self.layers = layers self.typ = typ + self.risks = risks - if self.dist == 'Weibull': + if self.dist in ['Weibull']: self.act = nn.SELU() - self.scale = nn.Parameter(-torch.ones(k)) - self.shape = nn.Parameter(-torch.ones(k)) - elif self.dist == 'LogNormal': + self.shape = nn.ParameterDict({str(r+1): nn.Parameter(-torch.ones(k)) + for r in range(self.risks)}) + self.scale = nn.ParameterDict({str(r+1):nn.Parameter(-torch.ones(k)) + for r in range(self.risks)}) + elif self.dist in ['Normal']: + self.act = nn.Identity() + self.shape = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k)) + for r in range(self.risks)}) + self.scale = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k)) + for r in range(self.risks)}) + elif self.dist in ['LogNormal']: self.act = nn.Tanh() - self.scale = nn.Parameter(torch.ones(k)) - self.shape = nn.Parameter(torch.ones(k)) + self.shape = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k)) + for r in range(self.risks)}) + self.scale = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k)) + for r in range(self.risks)}) else: raise NotImplementedError('Distribution: '+self.dist+' not implemented'+ ' yet.') - self.gate = nn.Sequential(nn.Linear(hidden, k, bias=False)) - self.scaleg = nn.Sequential(nn.Linear(hidden, k, bias=True)) - self.shapeg = nn.Sequential(nn.Linear(hidden, k, bias=True)) + self.gate = nn.ModuleDict({str(r+1): nn.Sequential( + nn.Linear(hidden, k, bias=False) + ) for r in range(self.risks)}) + + self.scaleg = nn.ModuleDict({str(r+1): nn.Sequential( + nn.Linear(hidden, k, bias=True) + ) for r in range(self.risks)}) + + self.shapeg = nn.ModuleDict({str(r+1): nn.Sequential( + nn.Linear(hidden, k, bias=True) + ) for r in range(self.risks)}) if self.typ == 'LSTM': self.embedding = nn.LSTM(inputdim, hidden, layers, @@ -268,7 +310,7 @@ def __init__(self, inputdim, k, typ='LSTM', layers=1, self.embedding = nn.GRU(inputdim, hidden, layers, bias=False, batch_first=True) - def forward(self, x): + def forward(self, x, risk='1'): """The forward function that is called when data is passed through DSM. Note: As compared to DSM, the input data for DRSM is a tensor. The forward @@ -287,10 +329,11 @@ def forward(self, x): xrep = xrep.contiguous().view(-1, self.hidden) xrep = xrep[inputmask] xrep = nn.ReLU6()(xrep) - return(self.act(self.shapeg(xrep))+self.shape.expand(xrep.shape[0], -1), - self.act(self.scaleg(xrep))+self.scale.expand(xrep.shape[0], -1), - self.gate(xrep)/self.temp) - - def get_shape_scale(self): - return(self.shape, - self.scale) + dim = xrep.shape[0] + return(self.act(self.shapeg[risk](xrep))+self.shape[risk].expand(dim, -1), + self.act(self.scaleg[risk](xrep))+self.scale[risk].expand(dim, -1), + self.gate[risk](xrep)/self.temp) + + def get_shape_scale(self, risk='1'): + return(self.shape[risk], + self.scale[risk]) diff --git a/dsm/losses.py b/dsm/losses.py index 4194b21..915b1c4 100644 --- a/dsm/losses.py +++ b/dsm/losses.py @@ -39,8 +39,9 @@ import torch import torch.nn as nn -def _normal_loss(model, t, e): - shape, scale = model.get_shape_scale() +def _normal_loss(model, t, e, risk='1'): + + shape, scale = model.get_shape_scale(risk) k_ = shape.expand(t.shape[0], -1) b_ = scale.expand(t.shape[0], -1) @@ -57,16 +58,16 @@ def _normal_loss(model, t, e): s = 0.5 - 0.5*torch.erf(s) s = torch.log(s) - uncens = np.where(e == 1)[0] - cens = np.where(e == 0)[0] + uncens = np.where(e.cpu().data.numpy() == int(risk))[0] + cens = np.where(e.cpu().data.numpy() != int(risk))[0] ll += f[uncens].sum() + s[cens].sum() return -ll.mean() -def _lognormal_loss(model, t, e): +def _lognormal_loss(model, t, e, risk='1'): - shape, scale = model.get_shape_scale() + shape, scale = model.get_shape_scale(risk) k_ = shape.expand(t.shape[0], -1) b_ = scale.expand(t.shape[0], -1) @@ -83,16 +84,16 @@ def _lognormal_loss(model, t, e): s = 0.5 - 0.5*torch.erf(s) s = torch.log(s) - uncens = np.where(e == 1)[0] - cens = np.where(e == 0)[0] + uncens = np.where(e.cpu().data.numpy() == int(risk))[0] + cens = np.where(e.cpu().data.numpy() != int(risk))[0] ll += f[uncens].sum() + s[cens].sum() return -ll.mean() -def _weibull_loss(model, t, e): +def _weibull_loss(model, t, e, risk='1'): - shape, scale = model.get_shape_scale() + shape, scale = model.get_shape_scale(risk) k_ = shape.expand(t.shape[0], -1) b_ = scale.expand(t.shape[0], -1) @@ -107,29 +108,29 @@ def _weibull_loss(model, t, e): f = k + b + ((torch.exp(k)-1)*(b+torch.log(t))) f = f + s - uncens = np.where(e.cpu().data.numpy() == 1)[0] - cens = np.where(e.cpu().data.numpy() == 0)[0] + uncens = np.where(e.cpu().data.numpy() == int(risk))[0] + cens = np.where(e.cpu().data.numpy() != int(risk))[0] ll += f[uncens].sum() + s[cens].sum() return -ll.mean() -def unconditional_loss(model, t, e): +def unconditional_loss(model, t, e, risk='1'): if model.dist == 'Weibull': - return _weibull_loss(model, t, e) + return _weibull_loss(model, t, e, risk) elif model.dist == 'LogNormal': - return _lognormal_loss(model, t, e) + return _lognormal_loss(model, t, e, risk) elif model.dist == 'Normal': - return _normal_loss(model, t, e) + return _normal_loss(model, t, e, risk) else: raise NotImplementedError('Distribution: '+model.dist+ ' not implemented yet.') -def _conditional_normal_loss(model, x, t, e, elbo=True): +def _conditional_normal_loss(model, x, t, e, elbo=True, risk='1'): alpha = model.discount - shape, scale, logits = model.forward(x) + shape, scale, logits = model.forward(x, risk) lossf = [] losss = [] @@ -172,16 +173,16 @@ def _conditional_normal_loss(model, x, t, e, elbo=True): losss = torch.logsumexp(losss, dim=1) lossf = torch.logsumexp(lossf, dim=1) - uncens = np.where(e.cpu().data.numpy() == 1)[0] - cens = np.where(e.cpu().data.numpy() == 0)[0] + uncens = np.where(e.cpu().data.numpy() == int(risk))[0] + cens = np.where(e.cpu().data.numpy() != int(risk))[0] ll = lossf[uncens].sum() + alpha*losss[cens].sum() return -ll/float(len(uncens)+len(cens)) -def _conditional_lognormal_loss(model, x, t, e, elbo=True): +def _conditional_lognormal_loss(model, x, t, e, elbo=True, risk='1'): alpha = model.discount - shape, scale, logits = model.forward(x) + shape, scale, logits = model.forward(x, risk) lossf = [] losss = [] @@ -224,17 +225,17 @@ def _conditional_lognormal_loss(model, x, t, e, elbo=True): losss = torch.logsumexp(losss, dim=1) lossf = torch.logsumexp(lossf, dim=1) - uncens = np.where(e.cpu().data.numpy() == 1)[0] - cens = np.where(e.cpu().data.numpy() == 0)[0] + uncens = np.where(e.cpu().data.numpy() == int(risk))[0] + cens = np.where(e.cpu().data.numpy() != int(risk))[0] ll = lossf[uncens].sum() + alpha*losss[cens].sum() return -ll/float(len(uncens)+len(cens)) -def _conditional_weibull_loss(model, x, t, e, elbo=True): +def _conditional_weibull_loss(model, x, t, e, elbo=True, risk='1'): alpha = model.discount - shape, scale, logits = model.forward(x) + shape, scale, logits = model.forward(x, risk) k_ = shape b_ = scale @@ -273,31 +274,31 @@ def _conditional_weibull_loss(model, x, t, e, elbo=True): losss = torch.logsumexp(losss, dim=1) lossf = torch.logsumexp(lossf, dim=1) - uncens = np.where(e.cpu().data.numpy() == 1)[0] - cens = np.where(e.cpu().data.numpy() == 0)[0] + uncens = np.where(e.cpu().data.numpy() == int(risk))[0] + cens = np.where(e.cpu().data.numpy() != int(risk))[0] ll = lossf[uncens].sum() + alpha*losss[cens].sum() return -ll/float(len(uncens)+len(cens)) -def conditional_loss(model, x, t, e, elbo=True): +def conditional_loss(model, x, t, e, elbo=True, risk='1'): if model.dist == 'Weibull': - return _conditional_weibull_loss(model, x, t, e, elbo) + return _conditional_weibull_loss(model, x, t, e, elbo, risk) elif model.dist == 'LogNormal': - return _conditional_lognormal_loss(model, x, t, e, elbo) + return _conditional_lognormal_loss(model, x, t, e, elbo, risk) elif model.dist == 'Normal': - return _conditional_normal_loss(model, x, t, e, elbo) + return _conditional_normal_loss(model, x, t, e, elbo, risk) else: raise NotImplementedError('Distribution: '+model.dist+ ' not implemented yet.') -def _weibull_cdf(model, x, t_horizon): +def _weibull_cdf(model, x, t_horizon, risk='1'): squish = nn.LogSoftmax(dim=1) - shape, scale, logits = model.forward(x) + shape, scale, logits = model.forward(x, risk) logits = squish(logits) k_ = shape @@ -327,11 +328,11 @@ def _weibull_cdf(model, x, t_horizon): return cdfs -def _lognormal_cdf(model, x, t_horizon): +def _lognormal_cdf(model, x, t_horizon, risk='1'): squish = nn.LogSoftmax(dim=1) - shape, scale, logits = model.forward(x) + shape, scale, logits = model.forward(x, risk) logits = squish(logits) k_ = shape @@ -364,11 +365,11 @@ def _lognormal_cdf(model, x, t_horizon): return cdfs -def _normal_cdf(model, x, t_horizon): +def _normal_cdf(model, x, t_horizon, risk='1'): squish = nn.LogSoftmax(dim=1) - shape, scale, logits = model.forward(x) + shape, scale, logits = model.forward(x, risk) logits = squish(logits) k_ = shape @@ -401,10 +402,10 @@ def _normal_cdf(model, x, t_horizon): return cdfs -def _normal_mean(model, x): +def _normal_mean(model, x, risk='1'): squish = nn.Softmax(dim=1) - shape, scale, logits = model.forward(x) + shape, scale, logits = model.forward(x, risk) logits = squish(logits) k_ = shape @@ -423,10 +424,10 @@ def _normal_mean(model, x): return lmeans.detach().numpy() -def predict_mean(model, x): +def predict_mean(model, x, risk='1'): torch.no_grad() if model.dist == 'Normal': - return _normal_mean(model, x) + return _normal_mean(model, x, risk) else: raise NotImplementedError('Mean of Distribution: '+model.dist+ ' not implemented yet.') @@ -434,14 +435,14 @@ def predict_mean(model, x): -def predict_cdf(model, x, t_horizon): +def predict_cdf(model, x, t_horizon, risk='1'): torch.no_grad() if model.dist == 'Weibull': - return _weibull_cdf(model, x, t_horizon) + return _weibull_cdf(model, x, t_horizon, risk) if model.dist == 'LogNormal': - return _lognormal_cdf(model, x, t_horizon) + return _lognormal_cdf(model, x, t_horizon, risk) if model.dist == 'Normal': - return _normal_cdf(model, x, t_horizon) + return _normal_cdf(model, x, t_horizon, risk) else: raise NotImplementedError('Distribution: '+model.dist+ ' not implemented yet.') diff --git a/dsm/utilities.py b/dsm/utilities.py index 576eb8e..61f1725 100644 --- a/dsm/utilities.py +++ b/dsm/utilities.py @@ -53,7 +53,8 @@ def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, n_iter=10000, lr=1e-2, thres=1e-4): premodel = DeepSurvivalMachinesTorch(1, 1, - dist=model.dist) + dist=model.dist, + risks=model.risks) premodel.double() optimizer = torch.optim.Adam(premodel.parameters(), lr=lr) @@ -64,14 +65,18 @@ def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, for _ in tqdm(range(n_iter)): optimizer.zero_grad() - loss = unconditional_loss(premodel, t_train, e_train) + loss = 0 + for r in range(model.risks): + loss += unconditional_loss(premodel, t_train, e_train, str(r+1)) loss.backward() optimizer.step() - valid_loss = unconditional_loss(premodel, t_valid, e_valid) + valid_loss = 0 + for r in range(model.risks): + valid_loss += unconditional_loss(premodel, t_valid, e_valid, str(r+1)) valid_loss = valid_loss.detach().cpu().numpy() costs.append(valid_loss) - + #print(valid_loss) if np.abs(costs[-1] - oldcost) < thres: patience += 1 if patience == 3: @@ -126,8 +131,12 @@ def train_dsm(model, n_iter=10000, lr=1e-2, thres=1e-4) - model.shape.data.fill_(float(premodel.shape)) - model.scale.data.fill_(float(premodel.scale)) + + for r in range(model.risks): + model.shape[str(r+1)].data.fill_(float(premodel.shape[str(r+1)])) + model.scale[str(r+1)].data.fill_(float(premodel.scale[str(r+1)])) + + print(premodel.shape, premodel.scale) model.double() optimizer = torch.optim.Adam(model.parameters(), lr=lr) @@ -148,20 +157,24 @@ def train_dsm(model, eb = e_train[j*bs:(j+1)*bs] optimizer.zero_grad() - loss = conditional_loss(model, - xb, - _reshape_tensor_with_nans(tb), - _reshape_tensor_with_nans(eb), - elbo=elbo) + loss = 0 + for r in range(model.risks): + loss += conditional_loss(model, + xb, + _reshape_tensor_with_nans(tb), + _reshape_tensor_with_nans(eb), + elbo=elbo, + risk=str(r+1)) #print ("Train Loss:", float(loss)) loss.backward() optimizer.step() - - valid_loss = conditional_loss(model, - x_valid, - t_valid_, - e_valid_, - elbo=False) + valid_loss = 0 + for r in range(model.risks): + valid_loss += conditional_loss(model, + x_valid, + t_valid_, + e_valid_, + elbo=False) valid_loss = valid_loss.detach().cpu().numpy() costs.append(float(valid_loss))