From 1357c32df3a03e2f9914b8db1402e78c1c684eda Mon Sep 17 00:00:00 2001 From: Jihoon Kim Date: Tue, 28 Sep 2021 14:39:13 +0900 Subject: [PATCH] Normalize quaternions to make valid quaternion Former-commit-id: 5f013bc684dfa84c47324c370eb0336aeae56cad [formerly 5f013bc684dfa84c47324c370eb0336aeae56cad [formerly 0d58e99160651e0b52b3c39c24af74fde46700e0]] Former-commit-id: bb01147058302a83b83c6f7951d492823f450d99 Former-commit-id: 9b0980c085795dd47281cfe0119164f26469fbde --- test_benchmark.py | 9 +++++---- train_mmm.py | 8 ++++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/test_benchmark.py b/test_benchmark.py index bf500c4..5634f42 100644 --- a/test_benchmark.py +++ b/test_benchmark.py @@ -149,17 +149,18 @@ def test(opt, device): pred_global_pos[0,-1] = gt_global_pos[0,-1] pred_global_rot = output[1:,:,pos_dim:].permute(1,0,2).reshape(1,horizon-1,22,4) + pred_global_rot_normalized = nn.functional.normalize(pred_global_rot, p=2.0, dim=3) gt_global_rot = global_q[test_idx[i]:test_idx[i]+1] - pred_global_rot[0,0] = gt_global_rot[0,0] - pred_global_rot[0,-1] = gt_global_rot[0,-1] - pred_rot_npss.append(pred_global_rot) + pred_global_rot_normalized[0,0] = gt_global_rot[0,0] + pred_global_rot_normalized[0,-1] = gt_global_rot[0,-1] + pred_rot_npss.append(pred_global_rot_normalized) # Normalize for L2P normalized_gt_pos = torch.Tensor((lafan_dataset.data['global_pos'][test_idx[i]:test_idx[i]+1, from_idx:target_idx+1].reshape(1, -1, lafan_dataset.num_joints * 3).transpose(0,2,1) - x_mean) / x_std) normalized_pred_pos = torch.Tensor((pred_global_pos.reshape(1, -1, lafan_dataset.num_joints * 3).transpose(0,2,1) - x_mean) / x_std) l2p.append(torch.mean(torch.norm(normalized_pred_pos[0] - normalized_gt_pos[0], dim=(0))).item()) - l2q.append(torch.mean(torch.norm(pred_global_rot[0] - global_q[test_idx[i]], dim=(1,2))).item()) + l2q.append(torch.mean(torch.norm(pred_global_rot_normalized[0] - global_q[test_idx[i]], dim=(1,2))).item()) print(f"ID {test_idx[i]}: test completed.") l2p_mean = np.mean(l2p) diff --git a/train_mmm.py b/train_mmm.py index 888dbe5..1f824cd 100644 --- a/train_mmm.py +++ b/train_mmm.py @@ -133,8 +133,12 @@ def train(opt, device): recon_pos_loss.append(opt.loss_pos_weight * pos_loss) rot_pred = output[1:,:,pos_dim:].permute(1,0,2) + rot_pred_reshaped = rot_pred.reshape(rot_pred.shape[0], rot_pred.shape[1], lafan_dataset.num_joints, 4) + rot_pred_normalized = nn.functional.normalize(rot_pred_reshaped, p=2.0, dim=3) + rot_gt = minibatch_pose_gt[:,:,pos_dim:] - rot_loss = l1_loss(rot_pred, rot_gt) + rot_gt_reshaped = rot_gt.reshape(rot_gt.shape[0], rot_gt.shape[1], lafan_dataset.num_joints, 4) + rot_loss = l1_loss(rot_pred_normalized, rot_gt_reshaped) recon_rot_loss.append(opt.loss_rot_weight * rot_loss) total_g_loss = opt.loss_pos_weight * pos_loss + \ @@ -190,7 +194,7 @@ def parse_opt(): parser.add_argument('--processed_data_dir', type=str, default='processed_data_original/', help='path to save pickled processed data') parser.add_argument('--batch_size', type=int, default=128, help='batch size') parser.add_argument('--epochs', type=int, default=1000) - parser.add_argument('--device', default='1', help='cuda device, i.e. 0 or -1 or cpu') + parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or -1 or cpu') parser.add_argument('--entity', default=None, help='W&B entity') parser.add_argument('--exp_name', default='exp', help='save to project/name') parser.add_argument('--save_interval', type=int, default=1, help='Log model after every "save_period" epoch')