Skip to content

Commit

Permalink
Merge pull request #15 from autonlab/competing_risks
Browse files Browse the repository at this point in the history
Added support for competing risks
  • Loading branch information
Chirag Nagpal authored Dec 2, 2020
2 parents cfea0f3 + 0a78534 commit 96e4ed9
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 133 deletions.
38 changes: 21 additions & 17 deletions dsm/datasets.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
# coding=utf-8
# Copyright 2020 Chirag Nagpal
#
# This file is part of Deep Survival Machines.

# Deep Survival Machines is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# Deep Survival Machines is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with Deep Survival Machines.
# If not, see <https://www.gnu.org/licenses/>.
# MIT License

# Copyright (c) 2020 Carnegie Mellon University, Auton Lab

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


"""Utility functions to load standard datasets to train and evaluate the
Expand Down
52 changes: 37 additions & 15 deletions dsm/dsm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from dsm.dsm_torch import DeepSurvivalMachinesTorch
from dsm.dsm_torch import DeepRecurrentSurvivalMachinesTorch
from dsm.losses import predict_cdf
import dsm.losses as losses
from dsm.utilities import train_dsm, _get_padded_features, _get_padded_targets

import torch
Expand All @@ -52,15 +53,16 @@ def __init__(self, k=3, layers=None, distribution="Weibull",
self.discount = discount
self.fitted = False

def _gen_torch_model(self, inputdim, optimizer):
def _gen_torch_model(self, inputdim, optimizer, risks):
"""Helper function to return a torch model."""
return DeepSurvivalMachinesTorch(inputdim,
k=self.k,
layers=self.layers,
dist=self.dist,
temp=self.temp,
discount=self.discount,
optimizer=optimizer)
optimizer=optimizer,
risks=risks)

def fit(self, x, t, e, vsize=0.15,
iters=1, learning_rate=1e-3, batch_size=100,
Expand Down Expand Up @@ -102,8 +104,8 @@ def fit(self, x, t, e, vsize=0.15,
x_train, t_train, e_train, x_val, t_val, e_val = processed_data

inputdim = x_train.shape[-1]

model = self._gen_torch_model(inputdim, optimizer)
maxrisk = int(e_train.max())
model = self._gen_torch_model(inputdim, optimizer, risks=maxrisk)
model, _ = train_dsm(model,
x_train, t_train, e_train,
x_val, t_val, e_val,
Expand Down Expand Up @@ -139,8 +141,27 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state):
return (x_train, t_train, e_train,
x_val, t_val, e_val)

def predict_mean(self, x, risk=1):
r"""Returns the mean Time-to-Event \( t \)
Parameters
----------
x: np.ndarray
A numpy array of the input features, \( x \).
Returns:
np.array: numpy array of the mean time to event.
def predict_risk(self, x, t):
"""

if self.fitted:
x = self._prepocess_test_data(x)
scores = losses.predict_mean(self.torch_model, x, risk=str(risk))
return scores
else:
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `predict_mean`.")
def predict_risk(self, x, t, risk=1):
r"""Returns the estimated risk of an event occuring before time \( t \)
\( \widehat{\mathbb{P}}(T\leq t|X) \) for some input data \( x \).
Expand All @@ -157,14 +178,14 @@ def predict_risk(self, x, t):
"""

if self.fitted:
return 1-self.predict_survival(x, t)
return 1-self.predict_survival(x, t, risk=str(risk))
else:
raise Exception("The model has not been fitted yet. Please fit the " +
"model using the `fit` method on some training data " +
"before calling `predict_survival`.")
"before calling `predict_risk`.")


def predict_survival(self, x, t):
def predict_survival(self, x, t, risk=1):
r"""Returns the estimated survival probability at time \( t \),
\( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \).
Expand All @@ -183,7 +204,7 @@ def predict_survival(self, x, t):
if not isinstance(t, list):
t = [t]
if self.fitted:
scores = predict_cdf(self.torch_model, x, t)
scores = predict_cdf(self.torch_model, x, t, risk=str(risk))
return np.exp(np.array(scores)).T
else:
raise Exception("The model has not been fitted yet. Please fit the " +
Expand Down Expand Up @@ -255,12 +276,14 @@ class DeepRecurrentSurvivalMachines(DSMBase):

def __init__(self, k=3, layers=None, hidden=None,
distribution='Weibull', temp=1000., discount=1.0, typ='LSTM'):
super(DeepRecurrentSurvivalMachines, self).__init__(k=k, layers=layers,
super(DeepRecurrentSurvivalMachines, self).__init__(k=k,
layers=layers,
distribution=distribution,
temp=temp, discount=discount)
temp=temp,
discount=discount)
self.hidden = hidden
self.typ = typ
def _gen_torch_model(self, inputdim, optimizer):
def _gen_torch_model(self, inputdim, optimizer, risks):
"""Helper function to return a torch model."""
return DeepRecurrentSurvivalMachinesTorch(inputdim,
k=self.k,
Expand All @@ -270,7 +293,8 @@ def _gen_torch_model(self, inputdim, optimizer):
temp=self.temp,
discount=self.discount,
optimizer=optimizer,
typ=self.typ)
typ=self.typ,
risks=risks)

