diff --git a/libs/community/langchain_community/retrievers/azure_cognitive_search.py b/libs/community/langchain_community/retrievers/azure_cognitive_search.py index fb91dd188f6d0..19e13234bfb92 100644 --- a/libs/community/langchain_community/retrievers/azure_cognitive_search.py +++ b/libs/community/langchain_community/retrievers/azure_cognitive_search.py @@ -28,7 +28,7 @@ class AzureCognitiveSearchRetriever(BaseRetriever): api_key: str = "" """API Key. Both Admin and Query keys work, but for reading data it's recommended to use a Query key.""" - api_version: str = "2020-06-30" + api_version: str = "2023-11-01" """API version""" aiosession: Optional[aiohttp.ClientSession] = None """ClientSession, in case we want to reuse connection for better performance.""" @@ -59,7 +59,14 @@ def _build_search_url(self, query: str) -> str: url_suffix = get_from_env( "", "AZURE_COGNITIVE_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX ) - base_url = f"https://{self.service_name}.{url_suffix}/" + if url_suffix in self.service_name and "https://" in self.service_name: + base_url = f"{self.service_name}/" + elif url_suffix in self.service_name and "https://" not in self.service_name: + base_url = f"https://{self.service_name}/" + elif url_suffix not in self.service_name and "https://" in self.service_name: + base_url = f"{self.service_name}.{url_suffix}/" + else: + base_url = self.service_name endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}" top_param = f"&$top={self.top_k}" if self.top_k else "" return base_url + endpoint_path + f"&search={query}" + top_param diff --git a/libs/community/tests/integration_tests/retrievers/test_azure_cognitive_search.py b/libs/community/tests/integration_tests/retrievers/test_azure_cognitive_search.py index 352588bef3e70..80c08e2e71695 100644 --- a/libs/community/tests/integration_tests/retrievers/test_azure_cognitive_search.py +++ b/libs/community/tests/integration_tests/retrievers/test_azure_cognitive_search.py @@ -7,22 +7,31 @@ def test_azure_cognitive_search_get_relevant_documents() -> None: - """Test valid call to Azure Cognitive Search.""" + """Test valid call to Azure Cognitive Search. + + In order to run this test, you should provide a service name, azure search api key + and an index_name as arguments for the AzureCognitiveSearchRetriever in both tests. + """ retriever = AzureCognitiveSearchRetriever() - documents = retriever.get_relevant_documents("what is langchain") + + documents = retriever.get_relevant_documents("what is langchain?") for doc in documents: assert isinstance(doc, Document) assert doc.page_content - retriever = AzureCognitiveSearchRetriever(top_k=1) - documents = retriever.get_relevant_documents("what is langchain") + retriever = AzureCognitiveSearchRetriever() + documents = retriever.get_relevant_documents("what is langchain?") assert len(documents) <= 1 async def test_azure_cognitive_search_aget_relevant_documents() -> None: - """Test valid async call to Azure Cognitive Search.""" + """Test valid async call to Azure Cognitive Search. + + In order to run this test, you should provide a service name, azure search api key + and an index_name as arguments for the AzureCognitiveSearchRetriever. + """ retriever = AzureCognitiveSearchRetriever() - documents = await retriever.aget_relevant_documents("what is langchain") + documents = await retriever.aget_relevant_documents("what is langchain?") for doc in documents: assert isinstance(doc, Document) assert doc.page_content