diff --git a/tests/ann_test.py b/tests/ann_test.py index 9f5deb2ee011..98c4c5695efd 100644 --- a/tests/ann_test.py +++ b/tests/ann_test.py @@ -49,7 +49,7 @@ def test_approx_max_k(self, qy_shape, db_shape, dtype, k, recall): _, gt_args = lax.top_k(scores, k) _, ann_args = ann.approx_max_k(scores, k, recall_target=recall) self.assertEqual(k, len(ann_args[0])) - gt_args_sets = [set(x) for x in gt_args] + gt_args_sets = [set(np.asarray(x)) for x in gt_args] hits = sum( len(list(x for x in ann_args_per_q @@ -77,7 +77,7 @@ def test_approx_min_k(self, qy_shape, db_shape, dtype, k, recall): _, gt_args = lax.top_k(-scores, k) _, ann_args = ann.approx_min_k(scores, k, recall_target=recall) self.assertEqual(k, len(ann_args[0])) - gt_args_sets = [set(x) for x in gt_args] + gt_args_sets = [set(np.asarray(x)) for x in gt_args] hits = sum( len(list(x for x in ann_args_per_q