Skip to content

Commit

Permalink
Add pass to cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
nunonmg committed Sep 8, 2021
1 parent ed96c93 commit dd01298
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lpsmap/api/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class _LPSparseMAP(torch.autograd.Function):

@classmethod
def forward(cls, ctx, fg, eta_u, *eta_v):
fg.set_log_potentials(eta_u.detach().numpy())
detached_eta_v = [np.atleast_1d(x.detach().numpy()) for x in eta_v]
fg.set_log_potentials(eta_u.detach().cpu().numpy())
detached_eta_v = [np.atleast_1d(x.detach().cpu().numpy()) for x in eta_v]
fg.set_all_additionals(detached_eta_v)
ctx.fg = fg
ctx.shape = eta_u.shape
Expand All @@ -30,7 +30,7 @@ def backward(cls, ctx, du):
dtype = du.dtype
device = du.device

du = du.to(dtype=torch.double, device="cpu").detach().numpy()
du = du.to(dtype=torch.double, device="cpu").detach().cpu().numpy()
out = torch.empty(ctx.shape, dtype=torch.double, device='cpu')
add = ctx.fg.jacobian_vec(du, out.numpy())

Expand Down

0 comments on commit dd01298

Please sign in to comment.