Skip to content

Commit

Permalink
input scan points normazlied to the center of scan point to avoid num…
Browse files Browse the repository at this point in the history
…beric optim issue in training
  • Loading branch information
alexandor91 committed Dec 3, 2024
1 parent 917b8f2 commit 09454dc
Showing 1 changed file with 49 additions and 6 deletions.
55 changes: 49 additions & 6 deletions src/train_eval_egnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,15 +732,58 @@ def train_one_epoch(model, dataloader, optimizer, device, epoch, writer, use_poi
corr = corr.to(device)
labels = labels.to(device)
xyz_0, xyz_1 = src_pts.to(device), tar_pts.to(device)
xyz_0_center = src_pts.mean(dim=1, keepdim=True).to(device)
xyz_1_center = tar_pts.mean(dim=1, keepdim=True).to(device)
xyz_0 = xyz_0 - xyz_0_center
xyz_1 = xyz_1 - xyz_1_center
scale1 = xyz_0.norm(dim=2).max(dim=1, keepdim=True).values.to(device)
scale2 = xyz_1.norm(dim=2).max(dim=1, keepdim=True).values.to(device)
# xyz_0 = xyz_0 / scale1
# xyz_1 = xyz_1/ scale2
feat_0, feat_1 = src_features.to(device), tgt_features.to(device)
gt_pose = gt_pose.to(device)

R = gt_pose[:, :3, :3].to(device)
T = gt_pose[:, :3, 3].to(device)
# Adjust xyz_0_center to match R's expected shape
# Ensure xyz_0_center is of shape [1, 3, 1]
xyz_0_center = xyz_0_center.squeeze(1).unsqueeze(-1) # From [1, 1, 3] -> [1, 3] -> [1, 3, 1]

# Ensure xyz_1_center is of shape [1, 3, 1]
xyz_1_center = xyz_1_center.squeeze(1).unsqueeze(-1) # From [1, 1, 3] -> [1, 3] -> [1, 3, 1]

# Ensure T is of shape [1, 3, 1]
T = T.unsqueeze(-1) # From [1, 3] -> [1, 3, 1]

# Perform the computation
T_norm = (T - xyz_1_center + torch.bmm(R, xyz_0_center)).to(device) # Result shape: [1, 3, 1]
gt_pose[:, :3, :3] = R
gt_pose[:, :3, 3] = T_norm.transpose(1, 2)
# T = T.squeeze(-1) # Shape: [B, 3] -> T is already [B, 3, 1], no need for transpose.
# T = T.transpose(1, 2)
# print(R.shape)
# print(T.shape)
# print(xyz_0_center.shape)
# print(gt_pose.shape)
# # gt_pose_norm = (gt_pose - xyz_1_center + R @ xyz_0_center.T) /scale2
# gt_pose_norm = (T - xyz_1_center + torch.bmm(R, xyz_0_center.T)).to(device)
# gt_pose = gt_pose_norm

# # Step 1: Append a homogeneous coordinate (1) to xyz_0
# ones = torch.ones(1, xyz_0.size(1), 1, device=xyz_0.device) # Shape: 1 x N x 1
# xyz_homo = torch.cat([xyz_0, ones], dim=-1) # Shape: 1 x N x 4

# # Step 2: Apply the extrinsic transformation
# xyz_transformed_homo = torch.matmul(xyz_homo, gt_pose) # Shape: 1 x N x 4

# # Step 3: Convert back to 3D by dropping the homogeneous coordinate
# xyz_transformed = xyz_transformed_homo[..., :3] # Shape: 1 x N x 3
# xyz_diff = xyz_transformed - xyz_1
# print("################@@@@@@@@@@@@@@@@######################")
# print(xyz_0.shape)
# print(xyz_1.shape)
# print(feat_0.shape)
# print(feat_1.shape)

# # print(xyz_0)
# # print(xyz_1)
# # print(feat_0)
# # print(feat_1)
# print(xyz_diff)
# Initialize KNN graphs for source and target point clouds
k = 12
####### remove the batch size when it is one ##############
Expand Down

0 comments on commit 09454dc

Please sign in to comment.