Skip to content

Commit

Permalink
Trivial fix
Browse files Browse the repository at this point in the history
Former-commit-id: 7dae8c1 [formerly 7dae8c1 [formerly 94d5422]]
Former-commit-id: e35c0b63a5fdcc29aefd4b9eea114dab1d938463
Former-commit-id: 12bfdb5
  • Loading branch information
jihoonerd committed Sep 28, 2021
1 parent 95a27ce commit 4a478b6
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
3 changes: 1 addition & 2 deletions test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,8 @@ def test(opt, device):
l2p_mean = np.mean(l2p)
l2q_mean = np.mean(l2q)

pred_quaternions = torch.cat(pred_rot_npss, dim=0)

# Drop end nodes for fair comparison
pred_quaternions = torch.cat(pred_rot_npss, dim=0)
npss_gt = global_q[:,:,skeleton_mocap.has_children()].reshape(global_q.shape[0],global_q.shape[1], -1)
npss_pred = pred_quaternions[:,:,skeleton_mocap.has_children()].reshape(pred_quaternions.shape[0],pred_quaternions.shape[1], -1)
npss = benchmarks.npss(npss_gt, npss_pred).item()
Expand Down
2 changes: 1 addition & 1 deletion train_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def train(opt, device):
rot_pred = output[1:,:,pos_dim:].permute(1,0,2)
rot_gt = minibatch_pose_gt[:,:,pos_dim:]
rot_loss = l1_loss(rot_pred, rot_gt)
recon_rot_loss.append(opt.loss_rot_weight * rot_loss)
recon_rot_loss.append(opt.loss_rot_weight * rot_loss)

total_g_loss = opt.loss_pos_weight * pos_loss + \
opt.loss_rot_weight * rot_loss + \
Expand Down

0 comments on commit 4a478b6

Please sign in to comment.