diff --git a/.travis.yml b/.travis.yml index 1f74dbe..a817389 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,7 +5,6 @@ python: - "3.8" os: - linux - - osx # command to install dependencies install: - pip install -r requirements.txt @@ -15,4 +14,4 @@ install: # command to run tests script: - python -m pytest tests/ - - pylint --fail-under=9 dsm/ + - pylint --fail-under=8 dsm/ diff --git a/docs/datasets.html b/docs/datasets.html index 1e77ccd..8a0a89a 100644 --- a/docs/datasets.html +++ b/docs/datasets.html @@ -31,19 +31,24 @@

Module dsm.datasets

Expand source code
# coding=utf-8
-# Copyright 2020 Chirag Nagpal, Auton Lab.
+# Copyright 2020 Chirag Nagpal
 #
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# This file is part of Deep Survival Machines.
+
+# Deep Survival Machines is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# Deep Survival Machines is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with Deep Survival Machines.
+# If not, see <https://www.gnu.org/licenses/>.
+
 
 """Utility functions to load standard datasets to train and evaluate the
 Deep Survival Machines models.
@@ -75,13 +80,21 @@ 

Module dsm.datasets

return e, t -def _load_pbc_dataset(): +def _load_pbc_dataset(sequential): """Helper function to load and preprocess the PBC dataset The Primary biliary cirrhosis (PBC) Dataset [1] is well known dataset for evaluating survival analysis models with time dependent covariates. + Parameters + ---------- + sequential: bool + If True returns a list of np.arrays for each individual. + else, returns collapsed results for each time step. To train + recurrent neural models you would typically use True. + + References ---------- [1] Fleming, Thomas R., and David P. Harrington. Counting processes and @@ -89,7 +102,36 @@

Module dsm.datasets

""" - raise NotImplementedError('') + data = pkgutil.get_data(__name__, 'datasets/pbc2.csv') + data = pd.read_csv(io.BytesIO(data)) + + data['histologic'] = data['histologic'].astype(str) + dat_cat = data[['drug', 'sex', 'ascites', 'hepatomegaly', + 'spiders', 'edema', 'histologic']] + dat_num = data[['serBilir', 'serChol', 'albumin', 'alkaline', + 'SGOT', 'platelets', 'prothrombin']] + age = data['age'] + data['years'] + + x1 = pd.get_dummies(dat_cat).values + x2 = dat_num.values + x3 = age.values.reshape(-1, 1) + x = np.hstack([x1, x2, x3]) + + time = (data['years'] - data['year']).values + event = data['status2'].values + + x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) + x_ = StandardScaler().fit_transform(x) + + if not sequential: + return x_, time, event + else: + x, t, e = [], [], [] + for id_ in sorted(list(set(data['id']))): + x.append(x_[data['id'] == id_]) + t.append(time[data['id'] == id_]) + e.append(event[data['id'] == id_]) + return x, t, e def _load_support_dataset(): """Helper function to load and preprocess the SUPPORT dataset. @@ -128,13 +170,16 @@

Module dsm.datasets

return x[remove], t[remove], e[remove] -def load_dataset(dataset='SUPPORT'): +def load_dataset(dataset='SUPPORT', **kwargs): """Helper function to load datasets to test Survival Analysis models. Parameters ---------- dataset: str - The choice of dataset to load. Currently implemented is 'SUPPORT'. + The choice of dataset to load. Currently implemented is 'SUPPORT' + and 'PBC'. + **kwargs: dict + Dataset specific keyword arguments. Returns ---------- @@ -146,6 +191,9 @@

Module dsm.datasets

if dataset == 'SUPPORT': return _load_support_dataset() + if dataset == 'PBC': + sequential = kwargs.get('sequential', False) + return _load_pbc_dataset(sequential) else: return NotImplementedError('Dataset '+dataset+' not implemented.')
@@ -184,14 +232,17 @@

Functions

-def load_dataset(dataset='SUPPORT') +def load_dataset(dataset='SUPPORT', **kwargs)

Helper function to load datasets to test Survival Analysis models.

Parameters

dataset : str
-
The choice of dataset to load. Currently implemented is 'SUPPORT'.
+
The choice of dataset to load. Currently implemented is 'SUPPORT' +and 'PBC'.
+
**kwargs : dict
+
Dataset specific keyword arguments.

Returns

@@ -203,13 +254,16 @@

Returns

Expand source code -
def load_dataset(dataset='SUPPORT'):
+
def load_dataset(dataset='SUPPORT', **kwargs):
   """Helper function to load datasets to test Survival Analysis models.
 
   Parameters
   ----------
   dataset: str
-      The choice of dataset to load. Currently implemented is 'SUPPORT'.
+      The choice of dataset to load. Currently implemented is 'SUPPORT'
+      and 'PBC'.
+  **kwargs: dict
+      Dataset specific keyword arguments.
 
   Returns
   ----------
@@ -221,6 +275,9 @@ 

Returns

if dataset == 'SUPPORT': return _load_support_dataset() + if dataset == 'PBC': + sequential = kwargs.get('sequential', False) + return _load_pbc_dataset(sequential) else: return NotImplementedError('Dataset '+dataset+' not implemented.')
diff --git a/docs/datautils.html b/docs/datautils.html deleted file mode 100644 index bd5779f..0000000 --- a/docs/datautils.html +++ /dev/null @@ -1,278 +0,0 @@ - - - - - - -dsm.datautils API documentation - - - - - - - - - - - - -
-
-
-

Module dsm.datautils

