From dd01298d00727bde87e30da2bf493ce8831c01c8 Mon Sep 17 00:00:00 2001 From: nunonmg Date: Wed, 8 Sep 2021 15:07:17 +0100 Subject: [PATCH] Add pass to cpu --- lpsmap/api/autograd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lpsmap/api/autograd.py b/lpsmap/api/autograd.py index f47683c..531d460 100644 --- a/lpsmap/api/autograd.py +++ b/lpsmap/api/autograd.py @@ -17,8 +17,8 @@ class _LPSparseMAP(torch.autograd.Function): @classmethod def forward(cls, ctx, fg, eta_u, *eta_v): - fg.set_log_potentials(eta_u.detach().numpy()) - detached_eta_v = [np.atleast_1d(x.detach().numpy()) for x in eta_v] + fg.set_log_potentials(eta_u.detach().cpu().numpy()) + detached_eta_v = [np.atleast_1d(x.detach().cpu().numpy()) for x in eta_v] fg.set_all_additionals(detached_eta_v) ctx.fg = fg ctx.shape = eta_u.shape @@ -30,7 +30,7 @@ def backward(cls, ctx, du): dtype = du.dtype device = du.device - du = du.to(dtype=torch.double, device="cpu").detach().numpy() + du = du.to(dtype=torch.double, device="cpu").detach().cpu().numpy() out = torch.empty(ctx.shape, dtype=torch.double, device='cpu') add = ctx.fg.jacobian_vec(du, out.numpy())