From 79b77f50c8d0e3e10c471d96aa745844b223b978 Mon Sep 17 00:00:00 2001 From: Lucashsmello Date: Tue, 28 Jul 2020 15:31:12 -0300 Subject: [PATCH 1/2] Faster triplet selector --- utils.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/utils.py b/utils.py index 1eb2d977..6140b063 100644 --- a/utils.py +++ b/utils.py @@ -162,20 +162,19 @@ def get_triplets(self, embeddings, labels): anchor_positives = list(combinations(label_indices, 2)) # All anchor-positive pairs anchor_positives = np.array(anchor_positives) - ap_distances = distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]] - for anchor_positive, ap_distance in zip(anchor_positives, ap_distances): - loss_values = ap_distance - distance_matrix[torch.LongTensor(np.array([anchor_positive[0]])), torch.LongTensor(negative_indices)] + self.margin - loss_values = loss_values.data.cpu().numpy() - hard_negative = self.negative_selection_fn(loss_values) + ap_distances = distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]] + self.margin + idxs = np.ix_(anchor_positives[:, 0], negative_indices) + loss_values = ap_distances.unsqueeze(dim=1) - distance_matrix[idxs] + loss_values = loss_values.data.cpu().numpy() + for i, loss_val in enumerate(loss_values): + hard_negative = self.negative_selection_fn(loss_val) if hard_negative is not None: hard_negative = negative_indices[hard_negative] - triplets.append([anchor_positive[0], anchor_positive[1], hard_negative]) + triplets.append([anchor_positives[i][0], anchor_positives[i][1], hard_negative]) if len(triplets) == 0: triplets.append([anchor_positive[0], anchor_positive[1], negative_indices[0]]) - triplets = np.array(triplets) - return torch.LongTensor(triplets) From 99f7b0ccc40bb58223b26b1a9a24f0fce63e44d8 Mon Sep 17 00:00:00 2001 From: Lucashsmello Date: Tue, 4 Aug 2020 16:41:27 -0300 Subject: [PATCH 2/2] Fixed bug when there is no hard triplet --- utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.py b/utils.py index 6140b063..981f8a0a 100644 --- a/utils.py +++ b/utils.py @@ -173,7 +173,7 @@ def get_triplets(self, embeddings, labels): triplets.append([anchor_positives[i][0], anchor_positives[i][1], hard_negative]) if len(triplets) == 0: - triplets.append([anchor_positive[0], anchor_positive[1], negative_indices[0]]) + triplets.append([anchor_positives[-1][0], anchor_positives[-1][1], negative_indices[0]]) return torch.LongTensor(triplets)