Skip to content

Commit

Permalink
modified: dsm/__init__.py
Browse files Browse the repository at this point in the history
	modified:   dsm/dsm_api.py
	modified:   dsm/dsm_torch.py
	modified:   dsm/losses.py
  • Loading branch information
chiragnagpal committed Nov 6, 2020
1 parent be7a102 commit 6b14c2b
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 20 deletions.
18 changes: 7 additions & 11 deletions dsm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,12 @@
# along with Deep Survival Machines.
# If not, see <https://www.gnu.org/licenses/>.

"""
[![Build Status](https://travis-ci.org/autonlab/DeepSurvivalMachines.svg?\
branch=master)](https://travis-ci.org/autonlab/DeepSurvivalMachines)
&nbsp;&nbsp;&nbsp;\
[![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)]\
(https://www.gnu.org/licenses/gpl-3.0)
&nbsp;&nbsp;&nbsp;\
[![GitHub Repo stars](https://img.shields.io/github/stars/autonlab/Deep\
SurvivalMachines?style=social)](https://github.com/autonlab/DeepSurvival\
Machines)
r"""
[![Build Status](https://travis-ci.org/autonlab/DeepSurvivalMachines.svg?branch=master)](https://travis-ci.org/autonlab/DeepSurvivalMachines)
&nbsp;&nbsp;&nbsp;
[![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)
&nbsp;&nbsp;&nbsp;
[![GitHub Repo stars](https://img.shields.io/github/stars/autonlab/DeepSurvivalMachines?style=social)](https://github.com/autonlab/DeepSurvivalMachines)
Python package `dsm` provides an API to train the Deep Survival Machines
Expand Down Expand Up @@ -77,7 +73,7 @@
data like vital signs, degradation monitoring signals in predictive
maintainance. **DRSM** allows the learnt representations at each time step to
involve historical context from previous time steps. **DRSM** implementation in
`dsm` is carried out through an easy to use API,
`dsm` is carried out through an easy to use API,
`DeepRecurrentSurvivalMachines` that accepts lists of data streams and
corresponding failure times. The module automatically takes care of appropriate
batching and padding of variable length sequences.
Expand Down
18 changes: 14 additions & 4 deletions dsm/dsm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def fit(self, x, t, e, vsize=0.15,
iters=1, learning_rate=1e-3, batch_size=100,
elbo=True, optimizer="Adam", random_state=100):

"""This method is used to train an instance of the DSM model.
r"""This method is used to train an instance of the DSM model.
Parameters
----------
Expand Down Expand Up @@ -137,7 +137,7 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state):


def predict_risk(self, x, t):
"""Returns the estimated risk of an event occuring before time \( t \)
r"""Returns the estimated risk of an event occuring before time \( t \)
\( \widehat{\mathbb{P}}(T\leq t|X) \) for some input data \( x \).
Parameters
Expand All @@ -161,7 +161,7 @@ def predict_risk(self, x, t):


def predict_survival(self, x, t):
"""Returns the estimated survival probability at time \( t \),
r"""Returns the estimated survival probability at time \( t \),
\( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \).
Parameters
Expand Down Expand Up @@ -249,15 +249,23 @@ class DeepRecurrentSurvivalMachines(DSMBase):
"""

def __init__(self, k=3, layers=None, hidden=None,
distribution='Weibull', temp=1000., discount=1.0, typ='LSTM'):
super(DeepRecurrentSurvivalMachines, self).__init__(k=k, layers=layers,
distribution=distribution)
self.hidden = hidden
self.typ = typ
def _gen_torch_model(self, inputdim, optimizer):
"""Helper function to return a torch model."""
return DeepRecurrentSurvivalMachinesTorch(inputdim,
k=self.k,
layers=self.layers,
hidden=self.hidden,
dist=self.dist,
temp=self.temp,
discount=self.discount,
optimizer=optimizer)
optimizer=optimizer,
typ=self.typ)

def _prepocess_test_data(self, x):
return torch.from_numpy(_get_padded_features(x))
Expand All @@ -273,6 +281,8 @@ def _prepocess_training_data(self, x, t, e, vsize, random_state):
t = _get_padded_targets(t)
e = _get_padded_targets(e)

print (x.shape)

x_train, t_train, e_train = x[idx], t[idx], e[idx]

x_train = torch.from_numpy(x_train).double()
Expand Down
10 changes: 6 additions & 4 deletions dsm/dsm_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@


def create_representation(inputdim, layers, activation):
"""Helper function to generate the representation function for DSM.
r"""Helper function to generate the representation function for DSM.
Deep Survival Machines learns a representation (\ Phi(X) \) for the input
data. This representation is parameterized using a Non Linear Multilayer
Expand Down Expand Up @@ -258,10 +258,12 @@ def __init__(self, inputdim, k, typ='LSTM', layers=1,
bias=False, batch_first=True)
if self.typ == 'RNN':
self.embedding = nn.RNN(inputdim, hidden, layers,
bias=False, batch_first=True,
nonlinearity='relu')
if self.typ == 'GRU':
self.embedding = nn.GRU(inputdim, hidden, layers,
bias=False, batch_first=True)



def forward(self, x):
"""The forward function that is called when data is passed through DSM.
Expand All @@ -280,7 +282,7 @@ def forward(self, x):
xrep, _ = self.embedding(x)
xrep = xrep.contiguous().view(-1, self.hidden)
xrep = xrep[inputmask]
#xrep = nn.ReLU6()(xrep)
xrep = nn.ReLU6()(xrep)
return(self.act(self.shapeg(xrep))+self.shape.expand(xrep.shape[0], -1),
self.act(self.scaleg(xrep))+self.scale.expand(xrep.shape[0], -1),
self.gate(xrep)/self.temp)
Expand Down
2 changes: 1 addition & 1 deletion dsm/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,4 @@ def predict_cdf(model, x, t_horizon):
return _lognormal_cdf(model, x, t_horizon)
else:
raise NotImplementedError('Distribution: '+model.dist+
' not implemented yet.')
' not implemented yet.')

0 comments on commit 6b14c2b

Please sign in to comment.