diff --git a/src/test/loss_functions.py b/src/test/loss_functions.py index 11b03a2..484a7c9 100644 --- a/src/test/loss_functions.py +++ b/src/test/loss_functions.py @@ -314,7 +314,7 @@ def test_basic_cases(self): labels = torch.tensor([1, 1, 1, 2, 2, 2, 3, 3, 3]) loss_computed = float(loss(points, labels)) - loss_expected = 9.741091053890488 # Manually computed loss + loss_expected = 1.1105841398239136 # Manually computed loss self.assertAlmostEqual(loss_computed, loss_expected, places = 3) def test_basic_cases_gt_than_zero(self): @@ -341,7 +341,7 @@ def test_basic_cases_gt_than_zero(self): labels = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2]) loss_computed = float(loss(points, labels)) - loss_expected = 9.741091053890488 # Manually computed loss + loss_expected = 1.1105841398239136 # Manually computed loss self.assertAlmostEqual(loss_computed, loss_expected, places = 3) class TestBatchAllTripletLoss(unittest.TestCase): @@ -370,7 +370,7 @@ def test_basic_cases(self): labels = torch.tensor([1, 1, 1, 2, 2, 2, 3, 3, 3]) loss_computed = float(loss(points, labels)) - loss_expected = 2.6339713808044256 + loss_expected = 1.0025441646575928 self.assertAlmostEqual(loss_computed, loss_expected, places = 3) def test_basic_cases_gt_than_zero(self): @@ -397,5 +397,5 @@ def test_basic_cases_gt_than_zero(self): labels = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2]) loss_computed = float(loss(points, labels)) - loss_expected = 4.515379509950444 + loss_expected = 1.0025441646575928 self.assertAlmostEqual(loss_computed, loss_expected, places = 5)