def _prepocess_test_data(self, x):
return torch.from_numpy(_get_padded_features(x))
Expand All @@ -286,8 +310,6 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state):
t = _get_padded_targets(t)
e = _get_padded_targets(e)

print (x.shape)

x_train, t_train, e_train = x[idx], t[idx], e[idx]

x_train = torch.from_numpy(x_train).double()
Expand Down
119 changes: 81 additions & 38 deletions dsm/dsm_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,43 +133,64 @@ class DeepSurvivalMachinesTorch(nn.Module):
"""

def __init__(self, inputdim, k, layers=None, dist='Weibull',
temp=1000., discount=1.0, optimizer='Adam'):
temp=1000., discount=1.0, optimizer='Adam',
risks=1):
super(DeepSurvivalMachinesTorch, self).__init__()

self.k = k
self.dist = dist
self.temp = float(temp)
self.discount = float(discount)
self.optimizer = optimizer
self.risks = risks

if layers is None:
layers = []
self.layers = layers

if self.dist == 'Weibull':
if self.dist in ['Weibull']:
self.act = nn.SELU()
self.scale = nn.Parameter(-torch.ones(k))
self.shape = nn.Parameter(-torch.ones(k))
elif self.dist == 'LogNormal':
self.shape = nn.ParameterDict({str(r+1): nn.Parameter(-torch.ones(k))
for r in range(self.risks)})
self.scale = nn.ParameterDict({str(r+1):nn.Parameter(-torch.ones(k))
for r in range(self.risks)})
elif self.dist in ['Normal']:
self.act = nn.Identity()
self.shape = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k))
for r in range(self.risks)})
self.scale = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k))
for r in range(self.risks)})
elif self.dist in ['LogNormal']:
self.act = nn.Tanh()
self.scale = nn.Parameter(torch.ones(k))
self.shape = nn.Parameter(torch.ones(k))
self.shape = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k))
for r in range(self.risks)})
self.scale = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k))
for r in range(self.risks)})
else:
raise NotImplementedError('Distribution: '+self.dist+' not implemented'+
' yet.')

self.embedding = create_representation(inputdim, layers, 'ReLU6')

if len(layers) == 0:
self.gate = nn.Sequential(nn.Linear(inputdim, k, bias=False))
self.scaleg = nn.Sequential(nn.Linear(inputdim, k, bias=True))
self.shapeg = nn.Sequential(nn.Linear(inputdim, k, bias=True))
lastdim = inputdim
else:
self.gate = nn.Sequential(nn.Linear(layers[-1], k, bias=False))
self.scaleg = nn.Sequential(nn.Linear(layers[-1], k, bias=True))
self.shapeg = nn.Sequential(nn.Linear(layers[-1], k, bias=True))
lastdim = layers[-1]

def forward(self, x):
self.gate = nn.ModuleDict({str(r+1): nn.Sequential(
nn.Linear(lastdim, k, bias=False)
) for r in range(self.risks)})

self.scaleg = nn.ModuleDict({str(r+1): nn.Sequential(
nn.Linear(lastdim, k, bias=True)
) for r in range(self.risks)})

self.shapeg = nn.ModuleDict({str(r+1): nn.Sequential(
nn.Linear(lastdim, k, bias=True)
) for r in range(self.risks)})


def forward(self, x, risk='1'):
"""The forward function that is called when data is passed through DSM.
Args:
Expand All @@ -178,13 +199,14 @@ def forward(self, x):
"""
xrep = self.embedding(x)
return(self.act(self.shapeg(xrep))+self.shape.expand(x.shape[0], -1),
self.act(self.scaleg(xrep))+self.scale.expand(x.shape[0], -1),
self.gate(xrep)/self.temp)
dim = x.shape[0]
return(self.act(self.shapeg[risk](xrep))+self.shape[risk].expand(dim, -1),
self.act(self.scaleg[risk](xrep))+self.scale[risk].expand(dim, -1),
self.gate[risk](xrep)/self.temp)

def get_shape_scale(self):
return(self.shape,
self.scale)
def get_shape_scale(self, risk='1'):
return(self.shape[risk],
self.scale[risk])

class DeepRecurrentSurvivalMachinesTorch(nn.Module):
"""A Torch implementation of Deep Recurrent Survival Machines model.
Expand Down Expand Up @@ -229,7 +251,8 @@ class DeepRecurrentSurvivalMachinesTorch(nn.Module):

