diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py index d30c3397..6848a430 100644 --- a/tests/unit/test_router.py +++ b/tests/unit/test_router.py @@ -13,10 +13,12 @@ from semantic_router.routers import RouterConfig, SemanticRouter, HybridRouter from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.route import Route +from semantic_router.utils.logger import logger from platform import python_version -PINECONE_SLEEP = 12 +PINECONE_SLEEP = 8 +RETRY_COUNT = 5 def mock_encoder_call(utterances): @@ -266,9 +268,6 @@ def test_initialization(self, routes, index_cls, encoder_cls, router_cls): auto_sync="local", top_k=10, ) - if index_cls is PineconeIndex: - time.sleep(PINECONE_SLEEP * 2) # allow for index to be populated - if isinstance(route_layer, HybridRouter): assert ( route_layer.score_threshold @@ -277,7 +276,16 @@ def test_initialization(self, routes, index_cls, encoder_cls, router_cls): else: assert route_layer.score_threshold == encoder.score_threshold assert route_layer.top_k == 10 - assert len(route_layer.index) == 5 + # allow for 5 retries in case of index not being populated + count = 0 + while count < RETRY_COUNT: + try: + assert len(route_layer.index) == 5 + break + except AssertionError: + logger.warning(f"Index not populated, waiting for retry (try {count})") + time.sleep(PINECONE_SLEEP) + count += 1 assert ( len(set(route_layer._get_route_names())) if route_layer._get_route_names() is not None @@ -718,10 +726,20 @@ def test_query_and_classification(self, routes, index_cls, encoder_cls, router_c index=index, auto_sync="local", ) - if index_cls is PineconeIndex: - time.sleep(PINECONE_SLEEP * 2) # allow for index to be populated - query_result = route_layer(text="Hello").name - assert query_result in ["Route 1", "Route 2"] + count = 0 + # we allow for 5 retries to allow for index to be populated + while count < RETRY_COUNT: + query_result = route_layer(text="Hello").name + try: + assert query_result in ["Route 1", "Route 2"] + break + except AssertionError: + logger.warning( + f"Query result not in expected routes, waiting for retry (try {count})" + ) + if index_cls is PineconeIndex: + time.sleep(PINECONE_SLEEP) # allow for index to be populated + count += 1 def test_query_filter(self, routes, index_cls, encoder_cls, router_cls): encoder = encoder_cls() @@ -781,15 +799,23 @@ def test_namespace_pinecone_index(self, routes, index_cls, encoder_cls, router_c index=pineconeindex, auto_sync="local", ) - time.sleep(PINECONE_SLEEP * 2) # allow for index to be populated - query_result = route_layer(text="Hello", route_filter=["Route 1"]).name - - try: - route_layer(text="Hello", route_filter=["Route 8"]).name - except ValueError: - assert True - - assert query_result in ["Route 1"] + count = 0 + while count < RETRY_COUNT: + try: + query_result = route_layer( + text="Hello", route_filter=["Route 1"] + ).name + assert query_result in ["Route 1"] + break + except AssertionError: + logger.warning( + f"Query result not in expected routes, waiting for retry (try {count})" + ) + if index_cls is PineconeIndex: + time.sleep( + PINECONE_SLEEP * 2 + ) # allow for index to be populated + count += 1 route_layer.index.index.delete(namespace="test", delete_all=True) def test_query_with_no_index(self, index_cls, encoder_cls, router_cls):