Skip to content

Commit

Permalink
modified: __init__.py
Browse files Browse the repository at this point in the history
	modified:   dcm_utilities.py
  • Loading branch information
chiragnagpal committed Mar 29, 2022
1 parent 2806a45 commit 45c7f38
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
21 changes: 14 additions & 7 deletions auton_survival/models/dcm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,16 @@ class DeepCoxMixtures:
"""

def __init__(self, k=3, layers=None, gamma=10,
smoothing_factor=1e-4, use_activation=False):
smoothing_factor=1e-4, use_activation=False,
random_seed=0):

self.k = k
self.layers = layers
self.fitted = False
self.gamma = gamma
self.smoothing_factor = smoothing_factor
self.use_activation = use_activation
self.random_seed = random_seed

def __call__(self):
if self.fitted:
Expand All @@ -109,10 +111,10 @@ def __call__(self):
def _preprocess_test_data(self, x):
return torch.from_numpy(x).float()

def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state):
def _preprocess_training_data(self, x, t, e, vsize, val_data, random_seed):

idx = list(range(x.shape[0]))
np.random.seed(random_state)
np.random.seed(random_seed)
np.random.shuffle(idx)
x_train, t_train, e_train = x[idx], t[idx], e[idx]

Expand Down Expand Up @@ -141,6 +143,10 @@ def _preprocess_training_data(self, x, t, e, vsize, val_data, random_state):

def _gen_torch_model(self, inputdim, optimizer):
"""Helper function to return a torch model."""

np.random.seed(self.random_seed)
torch.manual_seed(self.random_seed)

return DeepCoxMixturesTorch(inputdim,
k=self.k,
gamma=self.gamma,
Expand All @@ -150,7 +156,7 @@ def _gen_torch_model(self, inputdim, optimizer):

def fit(self, x, t, e, vsize=0.15, val_data=None,
iters=1, learning_rate=1e-3, batch_size=100,
optimizer="Adam", random_state=100):
optimizer="Adam"):

r"""This method is used to train an instance of the DSM model.
Expand All @@ -177,14 +183,14 @@ def fit(self, x, t, e, vsize=0.15, val_data=None,
optimizer: str
The choice of the gradient based optimization method. One of
'Adam', 'RMSProp' or 'SGD'.
random_state: float
random_seed: float
random seed that determines how the validation set is chosen.
"""

processed_data = self._preprocess_training_data(x, t, e,
vsize, val_data,
random_state)
self.random_seed)
x_train, t_train, e_train, x_val, t_val, e_val = processed_data

#Todo: Change this somehow. The base design shouldn't depend on child
Expand All @@ -201,7 +207,8 @@ def fit(self, x, t, e, vsize=0.15, val_data=None,
bs=batch_size,
return_losses=True,
smoothing_factor=self.smoothing_factor,
use_posteriors=True)
use_posteriors=True,
random_seed=self.random_seed)

self.torch_model = (model[0].eval(), model[1])
self.fitted = True
Expand Down
6 changes: 3 additions & 3 deletions auton_survival/models/dcm/dcm_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,12 @@ def test_step(model, x, t, e, breslow_splines, loss='q', typ='soft'):

def train_dcm(model, train_data, val_data, epochs=50,
patience=3, vloss='q', bs=256, typ='soft', lr=1e-3,
use_posteriors=True, debug=False, random_state=0,
use_posteriors=True, debug=False, random_seed=0,
return_losses=False, update_splines_after=10,
smoothing_factor=1e-2):

torch.manual_seed(random_state)
np.random.seed(random_state)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

if val_data is None:
val_data = train_data
Expand Down

0 comments on commit 45c7f38

Please sign in to comment.