Skip to content

Commit

Permalink
only rotation and rank loss used now, layer norm added, along with eg…
Browse files Browse the repository at this point in the history
…de embedding with invraiant feature and multi-head at 8-dim for multi-update instead of single high dimensional mlp layer
  • Loading branch information
alexandor91 committed Dec 1, 2024
1 parent 45dc0e1 commit c5d6942
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 105 deletions.
16 changes: 8 additions & 8 deletions datasets/ThreeDMatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ def __getitem__(self, index):

# Sample fixed number of points
sample_size = self.num_node
if sample_size > N_src or sample_size > N_tgt:
print("Warning: Not enough sample points for the fixed number, sampling with repetitions.")
# if sample_size > N_src or sample_size > N_tgt:
# print("Warning: Not enough sample points for the fixed number, sampling with repetitions.")

# Separate indices for positive and negative labels
pos_indices = np.where(labels == 1)[0]
Expand All @@ -338,8 +338,8 @@ def __getitem__(self, index):
num_pos = int(self.num_node * 0.6)
num_neg = self.num_node - num_pos

if len(pos_indices) < num_pos or len(neg_indices) < num_neg:
print("Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating sampling will be used!")
# if len(pos_indices) < num_pos or len(neg_indices) < num_neg:
# print("Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating sampling will be used!")

if len(pos_indices) < 5:
sampled_indices = np.random.choice(len(labels), self.num_node, replace=True)
Expand Down Expand Up @@ -484,8 +484,8 @@ def __getitem__(self, index):

# Sample fixed number of points
sample_size = self.num_node
if sample_size > N_src or sample_size > N_tgt:
print("Warning: Not enough sample points for the fixed number, sampling with repetitions.")
# if sample_size > N_src or sample_size > N_tgt:
# print("Warning: Not enough sample points for the fixed number, sampling with repetitions.")

# Separate indices for positive and negative labels
pos_indices = np.where(labels == 1)[0]
Expand All @@ -495,8 +495,8 @@ def __getitem__(self, index):
num_pos = int(self.num_node * 0.6)
num_neg = self.num_node - num_pos

if len(pos_indices) < num_pos or len(neg_indices) < num_neg:
print("Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating sampling will be used!")
# if len(pos_indices) < num_pos or len(neg_indices) < num_neg:
# print("Not enough positive or negative points to satisfy the 0.60-0.40 ratio. so repeating sampling will be used!")


if len(pos_indices) < 5:
Expand Down
Loading

0 comments on commit c5d6942

Please sign in to comment.