From 4093ba5de2a2c1eb12f6807fefa2f09df09b7be6 Mon Sep 17 00:00:00 2001 From: Chirag Nagpal Date: Mon, 14 Dec 2020 21:46:21 +0530 Subject: [PATCH] modified: dsm/dsm_api.py modified: dsm/dsm_torch.py --- dsm/dsm_api.py | 4 ++-- dsm/dsm_torch.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dsm/dsm_api.py b/dsm/dsm_api.py index 4ccd340..a205029 100644 --- a/dsm/dsm_api.py +++ b/dsm/dsm_api.py @@ -274,8 +274,8 @@ class DeepRecurrentSurvivalMachines(DSMBase): """ - def __init__(self, k=3, layers=None, hidden=None, - distribution='Weibull', temp=1000., discount=1.0, typ='LSTM'): + def __init__(self, k=3, layers=None, hidden=None, + distribution="Weibull", temp=1000., discount=1.0, typ="LSTM"): super(DeepRecurrentSurvivalMachines, self).__init__(k=k, layers=layers, distribution=distribution, diff --git a/dsm/dsm_torch.py b/dsm/dsm_torch.py index 23bce8d..1a19483 100644 --- a/dsm/dsm_torch.py +++ b/dsm/dsm_torch.py @@ -38,7 +38,7 @@ __pdoc__ = {} -for clsn in ['DeepSurvivalMachinesTorch', +for clsn in ['DeepSurvivalMachinesTorch', 'DeepRecurrentSurvivalMachinesTorch']: for membr in ['training', 'dump_patches']: @@ -269,7 +269,7 @@ def __init__(self, inputdim, k, typ='LSTM', layers=1, self.act = nn.SELU() self.shape = nn.ParameterDict({str(r+1): nn.Parameter(-torch.ones(k)) for r in range(self.risks)}) - self.scale = nn.ParameterDict({str(r+1):nn.Parameter(-torch.ones(k)) + self.scale = nn.ParameterDict({str(r+1): nn.Parameter(-torch.ones(k)) for r in range(self.risks)}) elif self.dist in ['Normal']: self.act = nn.Identity()