From c123ddebd5f5518ce60c6bbf8aa73b773d1368ea Mon Sep 17 00:00:00 2001 From: Chirag Nagpal Date: Sat, 9 Jul 2022 16:42:34 -0400 Subject: [PATCH] Update losses.py --- auton_survival/models/dsm/losses.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/auton_survival/models/dsm/losses.py b/auton_survival/models/dsm/losses.py index c74d744..4043ea3 100644 --- a/auton_survival/models/dsm/losses.py +++ b/auton_survival/models/dsm/losses.py @@ -303,7 +303,7 @@ def _weibull_pdf(model, x, t_horizon, risk='1'): k_ = shape b_ = scale - t_horz = torch.tensor(t_horizon).double() + t_horz = torch.tensor(t_horizon).double().to(x.device) t_horz = t_horz.repeat(shape.shape[0], 1) pdfs = [] @@ -338,7 +338,7 @@ def _weibull_cdf(model, x, t_horizon, risk='1'): k_ = shape b_ = scale - t_horz = torch.tensor(t_horizon).double() + t_horz = torch.tensor(t_horizon).double().to(x.device) t_horz = t_horz.repeat(shape.shape[0], 1) cdfs = [] @@ -401,7 +401,7 @@ def _lognormal_cdf(model, x, t_horizon, risk='1'): k_ = shape b_ = scale - t_horz = torch.tensor(t_horizon).double() + t_horz = torch.tensor(t_horizon).double().to(x.device) t_horz = t_horz.repeat(shape.shape[0], 1) cdfs = [] @@ -438,7 +438,7 @@ def _normal_cdf(model, x, t_horizon, risk='1'): k_ = shape b_ = scale - t_horz = torch.tensor(t_horizon).double() + t_horz = torch.tensor(t_horizon).double().to(x.device) t_horz = t_horz.repeat(shape.shape[0], 1) cdfs = []