# coding=utf-8
-# Copyright 2020 Chirag Nagpal, Auton Lab.
+# Copyright 2020 Chirag Nagpal
#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# This file is part of Deep Survival Machines.
+
+# Deep Survival Machines is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# Deep Survival Machines is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with Deep Survival Machines.
+# If not, see <https://www.gnu.org/licenses/>.
+
"""Utility functions to load standard datasets to train and evaluate the
Deep Survival Machines models.
@@ -75,13 +80,21 @@
Module dsm.datasets
return e, t
-def _load_pbc_dataset():
+def _load_pbc_dataset(sequential):
"""Helper function to load and preprocess the PBC dataset
The Primary biliary cirrhosis (PBC) Dataset [1] is well known
dataset for evaluating survival analysis models with time
dependent covariates.
+ Parameters
+ ----------
+ sequential: bool
+ If True returns a list of np.arrays for each individual.
+ else, returns collapsed results for each time step. To train
+ recurrent neural models you would typically use True.
+
+
References
----------
[1] Fleming, Thomas R., and David P. Harrington. Counting processes and
@@ -89,7 +102,36 @@
Module dsm.datasets
"""
- raise NotImplementedError('')
+ data = pkgutil.get_data(__name__, 'datasets/pbc2.csv')
+ data = pd.read_csv(io.BytesIO(data))
+
+ data['histologic'] = data['histologic'].astype(str)
+ dat_cat = data[['drug', 'sex', 'ascites', 'hepatomegaly',
+ 'spiders', 'edema', 'histologic']]
+ dat_num = data[['serBilir', 'serChol', 'albumin', 'alkaline',
+ 'SGOT', 'platelets', 'prothrombin']]
+ age = data['age'] + data['years']
+
+ x1 = pd.get_dummies(dat_cat).values
+ x2 = dat_num.values
+ x3 = age.values.reshape(-1, 1)
+ x = np.hstack([x1, x2, x3])
+
+ time = (data['years'] - data['year']).values
+ event = data['status2'].values
+
+ x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x)
+ x_ = StandardScaler().fit_transform(x)
+
+ if not sequential:
+ return x_, time, event
+ else:
+ x, t, e = [], [], []
+ for id_ in sorted(list(set(data['id']))):
+ x.append(x_[data['id'] == id_])
+ t.append(time[data['id'] == id_])
+ e.append(event[data['id'] == id_])
+ return x, t, e
def _load_support_dataset():
"""Helper function to load and preprocess the SUPPORT dataset.
@@ -128,13 +170,16 @@
Module dsm.datasets
return x[remove], t[remove], e[remove]
-def load_dataset(dataset='SUPPORT'):
+def load_dataset(dataset='SUPPORT', **kwargs):
"""Helper function to load datasets to test Survival Analysis models.
Parameters
----------
dataset: str
- The choice of dataset to load. Currently implemented is 'SUPPORT'.
+ The choice of dataset to load. Currently implemented is 'SUPPORT'
+ and 'PBC'.
+ **kwargs: dict
+ Dataset specific keyword arguments.
Returns
----------
@@ -146,6 +191,9 @@
Module dsm.datasets
if dataset == 'SUPPORT':
return _load_support_dataset()
+ if dataset == 'PBC':
+ sequential = kwargs.get('sequential', False)
+ return _load_pbc_dataset(sequential)
else:
return NotImplementedError('Dataset '+dataset+' not implemented.')
Helper function to load datasets to test Survival Analysis models.
Parameters
dataset : str
-
The choice of dataset to load. Currently implemented is 'SUPPORT'.
+
The choice of dataset to load. Currently implemented is 'SUPPORT'
+and 'PBC'.
+
**kwargs : dict
+
Dataset specific keyword arguments.
Returns
@@ -203,13 +254,16 @@
Returns
Expand source code
-
def load_dataset(dataset='SUPPORT'):
+
def load_dataset(dataset='SUPPORT', **kwargs):
"""Helper function to load datasets to test Survival Analysis models.
Parameters
----------
dataset: str
- The choice of dataset to load. Currently implemented is 'SUPPORT'.
+ The choice of dataset to load. Currently implemented is 'SUPPORT'
+ and 'PBC'.
+ **kwargs: dict
+ Dataset specific keyword arguments.
Returns
----------
@@ -221,6 +275,9 @@
Returns
if dataset == 'SUPPORT':
return _load_support_dataset()
+ if dataset == 'PBC':
+ sequential = kwargs.get('sequential', False)
+ return _load_pbc_dataset(sequential)
else:
return NotImplementedError('Dataset '+dataset+' not implemented.')
# coding=utf-8
-# Copyright 2020 Chirag Nagpal, Auton Lab.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import io, pkgutil
-
-import pandas as pd
-import numpy as np
-
-from sklearn.impute import SimpleImputer
-from sklearn.preprocessing import StandardScaler
-
-def increase_censoring(e, t, p):
-
- uncens = np.where(e == 1)[0]
- mask = np.random.choice([False, True], len(uncens), p=[1-p, p])
- toswitch = uncens[mask]
-
- e[toswitch] = 0
- t_ = t[toswitch]
-
- newt = []
- for t__ in t_:
- newt.append(np.random.uniform(1, t__))
- t[toswitch] = newt
-
- return e, t
-
-def load_support_dataset():
-
- """Helper function to load and preprocess the SUPPORT dataset.
-
- The SUPPORT Dataset comes from the Vanderbilt University study
- to estimate survival for seriously ill hospitalized adults [1].
-
- Please refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc.
- for the original datasource.
-
- [1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic
- model: Objective estimates of survival for seriously ill hospitalized
- adults. Annals of Internal Medicine 122:191-203.
-
- """
- data = pkgutil.get_data(__name__, 'datasets/support2.csv')
- data = pd.read_csv(io.BytesIO(data))
- x1 = data[['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp',
- 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun',
- 'urine', 'adlp', 'adls']]
-
- catfeats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
- x2 = pd.get_dummies(data[catfeats])
-
- x = np.concatenate([x1, x2], axis=1)
- t = data['d.time'].values
- e = data['death'].values
-
- x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x)
- x = StandardScaler().fit_transform(x)
-
- remove = ~np.isnan(t)
- return x[remove], t[remove], e[remove]
-
-
-def load_dataset(dataset='SUPPORT'):
- """Helper function to load datasets to test Survival Analysis models.
-
- Parameters
- ----------
- dataset: str
- The choice of dataset to load. Currently implemented is 'SUPPORT'.
-
- Returns
- ----------
- tuple: (np.ndarray, np.ndarray, np.ndarray)
- A tuple of the form of (x, t, e) where x, t, e are the input covariates,
- event times and the censoring indicators respectively.
-
- """
-
- if dataset == 'SUPPORT':
- return _load_support_dataset()
Helper function to load datasets to test Survival Analysis models.
-
Parameters
-
-
dataset : str
-
The choice of dataset to load. Currently implemented is 'SUPPORT'.
-
-
Returns
-
-
tuple : (np.ndarray, np.ndarray, np.ndarray)
-
A tuple of the form of (x, t, e) where x, t, e are the input covariates,
-event times and the censoring indicators respectively.
-
-
-
-Expand source code
-
-
def load_dataset(dataset='SUPPORT'):
- """Helper function to load datasets to test Survival Analysis models.
-
- Parameters
- ----------
- dataset: str
- The choice of dataset to load. Currently implemented is 'SUPPORT'.
-
- Returns
- ----------
- tuple: (np.ndarray, np.ndarray, np.ndarray)
- A tuple of the form of (x, t, e) where x, t, e are the input covariates,
- event times and the censoring indicators respectively.
-
- """
-
- if dataset == 'SUPPORT':
- return _load_support_dataset()
-
-
-
-def load_support_dataset()
-
-
-
Helper function to load and preprocess the SUPPORT dataset.
-
The SUPPORT Dataset comes from the Vanderbilt University study
-to estimate survival for seriously ill hospitalized adults [1].
[1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic
-model: Objective estimates of survival for seriously ill hospitalized
-adults. Annals of Internal Medicine 122:191-203.
-
-
-Expand source code
-
-
def load_support_dataset():
-
- """Helper function to load and preprocess the SUPPORT dataset.
-
- The SUPPORT Dataset comes from the Vanderbilt University study
- to estimate survival for seriously ill hospitalized adults [1].
-
- Please refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc.
- for the original datasource.
-
- [1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic
- model: Objective estimates of survival for seriously ill hospitalized
- adults. Annals of Internal Medicine 122:191-203.
-
- """
- data = pkgutil.get_data(__name__, 'datasets/support2.csv')
- data = pd.read_csv(io.BytesIO(data))
- x1 = data[['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp',
- 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun',
- 'urine', 'adlp', 'adls']]
-
- catfeats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
- x2 = pd.get_dummies(data[catfeats])
-
- x = np.concatenate([x1, x2], axis=1)
- t = data['d.time'].values
- e = data['death'].values
-
- x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x)
- x = StandardScaler().fit_transform(x)
-
- remove = ~np.isnan(t)
- return x[remove], t[remove], e[remove]
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/docs/dsm_api.html b/docs/dsm_api.html
index 8d180e5..d8009a2 100644
--- a/docs/dsm_api.html
+++ b/docs/dsm_api.html
@@ -31,19 +31,24 @@
Module dsm.dsm_api
Expand source code
# coding=utf-8
-# Copyright 2020 Chirag Nagpal, Auton Lab.
+# Copyright 2020 Chirag Nagpal
#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# This file is part of Deep Survival Machines.
+
+# Deep Survival Machines is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# Deep Survival Machines is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with Deep Survival Machines.
+# If not, see <https://www.gnu.org/licenses/>.
+
"""
This module is a wrapper around torch implementations and
@@ -55,7 +60,6 @@
Module dsm.dsm_api
from dsm.utilities import train_dsm
import torch
-
import numpy as np
class DeepSurvivalMachines():
@@ -70,7 +74,7 @@
Module dsm.dsm_api
References
----------
- [1] <a href="https://arxiv.org/abs/2003.01176">Deep Survival Machines:
+ [1] <a href="https://arxiv.org/abs/2003.01176">Deep Survival Machines:
Fully Parametric Survival Regression and
Representation Learning for Censored Data with Competing Risks."
arXiv preprint arXiv:2003.01176 (2020)</a>
@@ -124,7 +128,7 @@
learning is performed on mini-batches of input data. this parameter
specifies the size of each mini-batch.
elbo: bool
- Whether to use the Evidence Lower Bound for Optimization.
+ Whether to use the Evidence Lower Bound for optimization.
Default is True.
optimizer: str
- The choice of the gradient based optimization method. One of
+ The choice of the gradient based optimization method. One of
'Adam', 'RMSProp' or 'SGD'.
random_state: float
random seed that determines how the validation set is chosen.
@@ -182,7 +186,6 @@
def predict_risk(self, x, t):
- """Returns the estimated risk of an event occuring before time \( t \),
+ """Returns the estimated risk of an event occuring before time \( t \)
\( \widehat{\mathbb{P}}(T\leq t|X) \) for some input data \( x \).
Parameters
@@ -396,7 +399,7 @@
Example
References
----------
- [1] <a href="https://arxiv.org/abs/2003.01176">Deep Survival Machines:
+ [1] <a href="https://arxiv.org/abs/2003.01176">Deep Survival Machines:
Fully Parametric Survival Regression and
Representation Learning for Censored Data with Competing Risks."
arXiv preprint arXiv:2003.01176 (2020)</a>
@@ -450,7 +453,7 @@
learning is performed on mini-batches of input data. this parameter
specifies the size of each mini-batch.
elbo: bool
- Whether to use the Evidence Lower Bound for Optimization.
+ Whether to use the Evidence Lower Bound for optimization.
Default is True.
optimizer: str
- The choice of the gradient based optimization method. One of
+ The choice of the gradient based optimization method. One of
'Adam', 'RMSProp' or 'SGD'.
random_state: float
random seed that determines how the validation set is chosen.
@@ -508,7 +511,6 @@
def predict_risk(self, x, t):
- """Returns the estimated risk of an event occuring before time \( t \),
+ """Returns the estimated risk of an event occuring before time \( t \)
\( \widehat{\mathbb{P}}(T\leq t|X) \) for some input data \( x \).
Parameters
@@ -613,7 +615,7 @@
Parameters
learning is performed on mini-batches of input data. this parameter
specifies the size of each mini-batch.
elbo : bool
-
Whether to use the Evidence Lower Bound for Optimization.
+
Whether to use the Evidence Lower Bound for optimization.
Default is True.
optimizer : str
The choice of the gradient based optimization method. One of
@@ -625,7 +627,7 @@
learning is performed on mini-batches of input data. this parameter
specifies the size of each mini-batch.
elbo: bool
- Whether to use the Evidence Lower Bound for Optimization.
+ Whether to use the Evidence Lower Bound for optimization.
Default is True.
optimizer: str
- The choice of the gradient based optimization method. One of
+ The choice of the gradient based optimization method. One of
'Adam', 'RMSProp' or 'SGD'.
random_state: float
random seed that determines how the validation set is chosen.
@@ -683,7 +685,6 @@
Returns the estimated risk of an event occuring before time t ,
+
Returns the estimated risk of an event occuring before time t \widehat{\mathbb{P}}(T\leq t|X) for some input data x .
Parameters
@@ -728,7 +729,7 @@
Returns
Expand source code
def predict_risk(self, x, t):
- """Returns the estimated risk of an event occuring before time \( t \),
+ """Returns the estimated risk of an event occuring before time \( t \)
\( \widehat{\mathbb{P}}(T\leq t|X) \) for some input data \( x \).
Parameters
diff --git a/docs/dsm_torch.html b/docs/dsm_torch.html
index a823b89..36d8a1d 100644
--- a/docs/dsm_torch.html
+++ b/docs/dsm_torch.html
@@ -33,19 +33,24 @@
Module dsm.dsm_torch
Expand source code
# coding=utf-8
-# Copyright 2020 Chirag Nagpal, Auton Lab.
+# Copyright 2020 Chirag Nagpal
#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# This file is part of Deep Survival Machines.
+
+# Deep Survival Machines is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# Deep Survival Machines is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with Deep Survival Machines.
+# If not, see <https://www.gnu.org/licenses/>.
+
"""Torch model definitons for the Deep Survival Machines model
@@ -63,8 +68,8 @@
Module dsm.dsm_torch
def create_representation(inputdim, layers, activation):
"""Helper function to generate the representation function for DSM.
- Deep Survival Machines learns a representation $$\Phi(X)$$ for the input data.
- This representation is parameterized using a Non Linear Multilayer
+ Deep Survival Machines learns a representation (\ Phi(X) \) for the input
+ data. This representation is parameterized using a Non Linear Multilayer
Perceptron (`torch.nn.Module`). This is a helper function designed to
instantiate the representation for Deep Survival Machines.
@@ -145,7 +150,7 @@
self.scaleg = nn.Sequential(nn.Linear(layers[-1], k, bias=True))
self.shapeg = nn.Sequential(nn.Linear(layers[-1], k, bias=True))
- if init is not False:
- self.shape.data.fill_(init[0])
- self.scale.data.fill_(init[1])
-
def forward(self, x):
"""The forward function that is called when data is passed through DSM.
@@ -197,6 +200,112 @@
Module dsm.dsm_torch
self.act(self.scaleg(xrep))+self.scale.expand(x.shape[0], -1),
self.gate(xrep)/self.temp)
+ def get_shape_scale(self):
+ return(self.shape,
+ self.scale)
+
+class DeepRecurrentSurvivalMachinesTorch(DeepSurvivalMachinesTorch):
+ """A Torch implementation of Deep Recurrent Survival Machines model.
+
+ This is an implementation of Deep Recurrent Survival Machines model
+ in torch. It inherits from `DeepSurvivalMachinesTorch` and replaces the
+ input representation learning MLP with an LSTM or RNN, the parameters of the
+ underlying distributions and the forward function which is called whenever
+ data is passed to the module. Each of the parameters are nn.Parameters and
+ torch automatically keeps track and computes gradients for them.
+
+ .. warning::
+ Not designed to be used directly.
+ Please use the API inferface `dsm.dsm_api.DeepRecurrentSurvivalMachines`!!
+
+ Parameters
+ ----------
+ inputdim: int
+ Dimensionality of the input features.
+ k: int
+ The number of underlying parametric distributions.
+ layers: int
+ The number of hidden layers in the LSTM or RNN cell.
+ hidden: int
+ The number of neurons in each hidden layer.
+ init: tuple
+ A tuple for initialization of the parameters for the underlying
+ distributions. (shape, scale).
+ dist: str
+ Choice of the underlying survival distributions.
+ One of 'Weibull', 'LogNormal'.
+ Default is 'Weibull'.
+ temp: float
+ The logits for the gate are rescaled with this value.
+ Default is 1000.
+ discount: float
+ a float in [0,1] that determines how to discount the tail bias
+ from the uncensored instances.
+ Default is 1.
+ """
+
+ def __init__(self, inputdim, k, typ='LSTM', layers=1,
+ hidden=None, dist='Weibull',
+ temp=1000., discount=1.0, optimizer='Adam'):
+ super(DeepSurvivalMachinesTorch, self).__init__()
+
+ self.k = k
+ self.dist = dist
+ self.temp = float(temp)
+ self.discount = float(discount)
+ self.optimizer = optimizer
+ self.hidden = hidden
+ self.layers = layers
+ self.typ = typ
+
+ if self.dist == 'Weibull':
+ self.act = nn.SELU()
+ self.scale = nn.Parameter(-torch.ones(k))
+ self.shape = nn.Parameter(-torch.ones(k))
+ elif self.dist == 'LogNormal':
+ self.act = nn.Tanh()
+ self.scale = nn.Parameter(torch.ones(k))
+ self.shape = nn.Parameter(torch.ones(k))
+ else:
+ raise NotImplementedError('Distribution: '+self.dist+' not implemented'+
+ ' yet.')
+
+ self.gate = nn.Sequential(nn.Linear(hidden, k, bias=False))
+ self.scaleg = nn.Sequential(nn.Linear(hidden, k, bias=True))
+ self.shapeg = nn.Sequential(nn.Linear(hidden, k, bias=True))
+
+ if self.typ == 'LSTM':
+ self.embedding = nn.LSTM(inputdim, hidden, layers,
+ bias=False, batch_first=True)
+ if self.typ == 'RNN':
+ self.embedding = nn.RNN(inputdim, hidden, layers,
+ bias=False, batch_first=True)
+
+ #self.embedding = nn.ReLU6(self.embedding)
+
+
+ def forward(self, x):
+ """The forward function that is called when data is passed through DSM.
+
+ Note: As compared to DSM, the input data for DRSM is a tensor. The forward
+ function involves unpacking the tensor in-order to directly use the
+ DSM loss functions.
+
+ Args:
+ x:
+ a torch.tensor of the input features.
+ """
+ x = x.detach().clone()
+ inputmask = ~torch.isnan(x[:, :, 0]).reshape(-1)
+ x[torch.isnan(x)] = 0
+ xrep, _ = self.embedding(x)
+ xrep = xrep.contiguous().view(-1, self.hidden)
+ xrep = xrep[inputmask]
+ #xrep = nn.ReLU6()(xrep)
+ return(self.act(self.shapeg(xrep))+self.shape.expand(xrep.shape[0], -1),
+ self.act(self.scaleg(xrep))+self.scale.expand(xrep.shape[0], -1),
+ self.gate(xrep)/self.temp)
+
def get_shape_scale(self):
return(self.shape,
self.scale)
@@ -214,8 +323,8 @@
Functions
Helper function to generate the representation function for DSM.
-
Deep Survival Machines learns a representation \Phi(X) for the input data.
-This representation is parameterized using a Non Linear Multilayer
+
Deep Survival Machines learns a representation (\ Phi(X) ) for the input
+data. This representation is parameterized using a Non Linear Multilayer
Perceptron (torch.nn.Module). This is a helper function designed to
instantiate the representation for Deep Survival Machines.
@@ -240,8 +349,8 @@
Returns
def create_representation(inputdim, layers, activation):
"""Helper function to generate the representation function for DSM.
- Deep Survival Machines learns a representation $$\Phi(X)$$ for the input data.
- This representation is parameterized using a Non Linear Multilayer
+ Deep Survival Machines learns a representation (\ Phi(X) \) for the input
+ data. This representation is parameterized using a Non Linear Multilayer
Perceptron (`torch.nn.Module`). This is a helper function designed to
instantiate the representation for Deep Survival Machines.
@@ -285,9 +394,234 @@
A Torch implementation of Deep Recurrent Survival Machines model.
+
This is an implementation of Deep Recurrent Survival Machines model
+in torch. It inherits from DeepSurvivalMachinesTorch and replaces the
+input representation learning MLP with an LSTM or RNN, the parameters of the
+underlying distributions and the forward function which is called whenever
+data is passed to the module. Each of the parameters are nn.Parameters and
+torch automatically keeps track and computes gradients for them.
The number of underlying parametric distributions.
+
layers : int
+
The number of hidden layers in the LSTM or RNN cell.
+
hidden : int
+
The number of neurons in each hidden layer.
+
init : tuple
+
A tuple for initialization of the parameters for the underlying
+distributions. (shape, scale).
+
dist : str
+
Choice of the underlying survival distributions.
+One of 'Weibull', 'LogNormal'.
+Default is 'Weibull'.
+
temp : float
+
The logits for the gate are rescaled with this value.
+Default is 1000.
+
discount : float
+
a float in [0,1] that determines how to discount the tail bias
+from the uncensored instances.
+Default is 1.
+
+
Initializes internal Module state, shared by both nn.Module and ScriptModule.
+
+
+Expand source code
+
+
class DeepRecurrentSurvivalMachinesTorch(DeepSurvivalMachinesTorch):
+ """A Torch implementation of Deep Recurrent Survival Machines model.
+
+ This is an implementation of Deep Recurrent Survival Machines model
+ in torch. It inherits from `DeepSurvivalMachinesTorch` and replaces the
+ input representation learning MLP with an LSTM or RNN, the parameters of the
+ underlying distributions and the forward function which is called whenever
+ data is passed to the module. Each of the parameters are nn.Parameters and
+ torch automatically keeps track and computes gradients for them.
+
+ .. warning::
+ Not designed to be used directly.
+ Please use the API inferface `dsm.dsm_api.DeepRecurrentSurvivalMachines`!!
+
+ Parameters
+ ----------
+ inputdim: int
+ Dimensionality of the input features.
+ k: int
+ The number of underlying parametric distributions.
+ layers: int
+ The number of hidden layers in the LSTM or RNN cell.
+ hidden: int
+ The number of neurons in each hidden layer.
+ init: tuple
+ A tuple for initialization of the parameters for the underlying
+ distributions. (shape, scale).
+ dist: str
+ Choice of the underlying survival distributions.
+ One of 'Weibull', 'LogNormal'.
+ Default is 'Weibull'.
+ temp: float
+ The logits for the gate are rescaled with this value.
+ Default is 1000.
+ discount: float
+ a float in [0,1] that determines how to discount the tail bias
+ from the uncensored instances.
+ Default is 1.
+ """
+
+ def __init__(self, inputdim, k, typ='LSTM', layers=1,
+ hidden=None, dist='Weibull',
+ temp=1000., discount=1.0, optimizer='Adam'):
+ super(DeepSurvivalMachinesTorch, self).__init__()
+
+ self.k = k
+ self.dist = dist
+ self.temp = float(temp)
+ self.discount = float(discount)
+ self.optimizer = optimizer
+ self.hidden = hidden
+ self.layers = layers
+ self.typ = typ
+
+ if self.dist == 'Weibull':
+ self.act = nn.SELU()
+ self.scale = nn.Parameter(-torch.ones(k))
+ self.shape = nn.Parameter(-torch.ones(k))
+ elif self.dist == 'LogNormal':
+ self.act = nn.Tanh()
+ self.scale = nn.Parameter(torch.ones(k))
+ self.shape = nn.Parameter(torch.ones(k))
+ else:
+ raise NotImplementedError('Distribution: '+self.dist+' not implemented'+
+ ' yet.')
+
+ self.gate = nn.Sequential(nn.Linear(hidden, k, bias=False))
+ self.scaleg = nn.Sequential(nn.Linear(hidden, k, bias=True))
+ self.shapeg = nn.Sequential(nn.Linear(hidden, k, bias=True))
+
+ if self.typ == 'LSTM':
+ self.embedding = nn.LSTM(inputdim, hidden, layers,
+ bias=False, batch_first=True)
+ if self.typ == 'RNN':
+ self.embedding = nn.RNN(inputdim, hidden, layers,
+ bias=False, batch_first=True)
+
+ #self.embedding = nn.ReLU6(self.embedding)
+
+
+ def forward(self, x):
+ """The forward function that is called when data is passed through DSM.
+
+ Note: As compared to DSM, the input data for DRSM is a tensor. The forward
+ function involves unpacking the tensor in-order to directly use the
+ DSM loss functions.
+
+ Args:
+ x:
+ a torch.tensor of the input features.
+ """
+ x = x.detach().clone()
+ inputmask = ~torch.isnan(x[:, :, 0]).reshape(-1)
+ x[torch.isnan(x)] = 0
+ xrep, _ = self.embedding(x)
+ xrep = xrep.contiguous().view(-1, self.hidden)
+ xrep = xrep[inputmask]
+ #xrep = nn.ReLU6()(xrep)
+ return(self.act(self.shapeg(xrep))+self.shape.expand(xrep.shape[0], -1),
+ self.act(self.scaleg(xrep))+self.scale.expand(xrep.shape[0], -1),
+ self.gate(xrep)/self.temp)
+
+ def get_shape_scale(self):
+ return(self.shape,
+ self.scale)
The forward function that is called when data is passed through DSM.
+
Note: As compared to DSM, the input data for DRSM is a tensor. The forward
+function involves unpacking the tensor in-order to directly use the
+DSM loss functions.
+
Args
+
x:
+a torch.tensor of the input features.
+
+
+Expand source code
+
+
def forward(self, x):
+ """The forward function that is called when data is passed through DSM.
+
+ Note: As compared to DSM, the input data for DRSM is a tensor. The forward
+ function involves unpacking the tensor in-order to directly use the
+ DSM loss functions.
+
+ Args:
+ x:
+ a torch.tensor of the input features.
+ """
+ x = x.detach().clone()
+ inputmask = ~torch.isnan(x[:, :, 0]).reshape(-1)
+ x[torch.isnan(x)] = 0
+ xrep, _ = self.embedding(x)
+ xrep = xrep.contiguous().view(-1, self.hidden)
+ xrep = xrep[inputmask]
+ #xrep = nn.ReLU6()(xrep)
+ return(self.act(self.shapeg(xrep))+self.shape.expand(xrep.shape[0], -1),
+ self.act(self.scaleg(xrep))+self.scale.expand(xrep.shape[0], -1),
+ self.gate(xrep)/self.temp)
self.scaleg = nn.Sequential(nn.Linear(layers[-1], k, bias=True))
self.shapeg = nn.Sequential(nn.Linear(layers[-1], k, bias=True))
- if init is not False:
- self.shape.data.fill_(init[0])
- self.scale.data.fill_(init[1])
-
def forward(self, x):
"""The forward function that is called when data is passed through DSM.
@@ -438,6 +770,10 @@
Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
+
Deep Survival Machines is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
Deep Survival Machines is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+See the
+GNU General Public License for more details.
+
You should have received a copy of the GNU General Public License
+along with Foobar.
+If not, see https://www.gnu.org/licenses/.
-
+
Expand source code
-
"""
+
# coding=utf-8
+# Copyright 2020 Chirag Nagpal
+#
+# This file is part of Deep Survival Machines.
+
+# Deep Survival Machines is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# Deep Survival Machines is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with Deep Survival Machines.
+# If not, see <https://www.gnu.org/licenses/>.
+
+"""
Python package `dsm` provides an API to train the Deep Survival Machines
and associated models for problems in survival analysis. The underlying model
is implemented in `pytorch`.
@@ -144,9 +166,9 @@
License
This is the caption of the figure (a simple paragraph).
-**Deep Survival Machines (DSM)** is a fully parametric approach to model
-Time-to-Event outcomes in the presence of Censoring first introduced in
-[\[1\]](https://arxiv.org/abs/2003.01176).
+**Deep Survival Machines (DSM)** is a fully parametric approach to model
+Time-to-Event outcomes in the presence of Censoring first introduced in
+[\[1\]](https://arxiv.org/abs/2003.01176).
In the context of Healthcare ML and Biostatistics, this is known as 'Survival
Analysis'. The key idea behind Deep Survival Machines is to model the
underlying event outcome distribution as a mixure of some fixed \( k \)
@@ -187,7 +209,7 @@
License
Please cite the following papers if you are using the `dsm` package.
-[1] [Deep Survival Machines:
+[1] [Deep Survival Machines:
Fully Parametric Survival Regression and
Representation Learning for Censored Data with Competing Risks."
arXiv preprint arXiv:2003.01176 (2020)](https://arxiv.org/abs/2003.01176)</a>
@@ -220,22 +242,23 @@
License
Copyright 2020 [Chirag Nagpal](http://cs.cmu.edu/~chiragn),
[Auton Lab](http://www.autonlab.org).
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
+Deep Survival Machines is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
-http://www.apache.org/licenses/LICENSE-2.0
+Deep Survival Machines is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
+You should have received a copy of the GNU General Public License
+along with Foobar. If not, see <https://www.gnu.org/licenses/>.
<img style="float: right;" width ="200px" src="https://www.cmu.edu/brand/downloads/assets/images/wordmarks-600x600-min.jpg">
-<img style="float: right;padding-top:50px" src="https://www.autonlab.org/user/themes/auton/images/AutonLogo.png">
+<img style="float: right;padding-top:50px" src="https://www.autonlab.org/user/themes/auton/images/AutonLogo.png">
<br><br><br><br><br>
<br><br><br><br><br>
diff --git a/docs/losses.html b/docs/losses.html
index 494b3e7..5df6074 100644
--- a/docs/losses.html
+++ b/docs/losses.html
@@ -40,19 +40,23 @@
Module dsm.losses
Expand source code
# coding=utf-8
-# Copyright 2020 Chirag Nagpal, Auton Lab.
+# Copyright 2020 Chirag Nagpal
#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# This file is part of Deep Survival Machines.
+
+# Deep Survival Machines is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# Deep Survival Machines is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with Deep Survival Machines.
+# If not, see <https://www.gnu.org/licenses/>.
"""Loss function definitions for the Deep Survival Machines model
@@ -93,7 +97,6 @@
# coding=utf-8
-# Copyright 2020 Chirag Nagpal, Auton Lab.
+# Copyright 2020 Chirag Nagpal
#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# This file is part of Deep Survival Machines.
+
+# Deep Survival Machines is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# Deep Survival Machines is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with Deep Survival Machines.
+# If not, see <https://www.gnu.org/licenses/>.
"""Utility functions to train the Deep Survival Machines models"""
+from dsm.dsm_torch import DeepSurvivalMachinesTorch
from dsm.losses import unconditional_loss, conditional_loss
from tqdm import tqdm
@@ -55,7 +60,6 @@