From c5d6942c819b261df18702e1726cb02bc369089b Mon Sep 17 00:00:00 2001 From: "kangxueyang@126.com" Date: Mon, 2 Dec 2024 01:41:55 +1100 Subject: [PATCH] only rotation and rank loss used now, layer norm added, along with egde embedding with invraiant feature and multi-head at 8-dim for multi-update instead of single high dimensional mlp layer --- datasets/ThreeDMatch.py | 16 +-- src/train_eval_egnn.py | 257 +++++++++++++++++++++++++--------------- 2 files changed, 168 insertions(+), 105 deletions(-) diff --git a/datasets/ThreeDMatch.py b/datasets/ThreeDMatch.py index a493473..89351af 100644 --- a/datasets/ThreeDMatch.py +++ b/datasets/ThreeDMatch.py @@ -327,8 +327,8 @@ def __getitem__(self, index): # Sample fixed number of points sample_size = self.num_node - if sample_size > N_src or sample_size > N_tgt: - print("Warning: Not enough sample points for the fixed number, sampling with repetitions.") + # if sample_size > N_src or sample_size > N_tgt: + # print("Warning: Not enough sample points for the fixed number, sampling with repetitions.") # Separate indices for positive and negative labels pos_indices = np.where(labels == 1)[0] @@ -338,8 +338,8 @@ def __getitem__(self, index): num_pos = int(self.num_node * 0.6) num_neg = self.num_node - num_pos - if len(pos_indices) < num_pos or len(neg_indices) < num_neg: - print("Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating sampling will be used!") + # if len(pos_indices) < num_pos or len(neg_indices) < num_neg: + # print("Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating sampling will be used!") if len(pos_indices) < 5: sampled_indices = np.random.choice(len(labels), self.num_node, replace=True) @@ -484,8 +484,8 @@ def __getitem__(self, index): # Sample fixed number of points sample_size = self.num_node - if sample_size > N_src or sample_size > N_tgt: - print("Warning: Not enough sample points for the fixed number, sampling with repetitions.") + # if sample_size > N_src or sample_size > N_tgt: + # print("Warning: Not enough sample points for the fixed number, sampling with repetitions.") # Separate indices for positive and negative labels pos_indices = np.where(labels == 1)[0] @@ -495,8 +495,8 @@ def __getitem__(self, index): num_pos = int(self.num_node * 0.6) num_neg = self.num_node - num_pos - if len(pos_indices) < num_pos or len(neg_indices) < num_neg: - print("Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating sampling will be used!") + # if len(pos_indices) < num_pos or len(neg_indices) < num_neg: + # print("Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating sampling will be used!") if len(pos_indices) < 5: diff --git a/src/train_eval_egnn.py b/src/train_eval_egnn.py index 5ae2e5f..840b05b 100644 --- a/src/train_eval_egnn.py +++ b/src/train_eval_egnn.py @@ -188,127 +188,122 @@ def compute_so3_matrix(x, graph_idx): return so3_flat -class E_GCL(nn.Module): - """ - E(n) Equivariant Convolutional Layer, basic moduel to buil up the eqnn network model - re - """ +# Helper: Compute Edge Distance and Other Features +def compute_edge_features(coord, edge_index): + row, col = edge_index + rel_coord = coord[row] - coord[col] # Relative positions + dist = torch.norm(rel_coord, dim=1, keepdim=True) # ∥xi - xj∥ + dot_product = (coord[row] * coord[col]).sum(dim=1, keepdim=True) # xi ⋅ xj + return rel_coord, dist, dot_product + - def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, act_fn=nn.SiLU(), residual=True, attention=False, normalize=False, coords_agg='mean', tanh=False, device='cuda:0'): +# Updated Edge Model with Richer Features +class E_GCL(nn.Module): + def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, num_heads=4, + act_fn=nn.SiLU(), residual=True, attention=False, normalize=False, tanh=False, device='cuda:0'): super(E_GCL, self).__init__() - input_edge = input_nf * 2 self.residual = residual self.attention = attention self.normalize = normalize - self.coords_agg = coords_agg self.tanh = tanh - self.epsilon = 1e-8 - edge_coords_nf = 1 - self.device =device - - # Edge MLP, with 9 additional dimensions for SO(3) flattened feature - self.edge_mlp = nn.Sequential( - nn.Linear(input_edge + edge_coords_nf + edges_in_d + 9, hidden_nf), ########2*input node feature, edge feature, edge attribute dimension features - act_fn, - nn.Linear(hidden_nf, hidden_nf), - act_fn) + self.num_heads = num_heads + self.device = device + input_edge = input_nf * 2 + edge_coords_nf = 1 + so3_feat_dim = 9 + feature_dim = input_edge + edges_in_d + edge_coords_nf + so3_feat_dim + 2 # Include distance, dot product + + # Multi-Head Edge MLP + self.edge_mlps = nn.ModuleList([ + nn.Sequential( + nn.Linear(feature_dim, hidden_nf // num_heads), + act_fn, + nn.Linear(hidden_nf // num_heads, hidden_nf // num_heads) + ) for _ in range(num_heads) + ]) + self.layer_norm = nn.LayerNorm(hidden_nf) # Layer norm after combining heads + + # Node MLP self.node_mlp = nn.Sequential( nn.Linear(hidden_nf + input_nf, hidden_nf), act_fn, - nn.Linear(hidden_nf, output_nf)) + nn.Linear(hidden_nf, output_nf) + ) + # Coordinate MLP layer = nn.Linear(hidden_nf, edge_coords_nf, bias=False) - torch.nn.init.xavier_uniform_(layer.weight, gain=1e-3) + nn.init.xavier_uniform_(layer.weight, gain=1e-3) - coord_mlp = [] - coord_mlp.append(nn.Linear(hidden_nf, hidden_nf)) - coord_mlp.append(act_fn) - coord_mlp.append(layer) + coord_mlp = [ + nn.Linear(hidden_nf, hidden_nf), + act_fn, + layer + ] if self.tanh: coord_mlp.append(nn.Tanh()) self.coord_mlp = nn.Sequential(*coord_mlp) - if self.attention: - self.att_mlp = nn.Sequential( - nn.Linear(hidden_nf, edge_coords_nf), - nn.Sigmoid()) - # Initialize SO(3) Tensor Product Layer - self.so3_tensor_product = SO3TensorProductLayer(input_dim=3, output_dim=hidden_nf) - def edge_model(self, source, target, radial, edge_attr, coord, edge_index): - # Compute the SO(3) matrix features (N x 9) - so3_flat = compute_so3_matrix(coord, edge_index) - # so3_output = self.so3_tensor_product(so3_flat) # Apply SO(3) tensor product layer + rel_coord, dist, dot_product = compute_edge_features(coord, edge_index) + so3_flat = compute_so3_matrix(coord, edge_index) # From your code + # dihedral = compute_dihedral_angles(rel_coord, so3_flat[:, :3]) # Placeholder - if edge_attr is None: # Unused for edge attributes. - out = torch.cat([source, target, radial, so3_flat], dim=1) - else: - # Concatenate source, target, radial, edge_attr, and SO(3) flattened features - out = torch.cat([source, target, radial, edge_attr, so3_flat], dim=1) + # Combine features + features = [source, target, radial, dist, dot_product, so3_flat] + if edge_attr is not None: + features.append(edge_attr) + out = torch.cat(features, dim=1) - out = self.edge_mlp(out) - if self.attention: - att_val = self.att_mlp(out) - out = out * att_val + # Multi-head MLP + head_outputs = [mlp(out) for mlp in self.edge_mlps] + combined = torch.cat(head_outputs, dim=1) + # Layer normalization + out = self.layer_norm(combined) return out - def node_model(self, x, edge_index, edge_attr, node_attr): - row, col = edge_index + def node_model(self, x, edge_index, edge_attr): + row = edge_index[0] agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) - if node_attr is not None: - agg = torch.cat([x, agg, node_attr], dim=1) - else: - agg = torch.cat([x, agg], dim=1) - out = self.node_mlp(agg) + out = torch.cat([x, agg], dim=1) + out = self.node_mlp(out) if self.residual: out = x + out - return out, agg + return out def coord_model(self, coord, edge_index, coord_diff, edge_feat): - row, col = edge_index - if edge_feat is not None: - trans = coord_diff * self.coord_mlp(edge_feat) - trans = coord_diff - # print(self.coord_mlp(edge_feat)) - if self.coords_agg == 'sum': - agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0)) - elif self.coords_agg == 'mean': - agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) - else: - raise Exception('Wrong coords_agg parameter' % self.coords_agg) + row = edge_index[0] + trans = coord_diff * self.coord_mlp(edge_feat) + agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0)) + #agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) coord = coord + agg return coord + def coord2radial(self, edge_index, coord): row, col = edge_index coord_diff = coord[row] - coord[col] - radial = torch.sum(coord_diff**2, 1).unsqueeze(1) - + radial = torch.sum(coord_diff**2, -1).unsqueeze(-1) if self.normalize: norm = torch.sqrt(radial).detach() + self.epsilon coord_diff = coord_diff / norm - return radial, coord_diff - def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None): - self.to(self.device) + def forward(self, h, edge_index, coord, edge_attr=None): row, col = edge_index radial, coord_diff = self.coord2radial(edge_index, coord) - - # Edge model with SO(3) features concatenated edge_feat = self.edge_model(h[row], h[col], radial, edge_attr, coord, edge_index) - coord = self.coord_model(coord, edge_index, coord_diff, edge_feat) - h, agg = self.node_model(h, edge_index, edge_feat, node_attr) - # print(coord.shape) - # print(h.shape) + h = self.node_model(h, edge_index, edge_feat) + + # Return h, coord, and a dummy third value return h, coord, edge_attr + class EGNN(nn.Module): def __init__(self, in_node_nf, hidden_nf, out_node_nf, in_edge_nf=0, device='cuda:0', act_fn=nn.SiLU(), n_layers=5, residual=True, attention=True, normalize=False, tanh=False): ''' @@ -342,6 +337,7 @@ def __init__(self, in_node_nf, hidden_nf, out_node_nf, in_edge_nf=0, device='cud self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, residual=residual, attention=attention, normalize=normalize, tanh=tanh)) + self.to(device) # Move entire model to device during initialization def forward(self, h, x, edges, edge_attr): # self.to(self.device) @@ -501,6 +497,52 @@ def quaternion_to_matrix(quaternion, device): return rotation_matrix +def matrix_log(R): + """ + Compute the matrix logarithm of a 3x3 rotation matrix R on the GPU. + + Args: + - R: Rotation matrix [3x3] on the GPU + + Returns: + - Matrix logarithm of R [3x3] on the GPU + """ + trace = torch.trace(R) + theta = torch.acos((trace - 1) / 2) + + # Handle the case when theta is close to 0 (identity matrix) + theta_abs = torch.abs(theta) + log_R = torch.where(theta_abs < 1e-6, + torch.zeros_like(R), + theta / (2 * torch.sin(theta)) * (R - R.T)) + + return log_R + +def center_and_normalize(src_pts, tar_pts): + """ + Normalize source and target points by centering them at the origin and scaling them to have unit norm. + + Args: + - src_pts: Source points [N x 3] + - tar_pts: Target points [N x 3] + + Returns: + - Normalized source points [N x 3] + - Normalized target points [N x 3] + """ + # Center the points at the origin + src_center = src_pts.mean(dim=0) + tar_center = tar_pts.mean(dim=0) + + src_pts_centered = src_pts - src_center + tar_pts_centered = tar_pts - tar_center + + # Normalize the points to have unit norm + src_pts_normalized = src_pts_centered / torch.norm(src_pts_centered, dim=1, keepdim=True) + tar_pts_normalized = tar_pts_centered / torch.norm(tar_pts_centered, dim=1, keepdim=True) + + return src_pts_normalized, tar_pts_normalized + class CrossAttentionPoseRegression(nn.Module): def __init__(self, egnn: EGNN, num_nodes: int = 2048, hidden_nf: int = 35, device='cuda:0'): super(CrossAttentionPoseRegression, self).__init__() @@ -600,8 +642,8 @@ def forward(self, h_src, x_src, edges_src, edge_attr_src, h_tgt, x_tgt, edges_tg total_corr_loss = corr_loss + rank_loss # Weigh the source and target descriptors using the similarity matrix - weighted_h_src = torch.mm(sim_matrix.transpose(0, 1), compressed_h_src) # Shape: [128, 35] - weighted_h_tgt = torch.mm(sim_matrix, compressed_h_tgt) # Shape: [128, 35] + weighted_h_src = torch.mm(sim_matrix.transpose(0, 1), compressed_h_src_norm) # Shape: [128, 35] + weighted_h_tgt = torch.mm(sim_matrix, compressed_h_tgt_norm) # Shape: [128, 35] # Concatenate the source and target weighted features combined_features = torch.cat([weighted_h_src, weighted_h_tgt], dim=-1) # Shape: [128, 70] @@ -618,7 +660,7 @@ def forward(self, h_src, x_src, edges_src, edge_attr_src, h_tgt, x_tgt, edges_tg return quaternion, translation, total_corr_loss -def pose_loss(pred_quaternion, pred_translation, gt_pose, delta=1.0): +def pose_loss(pred_quaternion, pred_translation, gt_pose, delta=1.5): """ Compute the loss between the predicted pose (quaternion + translation) and the ground truth pose matrix. @@ -629,26 +671,45 @@ def pose_loss(pred_quaternion, pred_translation, gt_pose, delta=1.0): Returns: - Total loss combining quaternion and translation loss. - """ - - # Extract ground truth translation (3D) and rotation (3x3) from the 4x4 gt_pose matrix - gt_translation = gt_pose[:3, 3] # Translation vector [3] - gt_rotation = gt_pose[:3, :3] # Rotation matrix [3x3] + """ + # # Convert ground truth rotation matrix to quaternion + # gt_quaternion = rotation_matrix_to_quaternion(gt_rotation) # Convert [3x3] to [4] + # gt_quaternion = F.normalize(gt_quaternion, p=2, dim=-1) + # # Normalize the predicted quaternion + # pred_quaternion = F.normalize(pred_quaternion, p=2, dim=-1) - # Convert ground truth rotation matrix to quaternion - gt_quaternion = rotation_matrix_to_quaternion(gt_rotation) # Convert [3x3] to [4] - gt_quaternion = F.normalize(gt_quaternion, p=2, dim=-1) - # Normalize the predicted quaternion - pred_quaternion = F.normalize(pred_quaternion, p=2, dim=-1) + # # Huber loss for quaternion and translation + # huber_loss = nn.HuberLoss(delta=delta) - # Huber loss for quaternion and translation - huber_loss = nn.HuberLoss(delta=delta) + # quaternion_loss = huber_loss(pred_quaternion, gt_quaternion) + # translation_loss = huber_loss(pred_translation, gt_translation) - quaternion_loss = huber_loss(pred_quaternion, gt_quaternion) - translation_loss = huber_loss(pred_translation, gt_translation) + # # Return the combined loss + # return quaternion_loss + translation_loss - # Return the combined loss - return quaternion_loss + translation_loss + # Extract ground truth translation and rotation + gt_translation = gt_pose[:3, 3] + gt_rotation = gt_pose[:3, :3] + + # Convert predicted quaternion to rotation matrix + pred_rotation = quaternion_to_matrix(pred_quaternion, device=pred_quaternion.device) + + # # SO(3) geodesic rotation loss + # R = torch.matmul(pred_rotation.T, gt_rotation) + # log_R = matrix_log(R) + # rotation_loss = torch.norm(log_R, p='fro') + + # Compute geodesic rotation loss + rotation_loss = torch.arccos(torch.clamp((torch.trace(torch.matmul(pred_rotation.T, gt_rotation)) - 1) / 2, min=-1, max=1)) + + # Translation loss + huber_loss = nn.HuberLoss(delta=delta) + translation_loss = huber_loss(pred_translation, gt_translation) + + # Combined loss + total_loss = rotation_loss #######+ translation_loss + + return total_loss # Function to train for one epoch def train_one_epoch(model, dataloader, optimizer, device, epoch, writer, use_pointnet=False, log_interval=5, beta=0.1): @@ -736,8 +797,9 @@ def train_one_epoch(model, dataloader, optimizer, device, epoch, writer, use_poi print(f'Iteration {batch_idx}/{len(train_loader)} - Loss: {loss.item():.4f}') avg_loss = running_loss / len(dataloader) print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") + print(avg_loss) # Log the average loss for this epoch - writer.add_scalar('Loss/train', avg_loss, epoch) + # writer.add_scalar('Loss/train', avg_loss, epoch) return avg_loss @@ -763,6 +825,7 @@ def validate(model, dataloader, device, epoch, writer, use_pointnet=False): feat_0, feat_1 = src_features.to(device), tgt_features.to(device) gt_pose = gt_pose.to(device) + # xyz_0, xyz_1 = center_and_normalize(xyz_0, xyz_1) # print("################@@@@@@@@@@@@@@@@######################") # print(xyz_0.shape) # print(xyz_1.shape) @@ -830,10 +893,10 @@ def validate(model, dataloader, device, epoch, writer, use_pointnet=False): print(f'Validation Loss: {avg_loss:.4f} | Pose Loss: {avg_pose_loss:.4f} | Correspondence Loss: {avg_corr_loss:.4f}') - # Log validation losses - writer.add_scalar('Loss/validation', avg_loss, epoch) - writer.add_scalar('Pose_Loss/validation', avg_pose_loss, epoch) - writer.add_scalar('Correspondence_Loss/validation', avg_corr_loss, epoch) + # # Log validation losses + # writer.add_scalar('Loss/validation', avg_loss, epoch) + # writer.add_scalar('Pose_Loss/validation', avg_pose_loss, epoch) + # writer.add_scalar('Correspondence_Loss/validation', avg_corr_loss, epoch) return avg_loss, avg_pose_loss, avg_corr_loss