def __init__(self, inputdim, k, typ='LSTM', layers=1,
hidden=None, dist='Weibull',
temp=1000., discount=1.0, optimizer='Adam'):
temp=1000., discount=1.0,
optimizer='Adam', risks=1):
super(DeepRecurrentSurvivalMachinesTorch, self).__init__()

self.k = k
Expand All @@ -240,22 +263,41 @@ def __init__(self, inputdim, k, typ='LSTM', layers=1,
self.hidden = hidden
self.layers = layers
self.typ = typ
self.risks = risks

if self.dist == 'Weibull':
if self.dist in ['Weibull']:
self.act = nn.SELU()
self.scale = nn.Parameter(-torch.ones(k))
self.shape = nn.Parameter(-torch.ones(k))
elif self.dist == 'LogNormal':
self.shape = nn.ParameterDict({str(r+1): nn.Parameter(-torch.ones(k))
for r in range(self.risks)})
self.scale = nn.ParameterDict({str(r+1):nn.Parameter(-torch.ones(k))
for r in range(self.risks)})
elif self.dist in ['Normal']:
self.act = nn.Identity()
self.shape = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k))
for r in range(self.risks)})
self.scale = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k))
for r in range(self.risks)})
elif self.dist in ['LogNormal']:
self.act = nn.Tanh()
self.scale = nn.Parameter(torch.ones(k))
self.shape = nn.Parameter(torch.ones(k))
self.shape = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k))
for r in range(self.risks)})
self.scale = nn.ParameterDict({str(r+1): nn.Parameter(torch.ones(k))
for r in range(self.risks)})
else:
raise NotImplementedError('Distribution: '+self.dist+' not implemented'+
' yet.')

self.gate = nn.Sequential(nn.Linear(hidden, k, bias=False))
self.scaleg = nn.Sequential(nn.Linear(hidden, k, bias=True))
self.shapeg = nn.Sequential(nn.Linear(hidden, k, bias=True))
self.gate = nn.ModuleDict({str(r+1): nn.Sequential(
nn.Linear(hidden, k, bias=False)
) for r in range(self.risks)})

self.scaleg = nn.ModuleDict({str(r+1): nn.Sequential(
nn.Linear(hidden, k, bias=True)
) for r in range(self.risks)})

self.shapeg = nn.ModuleDict({str(r+1): nn.Sequential(
nn.Linear(hidden, k, bias=True)
) for r in range(self.risks)})

if self.typ == 'LSTM':
self.embedding = nn.LSTM(inputdim, hidden, layers,
Expand All @@ -268,7 +310,7 @@ def __init__(self, inputdim, k, typ='LSTM', layers=1,
self.embedding = nn.GRU(inputdim, hidden, layers,
bias=False, batch_first=True)

def forward(self, x):
def forward(self, x, risk='1'):
"""The forward function that is called when data is passed through DSM.
Note: As compared to DSM, the input data for DRSM is a tensor. The forward
Expand All @@ -287,10 +329,11 @@ def forward(self, x):
xrep = xrep.contiguous().view(-1, self.hidden)
xrep = xrep[inputmask]
xrep = nn.ReLU6()(xrep)
return(self.act(self.shapeg(xrep))+self.shape.expand(xrep.shape[0], -1),
self.act(self.scaleg(xrep))+self.scale.expand(xrep.shape[0], -1),
self.gate(xrep)/self.temp)

def get_shape_scale(self):
return(self.shape,
self.scale)
dim = xrep.shape[0]
return(self.act(self.shapeg[risk](xrep))+self.shape[risk].expand(dim, -1),
self.act(self.scaleg[risk](xrep))+self.scale[risk].expand(dim, -1),
self.gate[risk](xrep)/self.temp)

def get_shape_scale(self, risk='1'):
return(self.shape[risk],
self.scale[risk])
Loading

0 comments on commit 96e4ed9

Please sign in to comment.