-
-
-
- -Expand source code - -
# coding=utf-8
-# Copyright 2020 Chirag Nagpal, Auton Lab.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import io, pkgutil
-
-import pandas as pd
-import numpy as np
-
-from sklearn.impute import SimpleImputer
-from sklearn.preprocessing import StandardScaler
-
-def increase_censoring(e, t, p):
-
-  uncens = np.where(e == 1)[0]
-  mask = np.random.choice([False, True], len(uncens), p=[1-p, p])
-  toswitch = uncens[mask]
-
-  e[toswitch] = 0
-  t_ = t[toswitch]
-
-  newt = []
-  for t__ in t_:
-    newt.append(np.random.uniform(1, t__))
-  t[toswitch] = newt
-
-  return e, t
-
-def load_support_dataset():
-
-  """Helper function to load and preprocess the SUPPORT dataset.
-
-  The SUPPORT Dataset comes from the Vanderbilt University study
-  to estimate survival for seriously ill hospitalized adults [1].
-
-  Please refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc.
-  for the original datasource.
-
-  [1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic
-  model: Objective estimates of survival for seriously ill hospitalized
-  adults. Annals of Internal Medicine 122:191-203.
-
-  """
-  data = pkgutil.get_data(__name__, 'datasets/support2.csv')
-  data = pd.read_csv(io.BytesIO(data))
-  x1 = data[['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp',
-             'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun',
-             'urine', 'adlp', 'adls']]
-
-  catfeats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
-  x2 = pd.get_dummies(data[catfeats])
-
-  x = np.concatenate([x1, x2], axis=1)
-  t = data['d.time'].values
-  e = data['death'].values
-
-  x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x)
-  x = StandardScaler().fit_transform(x)
-
-  remove = ~np.isnan(t)
-  return x[remove], t[remove], e[remove]
-
-
-def load_dataset(dataset='SUPPORT'):
-  """Helper function to load datasets to test Survival Analysis models.
-
-  Parameters
-  ----------
-  dataset: str
-      The choice of dataset to load. Currently implemented is 'SUPPORT'.
-
-  Returns
-  ----------
-  tuple: (np.ndarray, np.ndarray, np.ndarray)
-      A tuple of the form of (x, t, e) where x, t, e are the input covariates,
-      event times and the censoring indicators respectively.
-
-  """
-
-  if dataset == 'SUPPORT':
-    return _load_support_dataset()
-
-
-
-
-
-
-
-

Functions

-
-
-def increase_censoring(e, t, p) -
-
-
-
- -Expand source code - -
def increase_censoring(e, t, p):
-
-  uncens = np.where(e == 1)[0]
-  mask = np.random.choice([False, True], len(uncens), p=[1-p, p])
-  toswitch = uncens[mask]
-
-  e[toswitch] = 0
-  t_ = t[toswitch]
-
-  newt = []
-  for t__ in t_:
-    newt.append(np.random.uniform(1, t__))
-  t[toswitch] = newt
-
-  return e, t
-
-
-
-def load_dataset(dataset='SUPPORT') -
-
-

Helper function to load datasets to test Survival Analysis models.

-

Parameters

-
-
dataset : str
-
The choice of dataset to load. Currently implemented is 'SUPPORT'.
-
-

Returns

-
-
tuple : (np.ndarray, np.ndarray, np.ndarray)
-
A tuple of the form of (x, t, e) where x, t, e are the input covariates, -event times and the censoring indicators respectively.
-
-
- -Expand source code - -
def load_dataset(dataset='SUPPORT'):
-  """Helper function to load datasets to test Survival Analysis models.
-
-  Parameters
-  ----------
-  dataset: str
-      The choice of dataset to load. Currently implemented is 'SUPPORT'.
-
-  Returns
-  ----------
-  tuple: (np.ndarray, np.ndarray, np.ndarray)
-      A tuple of the form of (x, t, e) where x, t, e are the input covariates,
-      event times and the censoring indicators respectively.
-
-  """
-
-  if dataset == 'SUPPORT':
-    return _load_support_dataset()
-
-
-
-def load_support_dataset() -
-
-

Helper function to load and preprocess the SUPPORT dataset.

-

The SUPPORT Dataset comes from the Vanderbilt University study -to estimate survival for seriously ill hospitalized adults [1].

-

Please refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc. -for the original datasource.

-

[1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic -model: Objective estimates of survival for seriously ill hospitalized -adults. Annals of Internal Medicine 122:191-203.

-
- -Expand source code - -
def load_support_dataset():
-
-  """Helper function to load and preprocess the SUPPORT dataset.
-
-  The SUPPORT Dataset comes from the Vanderbilt University study
-  to estimate survival for seriously ill hospitalized adults [1].
-
-  Please refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc.
-  for the original datasource.
-
-  [1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic
-  model: Objective estimates of survival for seriously ill hospitalized
-  adults. Annals of Internal Medicine 122:191-203.
-
-  """
-  data = pkgutil.get_data(__name__, 'datasets/support2.csv')
-  data = pd.read_csv(io.BytesIO(data))
-  x1 = data[['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp',
-             'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun',
-             'urine', 'adlp', 'adls']]
-
-  catfeats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']
-  x2 = pd.get_dummies(data[catfeats])
-
-  x = np.concatenate([x1, x2], axis=1)
-  t = data['d.time'].values
-  e = data['death'].values
-
-  x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x)
-  x = StandardScaler().fit_transform(x)
-
-  remove = ~np.isnan(t)
-  return x[remove], t[remove], e[remove]
-
-
-
-
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/dsm_api.html b/docs/dsm_api.html index 8d180e5..d8009a2 100644 --- a/docs/dsm_api.html +++ b/docs/dsm_api.html @@ -31,19 +31,24 @@

Module dsm.dsm_api

