Skip to content

Commit

Permalink
refactor #70: removed pair selectors from Adam's lib
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioQuijanoRey committed Apr 21, 2024
1 parent 851cef7 commit d691241
Showing 1 changed file with 45 additions and 85 deletions.
130 changes: 45 additions & 85 deletions src/adambielski_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,75 +5,14 @@


def pdist(vectors):
distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum(
dim=1).view(-1, 1)
distance_matrix = (
-2 * vectors.mm(torch.t(vectors))
+ vectors.pow(2).sum(dim=1).view(1, -1)
+ vectors.pow(2).sum(dim=1).view(-1, 1)
)
return distance_matrix


class PairSelector:
"""
Implementation should return indices of positive pairs and negative pairs that will be passed to compute
Contrastive Loss
return positive_pairs, negative_pairs
"""

def __init__(self):
pass

def get_pairs(self, embeddings, labels):
raise NotImplementedError


class AllPositivePairSelector(PairSelector):
"""
Discards embeddings and generates all possible pairs given labels.
If balance is True, negative pairs are a random sample to match the number of positive samples
"""
def __init__(self, balance=True):
super(AllPositivePairSelector, self).__init__()
self.balance = balance

def get_pairs(self, embeddings, labels):
labels = labels.cpu().data.numpy()
all_pairs = np.array(list(combinations(range(len(labels)), 2)))
all_pairs = torch.LongTensor(all_pairs)
positive_pairs = all_pairs[(labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()]
negative_pairs = all_pairs[(labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()]
if self.balance:
negative_pairs = negative_pairs[torch.randperm(len(negative_pairs))[:len(positive_pairs)]]

return positive_pairs, negative_pairs


class HardNegativePairSelector(PairSelector):
"""
Creates all possible positive pairs. For negative pairs, pairs with smallest distance are taken into consideration,
matching the number of positive pairs.
"""

def __init__(self, cpu=True):
super(HardNegativePairSelector, self).__init__()
self.cpu = cpu

def get_pairs(self, embeddings, labels):
if self.cpu:
embeddings = embeddings.cpu()
distance_matrix = pdist(embeddings)

labels = labels.cpu().data.numpy()
all_pairs = np.array(list(combinations(range(len(labels)), 2)))
all_pairs = torch.LongTensor(all_pairs)
positive_pairs = all_pairs[(labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()]
negative_pairs = all_pairs[(labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()]

negative_distances = distance_matrix[negative_pairs[:, 0], negative_pairs[:, 1]]
negative_distances = negative_distances.cpu().data.numpy()
top_negatives = np.argpartition(negative_distances, len(positive_pairs))[:len(positive_pairs)]
top_negative_pairs = negative_pairs[torch.LongTensor(top_negatives)]

return positive_pairs, top_negative_pairs


class TripletSelector:
"""
Implementation should return indices of anchors, positive and negative samples
Expand All @@ -100,16 +39,21 @@ def get_triplets(self, embeddings, labels):
labels = labels.cpu().data.numpy()
triplets = []
for label in set(labels):
label_mask = (labels == label)
label_mask = labels == label
label_indices = np.where(label_mask)[0]
if len(label_indices) < 2:
continue
negative_indices = np.where(np.logical_not(label_mask))[0]
anchor_positives = list(combinations(label_indices, 2)) # All anchor-positive pairs
anchor_positives = list(
combinations(label_indices, 2)
) # All anchor-positive pairs

# Add all negatives for all positive pairs
temp_triplets = [[anchor_positive[0], anchor_positive[1], neg_ind] for anchor_positive in anchor_positives
for neg_ind in negative_indices]
temp_triplets = [
[anchor_positive[0], anchor_positive[1], neg_ind]
for anchor_positive in anchor_positives
for neg_ind in negative_indices
]
triplets += temp_triplets

return torch.LongTensor(np.array(triplets))
Expand All @@ -126,7 +70,9 @@ def random_hard_negative(loss_values):


def semihard_negative(loss_values, margin):
semihard_negatives = np.where(np.logical_and(loss_values < margin, loss_values > 0))[0]
semihard_negatives = np.where(
np.logical_and(loss_values < margin, loss_values > 0)
)[0]
return np.random.choice(semihard_negatives) if len(semihard_negatives) > 0 else None


Expand Down Expand Up @@ -154,40 +100,54 @@ def get_triplets(self, embeddings, labels):
triplets = []

for label in set(labels):
label_mask = (labels == label)
label_mask = labels == label
label_indices = np.where(label_mask)[0]
if len(label_indices) < 2:
continue
negative_indices = np.where(np.logical_not(label_mask))[0]
anchor_positives = list(combinations(label_indices, 2)) # All anchor-positive pairs
anchor_positives = list(
combinations(label_indices, 2)
) # All anchor-positive pairs
anchor_positives = np.array(anchor_positives)

ap_distances = distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]] + self.margin
ap_distances = (
distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]]
+ self.margin
)
idxs = np.ix_(anchor_positives[:, 0], negative_indices)
loss_values = ap_distances.unsqueeze(dim=1) - distance_matrix[idxs]
loss_values = loss_values.data.cpu().numpy()
for i, loss_val in enumerate(loss_values):
hard_negative = self.negative_selection_fn(loss_val)
if hard_negative is not None:
hard_negative = negative_indices[hard_negative]
triplets.append([anchor_positives[i][0], anchor_positives[i][1], hard_negative])
triplets.append(
[anchor_positives[i][0], anchor_positives[i][1], hard_negative]
)

if len(triplets) == 0:
triplets.append([anchor_positives[-1][0], anchor_positives[-1][1], negative_indices[0]])
triplets.append(
[anchor_positives[-1][0], anchor_positives[-1][1], negative_indices[0]]
)

return torch.LongTensor(triplets)


def HardestNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
negative_selection_fn=hardest_negative,
cpu=cpu)
def HardestNegativeTripletSelector(margin, cpu=False):
return FunctionNegativeTripletSelector(
margin=margin, negative_selection_fn=hardest_negative, cpu=cpu
)


def RandomNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
negative_selection_fn=random_hard_negative,
cpu=cpu)
def RandomNegativeTripletSelector(margin, cpu=False):
return FunctionNegativeTripletSelector(
margin=margin, negative_selection_fn=random_hard_negative, cpu=cpu
)


def SemihardNegativeTripletSelector(margin, cpu=False): return FunctionNegativeTripletSelector(margin=margin,
negative_selection_fn=lambda x: semihard_negative(x, margin),
cpu=cpu)
def SemihardNegativeTripletSelector(margin, cpu=False):
return FunctionNegativeTripletSelector(
margin=margin,
negative_selection_fn=lambda x: semihard_negative(x, margin),
cpu=cpu,
)

0 comments on commit d691241

Please sign in to comment.