From d69124182302e5eebaacbe899db3f9a51d8d2a89 Mon Sep 17 00:00:00 2001 From: Sergio Quijano Date: Sun, 21 Apr 2024 18:56:43 +0200 Subject: [PATCH] refactor #70: removed pair selectors from Adam's lib --- src/adambielski_lib/utils.py | 130 ++++++++++++----------------------- 1 file changed, 45 insertions(+), 85 deletions(-) diff --git a/src/adambielski_lib/utils.py b/src/adambielski_lib/utils.py index 981f8a0..8a3a211 100644 --- a/src/adambielski_lib/utils.py +++ b/src/adambielski_lib/utils.py @@ -5,75 +5,14 @@ def pdist(vectors): - distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum( - dim=1).view(-1, 1) + distance_matrix = ( + -2 * vectors.mm(torch.t(vectors)) + + vectors.pow(2).sum(dim=1).view(1, -1) + + vectors.pow(2).sum(dim=1).view(-1, 1) + ) return distance_matrix -class PairSelector: - """ - Implementation should return indices of positive pairs and negative pairs that will be passed to compute - Contrastive Loss - return positive_pairs, negative_pairs - """ - - def __init__(self): - pass - - def get_pairs(self, embeddings, labels): - raise NotImplementedError - - -class AllPositivePairSelector(PairSelector): - """ - Discards embeddings and generates all possible pairs given labels. - If balance is True, negative pairs are a random sample to match the number of positive samples - """ - def __init__(self, balance=True): - super(AllPositivePairSelector, self).__init__() - self.balance = balance - - def get_pairs(self, embeddings, labels): - labels = labels.cpu().data.numpy() - all_pairs = np.array(list(combinations(range(len(labels)), 2))) - all_pairs = torch.LongTensor(all_pairs) - positive_pairs = all_pairs[(labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()] - negative_pairs = all_pairs[(labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()] - if self.balance: - negative_pairs = negative_pairs[torch.randperm(len(negative_pairs))[:len(positive_pairs)]] - - return positive_pairs, negative_pairs - - -class HardNegativePairSelector(PairSelector): - """ - Creates all possible positive pairs. For negative pairs, pairs with smallest distance are taken into consideration, - matching the number of positive pairs. - """ - - def __init__(self, cpu=True): - super(HardNegativePairSelector, self).__init__() - self.cpu = cpu - - def get_pairs(self, embeddings, labels): - if self.cpu: - embeddings = embeddings.cpu() - distance_matrix = pdist(embeddings) - - labels = labels.cpu().data.numpy() - all_pairs = np.array(list(combinations(range(len(labels)), 2))) - all_pairs = torch.LongTensor(all_pairs) - positive_pairs = all_pairs[(labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()] - negative_pairs = all_pairs[(labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()] - - negative_distances = distance_matrix[negative_pairs[:, 0], negative_pairs[:, 1]] - negative_distances = negative_distances.cpu().data.numpy() - top_negatives = np.argpartition(negative_distances, len(positive_pairs))[:len(positive_pairs)] - top_negative_pairs = negative_pairs[torch.LongTensor(top_negatives)] - - return positive_pairs, top_negative_pairs - - class TripletSelector: """ Implementation should return indices of anchors, positive and negative samples @@ -100,16 +39,21 @@ def get_triplets(self, embeddings, labels): labels = labels.cpu().data.numpy() triplets = [] for label in set(labels): - label_mask = (labels == label) + label_mask = labels == label label_indices = np.where(label_mask)[0] if len(label_indices) < 2: continue negative_indices = np.where(np.logical_not(label_mask))[0] - anchor_positives = list(combinations(label_indices, 2)) # All anchor-positive pairs + anchor_positives = list( + combinations(label_indices, 2) + ) # All anchor-positive pairs # Add all negatives for all positive pairs - temp_triplets = [[anchor_positive[0], anchor_positive[1], neg_ind] for anchor_positive in anchor_positives - for neg_ind in negative_indices] + temp_triplets = [ + [anchor_positive[0], anchor_positive[1], neg_ind] + for anchor_positive in anchor_positives + for neg_ind in negative_indices + ] triplets += temp_triplets return torch.LongTensor(np.array(triplets)) @@ -126,7 +70,9 @@ def random_hard_negative(loss_values): def semihard_negative(loss_values, margin): - semihard_negatives = np.where(np.logical_and(loss_values < margin, loss_values > 0))[0] + semihard_negatives = np.where( + np.logical_and(loss_values < margin, loss_values > 0) + )[0] return np.random.choice(semihard_negatives) if len(semihard_negatives) > 0 else None @@ -154,15 +100,20 @@ def get_triplets(self, embeddings, labels): triplets = [] for label in set(labels): - label_mask = (labels == label) + label_mask = labels == label label_indices = np.where(label_mask)[0] if len(label_indices) < 2: continue negative_indices = np.where(np.logical_not(label_mask))[0] - anchor_positives = list(combinations(label_indices, 2)) # All anchor-positive pairs + anchor_positives = list( + combinations(label_indices, 2) + ) # All anchor-positive pairs anchor_positives = np.array(anchor_positives) - ap_distances = distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]] + self.margin + ap_distances = ( + distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]] + + self.margin + ) idxs = np.ix_(anchor_positives[:, 0], negative_indices) loss_values = ap_distances.unsqueeze(dim=1) - distance_matrix[idxs] loss_values = loss_values.data.cpu().numpy() @@ -170,24 +121,33 @@ def get_triplets(self, embeddings, labels): hard_negative = self.negative_selection_fn(loss_val) if hard_negative is not None: hard_negative = negative_indices[hard_negative] - triplets.append([anchor_positives[i][0], anchor_positives[i][1], hard_negative]) + triplets.append( + [anchor_positives[i][0], anchor_positives[i][1], hard_negative] + ) if len(triplets) == 0: - triplets.append([anchor_positives[-1][0], anchor_positives[-1][1], negative_indices[0]]) + triplets.append( + [anchor_positives[-1][0], anchor_positives[-1][1], negative_indices[0]] + ) return torch.LongTensor(triplets) -def HardestNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin, - negative_selection_fn=hardest_negative, - cpu=cpu) +def HardestNegativeTripletSelector(margin, cpu=False): + return FunctionNegativeTripletSelector( + margin=margin, negative_selection_fn=hardest_negative, cpu=cpu + ) -def RandomNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin, - negative_selection_fn=random_hard_negative, - cpu=cpu) +def RandomNegativeTripletSelector(margin, cpu=False): + return FunctionNegativeTripletSelector( + margin=margin, negative_selection_fn=random_hard_negative, cpu=cpu + ) -def SemihardNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin, - negative_selection_fn=lambda x: semihard_negative(x, margin), - cpu=cpu) +def SemihardNegativeTripletSelector(margin, cpu=False): + return FunctionNegativeTripletSelector( + margin=margin, + negative_selection_fn=lambda x: semihard_negative(x, margin), + cpu=cpu, + )