Expand source code
# coding=utf-8
-# Copyright 2020 Chirag Nagpal, Auton Lab.
+# Copyright 2020 Chirag Nagpal
 #
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# This file is part of Deep Survival Machines.
+
+# Deep Survival Machines is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# Deep Survival Machines is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with Deep Survival Machines.
+# If not, see <https://www.gnu.org/licenses/>.
+
 
 """
 This module is a wrapper around torch implementations and
@@ -55,7 +60,6 @@ 

Module dsm.dsm_api

from dsm.utilities import train_dsm import torch - import numpy as np class DeepSurvivalMachines(): @@ -70,7 +74,7 @@

Module dsm.dsm_api

References ---------- - [1] <a href="https://arxiv.org/abs/2003.01176">Deep Survival Machines: + [1] <a href="https://arxiv.org/abs/2003.01176">Deep Survival Machines: Fully Parametric Survival Regression and Representation Learning for Censored Data with Competing Risks." arXiv preprint arXiv:2003.01176 (2020)</a> @@ -124,7 +128,7 @@

Module dsm.dsm_api

print("Distribution Choice:", self.dist) - def fit(self, x, t, e, vsize=0.15, + def fit(self, x, t, e, vsize=0.15, iters=1, learning_rate=1e-3, batch_size=100, elbo=True, optimizer="Adam", random_state=100): @@ -150,10 +154,10 @@

Module dsm.dsm_api

learning is performed on mini-batches of input data. this parameter specifies the size of each mini-batch. elbo: bool - Whether to use the Evidence Lower Bound for Optimization. + Whether to use the Evidence Lower Bound for optimization. Default is True. optimizer: str - The choice of the gradient based optimization method. One of + The choice of the gradient based optimization method. One of 'Adam', 'RMSProp' or 'SGD'. random_state: float random seed that determines how the validation set is chosen. @@ -182,7 +186,6 @@

Module dsm.dsm_api

model = DeepSurvivalMachinesTorch(inputdim, k=self.k, layers=self.layers, - init=False, dist=self.dist, temp=self.temp, discount=self.discount, @@ -204,7 +207,7 @@

Module dsm.dsm_api

def predict_risk(self, x, t): - """Returns the estimated risk of an event occuring before time \( t \), + """Returns the estimated risk of an event occuring before time \( t \) \( \widehat{\mathbb{P}}(T\leq t|X) \) for some input data \( x \). Parameters @@ -396,7 +399,7 @@

Example

References ---------- - [1] <a href="https://arxiv.org/abs/2003.01176">Deep Survival Machines: + [1] <a href="https://arxiv.org/abs/2003.01176">Deep Survival Machines: Fully Parametric Survival Regression and Representation Learning for Censored Data with Competing Risks." arXiv preprint arXiv:2003.01176 (2020)</a> @@ -450,7 +453,7 @@

Example

print("Distribution Choice:", self.dist) - def fit(self, x, t, e, vsize=0.15, + def fit(self, x, t, e, vsize=0.15, iters=1, learning_rate=1e-3, batch_size=100, elbo=True, optimizer="Adam", random_state=100): @@ -476,10 +479,10 @@

Example

learning is performed on mini-batches of input data. this parameter specifies the size of each mini-batch. elbo: bool - Whether to use the Evidence Lower Bound for Optimization. + Whether to use the Evidence Lower Bound for optimization. Default is True. optimizer: str - The choice of the gradient based optimization method. One of + The choice of the gradient based optimization method. One of 'Adam', 'RMSProp' or 'SGD'. random_state: float random seed that determines how the validation set is chosen. @@ -508,7 +511,6 @@

Example

model = DeepSurvivalMachinesTorch(inputdim, k=self.k, layers=self.layers, - init=False, dist=self.dist, temp=self.temp, discount=self.discount, @@ -530,7 +532,7 @@

Example

def predict_risk(self, x, t): - """Returns the estimated risk of an event occuring before time \( t \), + """Returns the estimated risk of an event occuring before time \( t \) \( \widehat{\mathbb{P}}(T\leq t|X) \) for some input data \( x \). Parameters @@ -613,7 +615,7 @@

Parameters

learning is performed on mini-batches of input data. this parameter specifies the size of each mini-batch.
elbo : bool
-
Whether to use the Evidence Lower Bound for Optimization. +
Whether to use the Evidence Lower Bound for optimization. Default is True.
optimizer : str
The choice of the gradient based optimization method. One of @@ -625,7 +627,7 @@

Parameters

Expand source code -
def fit(self, x, t, e, vsize=0.15, 
+
def fit(self, x, t, e, vsize=0.15,
         iters=1, learning_rate=1e-3, batch_size=100,
         elbo=True, optimizer="Adam", random_state=100):
 
@@ -651,10 +653,10 @@ 

Parameters

learning is performed on mini-batches of input data. this parameter specifies the size of each mini-batch. elbo: bool - Whether to use the Evidence Lower Bound for Optimization. + Whether to use the Evidence Lower Bound for optimization. Default is True. optimizer: str - The choice of the gradient based optimization method. One of + The choice of the gradient based optimization method. One of 'Adam', 'RMSProp' or 'SGD'. random_state: float random seed that determines how the validation set is chosen. @@ -683,7 +685,6 @@

Parameters

model = DeepSurvivalMachinesTorch(inputdim, k=self.k, layers=self.layers, - init=False, dist=self.dist, temp=self.temp, discount=self.discount, @@ -708,7 +709,7 @@

Parameters

def predict_risk(self, x, t)
-

Returns the estimated risk of an event occuring before time t , +

Returns the estimated risk of an event occuring before time t \widehat{\mathbb{P}}(T\leq t|X) for some input data x .

Parameters

@@ -728,7 +729,7 @@

Returns

Expand source code
def predict_risk(self, x, t):
-  """Returns the estimated risk of an event occuring before time \( t \),
+  """Returns the estimated risk of an event occuring before time \( t \)
     \( \widehat{\mathbb{P}}(T\leq t|X) \) for some input data \( x \).
 
   Parameters
diff --git a/docs/dsm_torch.html b/docs/dsm_torch.html
index a823b89..36d8a1d 100644
--- a/docs/dsm_torch.html
+++ b/docs/dsm_torch.html
@@ -33,19 +33,24 @@ 

Module dsm.dsm_torch

Expand source code
# coding=utf-8
-# Copyright 2020 Chirag Nagpal, Auton Lab.
+# Copyright 2020 Chirag Nagpal
 #
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# This file is part of Deep Survival Machines.
+
+# Deep Survival Machines is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# Deep Survival Machines is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with Deep Survival Machines.
+# If not, see <https://www.gnu.org/licenses/>.
+
 
 """Torch model definitons for the Deep Survival Machines model
 
@@ -63,8 +68,8 @@ 

Module dsm.dsm_torch

