Skip to content

Commit

Permalink
fix, test: trying new manually computed values
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioQuijanoRey committed Jun 22, 2024
1 parent 3ec97af commit 35c1ae4
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/test/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def test_basic_cases(self):
labels = torch.tensor([1, 1, 1, 2, 2, 2, 3, 3, 3])

loss_computed = float(loss(points, labels))
loss_expected = 9.741091053890488 # Manually computed loss
loss_expected = 1.1105841398239136 # Manually computed loss
self.assertAlmostEqual(loss_computed, loss_expected, places = 3)

def test_basic_cases_gt_than_zero(self):
Expand All @@ -341,7 +341,7 @@ def test_basic_cases_gt_than_zero(self):
labels = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])

loss_computed = float(loss(points, labels))
loss_expected = 9.741091053890488 # Manually computed loss
loss_expected = 1.1105841398239136 # Manually computed loss
self.assertAlmostEqual(loss_computed, loss_expected, places = 3)

class TestBatchAllTripletLoss(unittest.TestCase):
Expand Down Expand Up @@ -370,7 +370,7 @@ def test_basic_cases(self):
labels = torch.tensor([1, 1, 1, 2, 2, 2, 3, 3, 3])

loss_computed = float(loss(points, labels))
loss_expected = 2.6339713808044256
loss_expected = 1.0025441646575928
self.assertAlmostEqual(loss_computed, loss_expected, places = 3)

def test_basic_cases_gt_than_zero(self):
Expand All @@ -397,5 +397,5 @@ def test_basic_cases_gt_than_zero(self):
labels = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])

loss_computed = float(loss(points, labels))
loss_expected = 4.515379509950444
loss_expected = 1.0025441646575928
self.assertAlmostEqual(loss_computed, loss_expected, places = 5)

0 comments on commit 35c1ae4

Please sign in to comment.