Skip to content

Commit

Permalink
Normalize quaternions to make valid quaternion
Browse files Browse the repository at this point in the history
Former-commit-id: 0d58e99
  • Loading branch information
jihoonerd committed Sep 28, 2021
1 parent 7dae8c1 commit 5f013bc
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
9 changes: 5 additions & 4 deletions test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,18 @@ def test(opt, device):
pred_global_pos[0,-1] = gt_global_pos[0,-1]

pred_global_rot = output[1:,:,pos_dim:].permute(1,0,2).reshape(1,horizon-1,22,4)
pred_global_rot_normalized = nn.functional.normalize(pred_global_rot, p=2.0, dim=3)
gt_global_rot = global_q[test_idx[i]:test_idx[i]+1]
pred_global_rot[0,0] = gt_global_rot[0,0]
pred_global_rot[0,-1] = gt_global_rot[0,-1]
pred_rot_npss.append(pred_global_rot)
pred_global_rot_normalized[0,0] = gt_global_rot[0,0]
pred_global_rot_normalized[0,-1] = gt_global_rot[0,-1]
pred_rot_npss.append(pred_global_rot_normalized)

# Normalize for L2P
normalized_gt_pos = torch.Tensor((lafan_dataset.data['global_pos'][test_idx[i]:test_idx[i]+1, from_idx:target_idx+1].reshape(1, -1, lafan_dataset.num_joints * 3).transpose(0,2,1) - x_mean) / x_std)
normalized_pred_pos = torch.Tensor((pred_global_pos.reshape(1, -1, lafan_dataset.num_joints * 3).transpose(0,2,1) - x_mean) / x_std)

l2p.append(torch.mean(torch.norm(normalized_pred_pos[0] - normalized_gt_pos[0], dim=(0))).item())
l2q.append(torch.mean(torch.norm(pred_global_rot[0] - global_q[test_idx[i]], dim=(1,2))).item())
l2q.append(torch.mean(torch.norm(pred_global_rot_normalized[0] - global_q[test_idx[i]], dim=(1,2))).item())
print(f"ID {test_idx[i]}: test completed.")

l2p_mean = np.mean(l2p)
Expand Down
8 changes: 6 additions & 2 deletions train_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,12 @@ def train(opt, device):
recon_pos_loss.append(opt.loss_pos_weight * pos_loss)

rot_pred = output[1:,:,pos_dim:].permute(1,0,2)
rot_pred_reshaped = rot_pred.reshape(rot_pred.shape[0], rot_pred.shape[1], lafan_dataset.num_joints, 4)
rot_pred_normalized = nn.functional.normalize(rot_pred_reshaped, p=2.0, dim=3)

rot_gt = minibatch_pose_gt[:,:,pos_dim:]
rot_loss = l1_loss(rot_pred, rot_gt)
rot_gt_reshaped = rot_gt.reshape(rot_gt.shape[0], rot_gt.shape[1], lafan_dataset.num_joints, 4)
rot_loss = l1_loss(rot_pred_normalized, rot_gt_reshaped)
recon_rot_loss.append(opt.loss_rot_weight * rot_loss)

total_g_loss = opt.loss_pos_weight * pos_loss + \
Expand Down Expand Up @@ -190,7 +194,7 @@ def parse_opt():
parser.add_argument('--processed_data_dir', type=str, default='processed_data_original/', help='path to save pickled processed data')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--device', default='1', help='cuda device, i.e. 0 or -1 or cpu')
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or -1 or cpu')
parser.add_argument('--entity', default=None, help='W&B entity')
parser.add_argument('--exp_name', default='exp', help='save to project/name')
parser.add_argument('--save_interval', type=int, default=1, help='Log model after every "save_period" epoch')
Expand Down

0 comments on commit 5f013bc

Please sign in to comment.