def create_representation(inputdim, layers, activation): """Helper function to generate the representation function for DSM. - Deep Survival Machines learns a representation $$\Phi(X)$$ for the input data. - This representation is parameterized using a Non Linear Multilayer + Deep Survival Machines learns a representation (\ Phi(X) \) for the input + data. This representation is parameterized using a Non Linear Multilayer Perceptron (`torch.nn.Module`). This is a helper function designed to instantiate the representation for Deep Survival Machines. @@ -145,7 +150,7 @@

Module dsm.dsm_torch

Default is 1. """ - def __init__(self, inputdim, k, layers=None, init=False, dist='Weibull', + def __init__(self, inputdim, k, layers=None, dist='Weibull', temp=1000., discount=1.0, optimizer='Adam'): super(DeepSurvivalMachinesTorch, self).__init__() @@ -163,11 +168,13 @@

Module dsm.dsm_torch

self.act = nn.SELU() self.scale = nn.Parameter(-torch.ones(k)) self.shape = nn.Parameter(-torch.ones(k)) - elif self.dist == 'LogNormal': self.act = nn.Tanh() self.scale = nn.Parameter(torch.ones(k)) self.shape = nn.Parameter(torch.ones(k)) + else: + raise NotImplementedError('Distribution: '+self.dist+' not implemented'+ + ' yet.') self.embedding = create_representation(inputdim, layers, 'ReLU6') @@ -181,10 +188,6 @@

Module dsm.dsm_torch

self.scaleg = nn.Sequential(nn.Linear(layers[-1], k, bias=True)) self.shapeg = nn.Sequential(nn.Linear(layers[-1], k, bias=True)) - if init is not False: - self.shape.data.fill_(init[0]) - self.scale.data.fill_(init[1]) - def forward(self, x): """The forward function that is called when data is passed through DSM. @@ -197,6 +200,112 @@

Module dsm.dsm_torch

self.act(self.scaleg(xrep))+self.scale.expand(x.shape[0], -1), self.gate(xrep)/self.temp) + def get_shape_scale(self): + return(self.shape, + self.scale) + +class DeepRecurrentSurvivalMachinesTorch(DeepSurvivalMachinesTorch): + """A Torch implementation of Deep Recurrent Survival Machines model. + + This is an implementation of Deep Recurrent Survival Machines model + in torch. It inherits from `DeepSurvivalMachinesTorch` and replaces the + input representation learning MLP with an LSTM or RNN, the parameters of the + underlying distributions and the forward function which is called whenever + data is passed to the module. Each of the parameters are nn.Parameters and + torch automatically keeps track and computes gradients for them. + + .. warning:: + Not designed to be used directly. + Please use the API inferface `dsm.dsm_api.DeepRecurrentSurvivalMachines`!! + + Parameters + ---------- + inputdim: int + Dimensionality of the input features. + k: int + The number of underlying parametric distributions. + layers: int + The number of hidden layers in the LSTM or RNN cell. + hidden: int + The number of neurons in each hidden layer. + init: tuple + A tuple for initialization of the parameters for the underlying + distributions. (shape, scale). + dist: str + Choice of the underlying survival distributions. + One of 'Weibull', 'LogNormal'. + Default is 'Weibull'. + temp: float + The logits for the gate are rescaled with this value. + Default is 1000. + discount: float + a float in [0,1] that determines how to discount the tail bias + from the uncensored instances. + Default is 1. + """ + + def __init__(self, inputdim, k, typ='LSTM', layers=1, + hidden=None, dist='Weibull', + temp=1000., discount=1.0, optimizer='Adam'): + super(DeepSurvivalMachinesTorch, self).__init__() + + self.k = k + self.dist = dist + self.temp = float(temp) + self.discount = float(discount) + self.optimizer = optimizer + self.hidden = hidden + self.layers = layers + self.typ = typ + + if self.dist == 'Weibull': + self.act = nn.SELU() + self.scale = nn.Parameter(-torch.ones(k)) + self.shape = nn.Parameter(-torch.ones(k)) + elif self.dist == 'LogNormal': + self.act = nn.Tanh() + self.scale = nn.Parameter(torch.ones(k)) + self.shape = nn.Parameter(torch.ones(k)) + else: + raise NotImplementedError('Distribution: '+self.dist+' not implemented'+ + ' yet.') + + self.gate = nn.Sequential(nn.Linear(hidden, k, bias=False)) + self.scaleg = nn.Sequential(nn.Linear(hidden, k, bias=True)) + self.shapeg = nn.Sequential(nn.Linear(hidden, k, bias=True)) + + if self.typ == 'LSTM': + self.embedding = nn.LSTM(inputdim, hidden, layers, + bias=False, batch_first=True) + if self.typ == 'RNN': + self.embedding = nn.RNN(inputdim, hidden, layers, + bias=False, batch_first=True) + + #self.embedding = nn.ReLU6(self.embedding) + + + def forward(self, x): + """The forward function that is called when data is passed through DSM. + + Note: As compared to DSM, the input data for DRSM is a tensor. The forward + function involves unpacking the tensor in-order to directly use the + DSM loss functions. + + Args: + x: + a torch.tensor of the input features. + """ + x = x.detach().clone() + inputmask = ~torch.isnan(x[:, :, 0]).reshape(-1) + x[torch.isnan(x)] = 0 + xrep, _ = self.embedding(x) + xrep = xrep.contiguous().view(-1, self.hidden) + xrep = xrep[inputmask] + #xrep = nn.ReLU6()(xrep) + return(self.act(self.shapeg(xrep))+self.shape.expand(xrep.shape[0], -1), + self.act(self.scaleg(xrep))+self.scale.expand(xrep.shape[0], -1), + self.gate(xrep)/self.temp) + def get_shape_scale(self): return(self.shape, self.scale)
@@ -214,8 +323,8 @@

Functions

Helper function to generate the representation function for DSM.

-

Deep Survival Machines learns a representation \Phi(X) for the input data. -This representation is parameterized using a Non Linear Multilayer +

Deep Survival Machines learns a representation (\ Phi(X) ) for the input +data. This representation is parameterized using a Non Linear Multilayer Perceptron (torch.nn.Module). This is a helper function designed to instantiate the representation for Deep Survival Machines.

@@ -240,8 +349,8 @@

Returns

def create_representation(inputdim, layers, activation):
   """Helper function to generate the representation function for DSM.
 
-  Deep Survival Machines learns a representation $$\Phi(X)$$ for the input data.
-  This representation is parameterized using a Non Linear Multilayer
+  Deep Survival Machines learns a representation (\ Phi(X) \) for the input
+  data. This representation is parameterized using a Non Linear Multilayer
   Perceptron (`torch.nn.Module`). This is a helper function designed to
   instantiate the representation for Deep Survival Machines.
 
@@ -285,9 +394,234 @@ 

Returns

Classes

+
+class DeepRecurrentSurvivalMachinesTorch +(inputdim, k, typ='LSTM', layers=1, hidden=None, dist='Weibull', temp=1000.0, discount=1.0, optimizer='Adam') +
+
+

A Torch implementation of Deep Recurrent Survival Machines model.

+

This is an implementation of Deep Recurrent Survival Machines model +in torch. It inherits from DeepSurvivalMachinesTorch and replaces the +input representation learning MLP with an LSTM or RNN, the parameters of the +underlying distributions and the forward function which is called whenever +data is passed to the module. Each of the parameters are nn.Parameters and +torch automatically keeps track and computes gradients for them.

+
+

Warning

+

Not designed to be used directly. +Please use the API inferface DeepRecurrentSurvivalMachines!!

+
+

Parameters

+
+
inputdim : int
+
Dimensionality of the input features.
+
k : int
+
The number of underlying parametric distributions.
+
layers : int
+
The number of hidden layers in the LSTM or RNN cell.
+
hidden : int
+
The number of neurons in each hidden layer.
+
init : tuple
+
A tuple for initialization of the parameters for the underlying +distributions. (shape, scale).
+
dist : str
+
Choice of the underlying survival distributions. +One of 'Weibull', 'LogNormal'. +Default is 'Weibull'.
+
temp : float
+
The logits for the gate are rescaled with this value. +Default is 1000.
+
discount : float
+
a float in [0,1] that determines how to discount the tail bias +from the uncensored instances. +Default is 1.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DeepRecurrentSurvivalMachinesTorch(DeepSurvivalMachinesTorch):
+  """A Torch implementation of Deep Recurrent Survival Machines model.
+
+  This is an implementation of Deep Recurrent Survival Machines model
+  in torch. It inherits from `DeepSurvivalMachinesTorch` and replaces the
+  input representation learning MLP with an LSTM or RNN, the parameters of the
+  underlying distributions and the forward function which is called whenever
+  data is passed to the module. Each of the parameters are nn.Parameters and
+  torch automatically keeps track and computes gradients for them.
+
+  .. warning::
+    Not designed to be used directly.
+    Please use the API inferface `dsm.dsm_api.DeepRecurrentSurvivalMachines`!!
+
+  Parameters
+  ----------
+  inputdim: int
+      Dimensionality of the input features.
+  k: int
+      The number of underlying parametric distributions.
+  layers: int
+      The number of hidden layers in the LSTM or RNN cell.
+  hidden: int
+      The number of neurons in each hidden layer.
+  init: tuple
+      A tuple for initialization of the parameters for the underlying
+      distributions. (shape, scale).
+  dist: str
+      Choice of the underlying survival distributions.
+      One of 'Weibull', 'LogNormal'.
+      Default is 'Weibull'.
+  temp: float
+      The logits for the gate are rescaled with this value.
+      Default is 1000.
+  discount: float
+      a float in [0,1] that determines how to discount the tail bias
+      from the uncensored instances.
+      Default is 1.
+  """
+
+  def __init__(self, inputdim, k, typ='LSTM', layers=1,
+               hidden=None, dist='Weibull',
+               temp=1000., discount=1.0, optimizer='Adam'):
+    super(DeepSurvivalMachinesTorch, self).__init__()
+
+    self.k = k
+    self.dist = dist
+    self.temp = float(temp)
+    self.discount = float(discount)
+    self.optimizer = optimizer
+    self.hidden = hidden
+    self.layers = layers
+    self.typ = typ
+
+    if self.dist == 'Weibull':
+      self.act = nn.SELU()
+      self.scale = nn.Parameter(-torch.ones(k))
+      self.shape = nn.Parameter(-torch.ones(k))
+    elif self.dist == 'LogNormal':
+      self.act = nn.Tanh()
+      self.scale = nn.Parameter(torch.ones(k))
+      self.shape = nn.Parameter(torch.ones(k))
+    else:
+      raise NotImplementedError('Distribution: '+self.dist+' not implemented'+
+                                ' yet.')
+
+    self.gate = nn.Sequential(nn.Linear(hidden, k, bias=False))
+    self.scaleg = nn.Sequential(nn.Linear(hidden, k, bias=True))
+    self.shapeg = nn.Sequential(nn.Linear(hidden, k, bias=True))
+
+    if self.typ == 'LSTM':
+      self.embedding = nn.LSTM(inputdim, hidden, layers,
+                               bias=False, batch_first=True)
+    if self.typ == 'RNN':
+      self.embedding = nn.RNN(inputdim, hidden, layers,
+                              bias=False, batch_first=True)
+
+    #self.embedding = nn.ReLU6(self.embedding)
+
+
+  def forward(self, x):
+    """The forward function that is called when data is passed through DSM.
+
+    Note: As compared to DSM, the input data for DRSM is a tensor. The forward
+    function involves unpacking the tensor in-order to directly use the
+    DSM loss functions.
+
+    Args:
+      x:
+        a torch.tensor of the input features.
+    """
+    x = x.detach().clone()
+    inputmask = ~torch.isnan(x[:, :, 0]).reshape(-1)
+    x[torch.isnan(x)] = 0
+    xrep, _ = self.embedding(x)
+    xrep = xrep.contiguous().view(-1, self.hidden)
+    xrep = xrep[inputmask]
+    #xrep = nn.ReLU6()(xrep)
+    return(self.act(self.shapeg(xrep))+self.shape.expand(xrep.shape[0], -1),
+           self.act(self.scaleg(xrep))+self.scale.expand(xrep.shape[0], -1),
+           self.gate(xrep)/self.temp)
+
+  def get_shape_scale(self):
+    return(self.shape,
+           self.scale)
+
+

Ancestors

+ +

Class variables

+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

The forward function that is called when data is passed through DSM.

+

Note: As compared to DSM, the input data for DRSM is a tensor. The forward +function involves unpacking the tensor in-order to directly use the +DSM loss functions.

+

Args

+

x: +a torch.tensor of the input features.

+
+ +Expand source code + +
def forward(self, x):
+  """The forward function that is called when data is passed through DSM.
+
+  Note: As compared to DSM, the input data for DRSM is a tensor. The forward
+  function involves unpacking the tensor in-order to directly use the
+  DSM loss functions.
+
+  Args:
+    x:
+      a torch.tensor of the input features.
+  """
+  x = x.detach().clone()
+  inputmask = ~torch.isnan(x[:, :, 0]).reshape(-1)
+  x[torch.isnan(x)] = 0
+  xrep, _ = self.embedding(x)
+  xrep = xrep.contiguous().view(-1, self.hidden)
+  xrep = xrep[inputmask]
+  #xrep = nn.ReLU6()(xrep)
+  return(self.act(self.shapeg(xrep))+self.shape.expand(xrep.shape[0], -1),
+         self.act(self.scaleg(xrep))+self.scale.expand(xrep.shape[0], -1),
+         self.gate(xrep)/self.temp)
+
+
+
+def get_shape_scale(self) +
+
+
+
+ +Expand source code + +
def get_shape_scale(self):
+  return(self.shape,
+         self.scale)
+
+
+
+
class DeepSurvivalMachinesTorch -(inputdim, k, layers=None, init=False, dist='Weibull', temp=1000.0, discount=1.0, optimizer='Adam') +(inputdim, k, layers=None, dist='Weibull', temp=1000.0, discount=1.0, optimizer='Adam')

A Torch implementation of Deep Survival Machines model.

@@ -378,7 +712,7 @@

Parameters

Default is 1. """ - def __init__(self, inputdim, k, layers=None, init=False, dist='Weibull', + def __init__(self, inputdim, k, layers=None, dist='Weibull', temp=1000., discount=1.0, optimizer='Adam'): super(DeepSurvivalMachinesTorch, self).__init__() @@ -396,11 +730,13 @@

Parameters

self.act = nn.SELU() self.scale = nn.Parameter(-torch.ones(k)) self.shape = nn.Parameter(-torch.ones(k)) - elif self.dist == 'LogNormal': self.act = nn.Tanh() self.scale = nn.Parameter(torch.ones(k)) self.shape = nn.Parameter(torch.ones(k)) + else: + raise NotImplementedError('Distribution: '+self.dist+' not implemented'+ + ' yet.') self.embedding = create_representation(inputdim, layers, 'ReLU6') @@ -414,10 +750,6 @@

Parameters

self.scaleg = nn.Sequential(nn.Linear(layers[-1], k, bias=True)) self.shapeg = nn.Sequential(nn.Linear(layers[-1], k, bias=True)) - if init is not False: - self.shape.data.fill_(init[0]) - self.scale.data.fill_(init[1]) - def forward(self, x): """The forward function that is called when data is passed through DSM. @@ -438,6 +770,10 @@

Ancestors

  • torch.nn.modules.module.Module
+

Subclasses

+

Class variables

var dump_patches : bool
@@ -514,6 +850,15 @@

Index

  • Classes

    • +

      DeepRecurrentSurvivalMachinesTorch

      + +
    • +
    • DeepSurvivalMachinesTorch

      • dump_patches
      • diff --git a/docs/index.html b/docs/index.html index 0abbc9e..f460ff4 100644 --- a/docs/index.html +++ b/docs/index.html @@ -99,24 +99,46 @@

        Contributing

        License

        Copyright 2020 Chirag Nagpal, Auton Lab.

        -

        Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at

        -

        http://www.apache.org/licenses/LICENSE-2.0

        -

        Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License.

        +

        Deep Survival Machines is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version.

        +

        Deep Survival Machines is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +See the +GNU General Public License for more details.

        +

        You should have received a copy of the GNU General Public License +along with Foobar. +If not, see https://www.gnu.org/licenses/.

        -

        +











        Expand source code -
        """
        +
        # coding=utf-8
        +# Copyright 2020 Chirag Nagpal
        +#
        +# This file is part of Deep Survival Machines.
        +
        +# Deep Survival Machines is free software: you can redistribute it and/or modify
        +# it under the terms of the GNU General Public License as published by
        +# the Free Software Foundation, either version 3 of the License, or
        +# (at your option) any later version.
        +
        +# Deep Survival Machines is distributed in the hope that it will be useful,
        +# but WITHOUT ANY WARRANTY; without even the implied warranty of
        +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
        +# GNU General Public License for more details.
        +
        +# You should have received a copy of the GNU General Public License
        +# along with Deep Survival Machines.
        +# If not, see <https://www.gnu.org/licenses/>.
        +
        +"""
         Python package `dsm` provides an API to train the Deep Survival Machines
         and associated models for problems in survival analysis. The underlying model
         is implemented in `pytorch`.
        @@ -144,9 +166,9 @@ 

        License

        This is the caption of the figure (a simple paragraph). -**Deep Survival Machines (DSM)** is a fully parametric approach to model -Time-to-Event outcomes in the presence of Censoring first introduced in -[\[1\]](https://arxiv.org/abs/2003.01176). +**Deep Survival Machines (DSM)** is a fully parametric approach to model +Time-to-Event outcomes in the presence of Censoring first introduced in +[\[1\]](https://arxiv.org/abs/2003.01176). In the context of Healthcare ML and Biostatistics, this is known as 'Survival Analysis'. The key idea behind Deep Survival Machines is to model the underlying event outcome distribution as a mixure of some fixed \( k \) @@ -187,7 +209,7 @@

        License

        Please cite the following papers if you are using the `dsm` package. -[1] [Deep Survival Machines: +[1] [Deep Survival Machines: Fully Parametric Survival Regression and Representation Learning for Censored Data with Competing Risks." arXiv preprint arXiv:2003.01176 (2020)](https://arxiv.org/abs/2003.01176)</a> @@ -220,22 +242,23 @@

        License

        Copyright 2020 [Chirag Nagpal](http://cs.cmu.edu/~chiragn), [Auton Lab](http://www.autonlab.org). -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +Deep Survival Machines is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. -http://www.apache.org/licenses/LICENSE-2.0 +Deep Survival Machines is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +You should have received a copy of the GNU General Public License +along with Foobar. If not, see <https://www.gnu.org/licenses/>. <img style="float: right;" width ="200px" src="https://www.cmu.edu/brand/downloads/assets/images/wordmarks-600x600-min.jpg"> -<img style="float: right;padding-top:50px" src="https://www.autonlab.org/user/themes/auton/images/AutonLogo.png"> +<img style="float: right;padding-top:50px" src="https://www.autonlab.org/user/themes/auton/images/AutonLogo.png"> <br><br><br><br><br> <br><br><br><br><br> diff --git a/docs/losses.html b/docs/losses.html index 494b3e7..5df6074 100644 --- a/docs/losses.html +++ b/docs/losses.html @@ -40,19 +40,23 @@

        Module dsm.losses

        Expand source code
        # coding=utf-8
        -# Copyright 2020 Chirag Nagpal, Auton Lab.
        +# Copyright 2020 Chirag Nagpal
         #
        -# Licensed under the Apache License, Version 2.0 (the "License");
        -# you may not use this file except in compliance with the License.
        -# You may obtain a copy of the License at
        -#
        -#     http://www.apache.org/licenses/LICENSE-2.0
        -#
        -# Unless required by applicable law or agreed to in writing, software
        -# distributed under the License is distributed on an "AS IS" BASIS,
        -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        -# See the License for the specific language governing permissions and
        -# limitations under the License.
        +# This file is part of Deep Survival Machines.
        +
        +# Deep Survival Machines is free software: you can redistribute it and/or modify
        +# it under the terms of the GNU General Public License as published by
        +# the Free Software Foundation, either version 3 of the License, or
        +# (at your option) any later version.
        +
        +# Deep Survival Machines is distributed in the hope that it will be useful,
        +# but WITHOUT ANY WARRANTY; without even the implied warranty of
        +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
        +# GNU General Public License for more details.
        +
        +# You should have received a copy of the GNU General Public License
        +# along with Deep Survival Machines.
        +# If not, see <https://www.gnu.org/licenses/>.
         
         """Loss function definitions for the Deep Survival Machines model
         
        @@ -93,7 +97,6 @@ 

        Module dsm.losses

        uncens = np.where(e == 1)[0] cens = np.where(e == 0)[0] - ll += f[uncens].sum() + s[cens].sum() return -ll.mean() @@ -179,9 +182,9 @@

        Module dsm.losses

        uncens = np.where(e.cpu().data.numpy() == 1)[0] cens = np.where(e.cpu().data.numpy() == 0)[0] - ll = lossf[uncens].sum() + alpha*losss[cens].sum() - return -ll/x.shape[0] + + return -ll.mean() def _conditional_weibull_loss(model, x, t, e, elbo=True): @@ -189,6 +192,8 @@

        Module dsm.losses

        alpha = model.discount shape, scale, logits = model.forward(x) + #print (shape, scale, logits) + k_ = shape b_ = scale @@ -228,9 +233,9 @@

        Module dsm.losses

        uncens = np.where(e.cpu().data.numpy() == 1)[0] cens = np.where(e.cpu().data.numpy() == 0)[0] - ll = lossf[uncens].sum() + alpha*losss[cens].sum() - return -ll/x.shape[0] + + return -ll.mean() def conditional_loss(model, x, t, e, elbo=True): @@ -269,9 +274,7 @@

        Module dsm.losses

        lcdfs.append(s) lcdfs = torch.stack(lcdfs, dim=1) - lcdfs = lcdfs+logits - lcdfs = torch.logsumexp(lcdfs, dim=1) cdfs.append(lcdfs.detach().numpy()) @@ -309,7 +312,6 @@

        Module dsm.losses

        lcdfs.append(s) lcdfs = torch.stack(lcdfs, dim=1) - lcdfs = lcdfs+logits lcdfs = torch.logsumexp(lcdfs, dim=1) cdfs.append(lcdfs.detach().numpy()) @@ -321,7 +323,6 @@

        Module dsm.losses

        torch.no_grad() if model.dist == 'Weibull': return _weibull_cdf(model, x, t_horizon) - if model.dist == 'LogNormal': return _lognormal_cdf(model, x, t_horizon)
        @@ -364,7 +365,6 @@

        Functions

        torch.no_grad() if model.dist == 'Weibull': return _weibull_cdf(model, x, t_horizon) - if model.dist == 'LogNormal': return _lognormal_cdf(model, x, t_horizon)
  • diff --git a/docs/utilities.html b/docs/utilities.html index 7a86166..aa9dece 100644 --- a/docs/utilities.html +++ b/docs/utilities.html @@ -29,22 +29,27 @@

    Module dsm.utilities

    Expand source code
    # coding=utf-8
    -# Copyright 2020 Chirag Nagpal, Auton Lab.
    +# Copyright 2020 Chirag Nagpal
     #
    -# Licensed under the Apache License, Version 2.0 (the "License");
    -# you may not use this file except in compliance with the License.
    -# You may obtain a copy of the License at
    -#
    -#     http://www.apache.org/licenses/LICENSE-2.0
    -#
    -# Unless required by applicable law or agreed to in writing, software
    -# distributed under the License is distributed on an "AS IS" BASIS,
    -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    -# See the License for the specific language governing permissions and
    -# limitations under the License.
    +# This file is part of Deep Survival Machines.
    +
    +# Deep Survival Machines is free software: you can redistribute it and/or modify
    +# it under the terms of the GNU General Public License as published by
    +# the Free Software Foundation, either version 3 of the License, or
    +# (at your option) any later version.
    +
    +# Deep Survival Machines is distributed in the hope that it will be useful,
    +# but WITHOUT ANY WARRANTY; without even the implied warranty of
    +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    +# GNU General Public License for more details.
    +
    +# You should have received a copy of the GNU General Public License
    +# along with Deep Survival Machines.
    +# If not, see <https://www.gnu.org/licenses/>.
     
     """Utility functions to train the Deep Survival Machines models"""
     
    +from dsm.dsm_torch import DeepSurvivalMachinesTorch
     from dsm.losses import unconditional_loss, conditional_loss
     
     from tqdm import tqdm
    @@ -55,7 +60,6 @@ 

    Module dsm.utilities

    import gc -from dsm.dsm_torch import DeepSurvivalMachinesTorch def get_optimizer(model, lr): @@ -66,18 +70,18 @@

    Module dsm.utilities

    elif model.optimizer == 'RMSProp': return torch.optim.RMSprop(model.parameters(), lr=lr) else: - raise NotImplementedError("Optimizer "+model.optimizer+ - " is not implemented") - + raise NotImplementedError('Optimizer '+model.optimizer+ + ' is not implemented') + def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, n_iter=10000, lr=1e-2, thres=1e-4): premodel = DeepSurvivalMachinesTorch(1, 1, - init=False, dist=model.dist) + dist=model.dist) premodel.double() optimizer = torch.optim.Adam(premodel.parameters(), lr=lr) - oldcost = -float('inf') + oldcost = float('inf') patience = 0 costs = [] @@ -102,6 +106,10 @@

    Module dsm.utilities

    return premodel +def _reshape_tensor_with_nans(data): + """Helper function to unroll padded RNN inputs""" + data = data.reshape(-1) + return data[~torch.isnan(data)] def train_dsm(model, x_train, t_train, e_train, @@ -111,24 +119,27 @@

    Module dsm.utilities

    print('Pretraining the Underlying Distributions...') + print(t_train.shape, e_train.shape) + + t_train_ = _reshape_tensor_with_nans(t_train) + e_train_ = _reshape_tensor_with_nans(e_train) + t_valid_ = _reshape_tensor_with_nans(t_valid) + e_valid_ = _reshape_tensor_with_nans(e_valid) + + print(t_train_.shape, e_train_.shape) + premodel = pretrain_dsm(model, - t_train, - e_train, - t_valid, - e_valid, + t_train_, + e_train_, + t_valid_, + e_valid_, n_iter=10000, lr=1e-2, thres=1e-4) model.shape.data.fill_(float(premodel.shape)) model.scale.data.fill_(float(premodel.scale)) - # print(premodel.shape, premodel.scale) - # print(model.shape, model.scale) - - # init=(float(premodel.shape[0]), - # float(premodel.scale[0])), - # print(torch.exp(-premodel.scale).cpu().data.numpy()[0], - # torch.exp(premodel.shape).cpu().data.numpy()[0]) + print(float(premodel.shape), float(premodel.scale)) model.double() optimizer = torch.optim.Adam(model.parameters(), lr=lr) @@ -144,19 +155,24 @@

    Module dsm.utilities

    for i in tqdm(range(n_iter)): for j in range(nbatches): + xb = x_train[j*bs:(j+1)*bs] + tb = t_train[j*bs:(j+1)*bs] + eb = e_train[j*bs:(j+1)*bs] + optimizer.zero_grad() loss = conditional_loss(model, - x_train[j*bs:(j+1)*bs], - t_train[j*bs:(j+1)*bs], - e_train[j*bs:(j+1)*bs], + xb, + _reshape_tensor_with_nans(tb), + _reshape_tensor_with_nans(eb), elbo=elbo) + #print ("Train Loss:", float(loss)) loss.backward() optimizer.step() valid_loss = conditional_loss(model, x_valid, - t_valid, - e_valid, + t_valid_, + e_valid_, elbo=False) valid_loss = valid_loss.detach().cpu().numpy() @@ -206,8 +222,8 @@

    Functions

    elif model.optimizer == 'RMSProp': return torch.optim.RMSprop(model.parameters(), lr=lr) else: - raise NotImplementedError("Optimizer "+model.optimizer+ - " is not implemented")
    + raise NotImplementedError('Optimizer '+model.optimizer+ + ' is not implemented')
    @@ -223,11 +239,11 @@

    Functions

    n_iter=10000, lr=1e-2, thres=1e-4): premodel = DeepSurvivalMachinesTorch(1, 1, - init=False, dist=model.dist) + dist=model.dist) premodel.double() optimizer = torch.optim.Adam(premodel.parameters(), lr=lr) - oldcost = -float('inf') + oldcost = float('inf') patience = 0 costs = [] @@ -270,24 +286,27 @@

    Functions

    print('Pretraining the Underlying Distributions...') + print(t_train.shape, e_train.shape) + + t_train_ = _reshape_tensor_with_nans(t_train) + e_train_ = _reshape_tensor_with_nans(e_train) + t_valid_ = _reshape_tensor_with_nans(t_valid) + e_valid_ = _reshape_tensor_with_nans(e_valid) + + print(t_train_.shape, e_train_.shape) + premodel = pretrain_dsm(model, - t_train, - e_train, - t_valid, - e_valid, + t_train_, + e_train_, + t_valid_, + e_valid_, n_iter=10000, lr=1e-2, thres=1e-4) model.shape.data.fill_(float(premodel.shape)) model.scale.data.fill_(float(premodel.scale)) - # print(premodel.shape, premodel.scale) - # print(model.shape, model.scale) - - # init=(float(premodel.shape[0]), - # float(premodel.scale[0])), - # print(torch.exp(-premodel.scale).cpu().data.numpy()[0], - # torch.exp(premodel.shape).cpu().data.numpy()[0]) + print(float(premodel.shape), float(premodel.scale)) model.double() optimizer = torch.optim.Adam(model.parameters(), lr=lr) @@ -303,19 +322,24 @@

    Functions

    for i in tqdm(range(n_iter)): for j in range(nbatches): + xb = x_train[j*bs:(j+1)*bs] + tb = t_train[j*bs:(j+1)*bs] + eb = e_train[j*bs:(j+1)*bs] + optimizer.zero_grad() loss = conditional_loss(model, - x_train[j*bs:(j+1)*bs], - t_train[j*bs:(j+1)*bs], - e_train[j*bs:(j+1)*bs], + xb, + _reshape_tensor_with_nans(tb), + _reshape_tensor_with_nans(eb), elbo=elbo) + #print ("Train Loss:", float(loss)) loss.backward() optimizer.step() valid_loss = conditional_loss(model, x_valid, - t_valid, - e_valid, + t_valid_, + e_valid_, elbo=False) valid_loss = valid_loss.detach().cpu().numpy()