Skip to content

Commit

Permalink
refactor #70: removed useless all triplets selector
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioQuijanoRey committed Apr 21, 2024
1 parent d691241 commit 34b4ad9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 35 deletions.
4 changes: 2 additions & 2 deletions src/MNIST Adam Bielski.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def extract_embeddings(dataloader, model):
# Set up the network and training parameters
from adambielski_lib.networks import EmbeddingNet
from adambielski_lib.utils import ( # Strategies for selecting triplets within a minibatch
AllTripletSelector, HardestNegativeTripletSelector,
RandomNegativeTripletSelector, SemihardNegativeTripletSelector)
HardestNegativeTripletSelector, RandomNegativeTripletSelector,
SemihardNegativeTripletSelector)

margin = 1.0
embedding_net = EmbeddingNet()
Expand Down
33 changes: 0 additions & 33 deletions src/adambielski_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,6 @@ def get_triplets(self, embeddings, labels):
raise NotImplementedError


class AllTripletSelector(TripletSelector):
"""
Returns all possible triplets
May be impractical in most cases
"""

def __init__(self):
super(AllTripletSelector, self).__init__()

def get_triplets(self, embeddings, labels):
labels = labels.cpu().data.numpy()
triplets = []
for label in set(labels):
label_mask = labels == label
label_indices = np.where(label_mask)[0]
if len(label_indices) < 2:
continue
negative_indices = np.where(np.logical_not(label_mask))[0]
anchor_positives = list(
combinations(label_indices, 2)
) # All anchor-positive pairs

# Add all negatives for all positive pairs
temp_triplets = [
[anchor_positive[0], anchor_positive[1], neg_ind]
for anchor_positive in anchor_positives
for neg_ind in negative_indices
]
triplets += temp_triplets

return torch.LongTensor(np.array(triplets))


def hardest_negative(loss_values):
hard_negative = np.argmax(loss_values)
return hard_negative if loss_values[hard_negative] > 0 else None
Expand Down

0 comments on commit 34b4ad9

Please sign in to comment.