From 8b4773bcfed488f9cc9f22217814246f8c2936b2 Mon Sep 17 00:00:00 2001 From: "Jeffrey (Dongkyu) Kim" Date: Fri, 13 Dec 2024 10:22:03 +0900 Subject: [PATCH 1/4] working new model v2 client --- autorag/nodes/passagereranker/cohere.py | 10 +++++----- .../nodes/passagereranker/test_cohere_reranker.py | 8 +++++--- tests/conftest.py | 2 -- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/autorag/nodes/passagereranker/cohere.py b/autorag/nodes/passagereranker/cohere.py index ad90d5698..78be241e2 100644 --- a/autorag/nodes/passagereranker/cohere.py +++ b/autorag/nodes/passagereranker/cohere.py @@ -30,7 +30,7 @@ def __init__(self, project_dir: str, *args, **kwargs): "or directly set it on the config YAML file." ) - self.cohere_client = cohere.AsyncClient(api_key) + self.cohere_client = cohere.AsyncClientV2(api_key=api_key) def __del__(self): del self.cohere_client @@ -41,7 +41,7 @@ def pure(self, previous_result: pd.DataFrame, *args, **kwargs): queries, contents, scores, ids = self.cast_to_run(previous_result) top_k = kwargs.pop("top_k") batch = kwargs.pop("batch", 64) - model = kwargs.pop("model", "rerank-multilingual-v2.0") + model = kwargs.pop("model", "rerank-v3.5") return self._pure(queries, contents, scores, ids, top_k, batch, model) def _pure( @@ -52,7 +52,7 @@ def _pure( ids_list: List[List[str]], top_k: int, batch: int = 64, - model: str = "rerank-multilingual-v2.0", + model: str = "rerank-v3.5", ) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: """ Rerank a list of contents with Cohere rerank models. @@ -65,8 +65,8 @@ def _pure( :param top_k: The number of passages to be retrieved :param batch: The number of queries to be processed in a batch :param model: The model name for Cohere rerank. - You can choose between "rerank-multilingual-v2.0" and "rerank-english-v2.0". - Default is "rerank-multilingual-v2.0". + You can choose between "rerank-v3.5", "rerank-english-v3.0", and "rerank-multilingual-v3.0". + Default is "rerank-v3.5". :return: Tuple of lists containing the reranked contents, ids, and scores """ # Run async cohere_rerank_pure function diff --git a/tests/autorag/nodes/passagereranker/test_cohere_reranker.py b/tests/autorag/nodes/passagereranker/test_cohere_reranker.py index 51f4e0d07..c7d7996d0 100644 --- a/tests/autorag/nodes/passagereranker/test_cohere_reranker.py +++ b/tests/autorag/nodes/passagereranker/test_cohere_reranker.py @@ -1,6 +1,7 @@ from unittest.mock import patch -import cohere.base_client +# import cohere.base_client +import cohere import pytest from cohere import RerankResponse, RerankResponseResultsItem @@ -45,10 +46,11 @@ async def mock_cohere_reranker( @pytest.fixture def cohere_reranker_instance(): - return CohereReranker(project_dir=project_dir, api_key="test") + # return CohereReranker(project_dir=project_dir, api_key="test") + return CohereReranker(project_dir=project_dir) -@patch.object(cohere.base_client.AsyncBaseCohere, "rerank", mock_cohere_reranker) +# @patch.object(cohere.base_client.AsyncBaseCohere, "rerank", mock_cohere_reranker) def test_cohere_reranker(cohere_reranker_instance): top_k = 3 contents_result, id_result, score_result = cohere_reranker_instance._pure( diff --git a/tests/conftest.py b/tests/conftest.py index b546f06f4..f61385f09 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,5 +3,3 @@ def pytest_sessionstart(session): os.environ["BM25"] = "bm25" - os.environ["JINAAI_API_KEY"] = "mock_jinaai_api_key" - os.environ["COHERE_API_KEY"] = "mock_cohere_api_key" From 8f5969dcfdf0dc344d0d332d872b56a834f90113 Mon Sep 17 00:00:00 2001 From: "Jeffrey (Dongkyu) Kim" Date: Fri, 13 Dec 2024 10:23:20 +0900 Subject: [PATCH 2/4] Update the cohere docs --- docs/source/nodes/passage_reranker/cohere.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/nodes/passage_reranker/cohere.md b/docs/source/nodes/passage_reranker/cohere.md index 0f77d21b9..c6ec3d9fa 100644 --- a/docs/source/nodes/passage_reranker/cohere.md +++ b/docs/source/nodes/passage_reranker/cohere.md @@ -34,8 +34,9 @@ Or, you can set your Cohere API key in the config.yaml file directly. It sends the batch size of passages to cohere API at once. If it is too large, it can cause some error. (default: 64) -- **model** : The type of model you want to use for reranking. Default is "rerank-multilingual-v2.0" and you can change - it to "rerank-multilingual-v1.0" or "rerank-english-v2.0" (default: "rerank-multilingual-v2.0") +- **model** : The type of model you want to use for reranking. + Default is "rerank-v3.5" and you can change + it to "rerank-v3.5" or "rerank-english-v3.0" or "rerank-multilingual-v3.0" - **api_key** : The cohere api key. ## **Example config.yaml** From 8755967bbaa7c1324f5d8b7c1fe12559f570de91 Mon Sep 17 00:00:00 2001 From: "Jeffrey (Dongkyu) Kim" Date: Fri, 13 Dec 2024 10:41:41 +0900 Subject: [PATCH 3/4] Add CO_API_KEY env variable and set new mock testing for v2 client --- autorag/nodes/passagereranker/cohere.py | 2 ++ docs/source/nodes/passage_reranker/cohere.md | 6 ++++++ .../nodes/passagereranker/test_cohere_reranker.py | 10 ++++------ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/autorag/nodes/passagereranker/cohere.py b/autorag/nodes/passagereranker/cohere.py index 78be241e2..34ecc90b5 100644 --- a/autorag/nodes/passagereranker/cohere.py +++ b/autorag/nodes/passagereranker/cohere.py @@ -24,6 +24,8 @@ def __init__(self, project_dir: str, *args, **kwargs): super().__init__(project_dir) api_key = kwargs.pop("api_key", None) api_key = os.getenv("COHERE_API_KEY", None) if api_key is None else api_key + if api_key is None: + api_key = os.getenv("CO_API_KEY", None) if api_key is None: raise KeyError( "Please set the API key for Cohere rerank in the environment variable COHERE_API_KEY " diff --git a/docs/source/nodes/passage_reranker/cohere.md b/docs/source/nodes/passage_reranker/cohere.md index c6ec3d9fa..5d9a6e693 100644 --- a/docs/source/nodes/passage_reranker/cohere.md +++ b/docs/source/nodes/passage_reranker/cohere.md @@ -21,6 +21,12 @@ Next, you can set your Cohere API key in the environment variable. export COHERE_API_KEY=your_cohere_api_key ``` +or + +```bash +export CO_API_KEY=your_cohere_api_key +``` + Or, you can set your Cohere API key in the config.yaml file directly. ```yaml diff --git a/tests/autorag/nodes/passagereranker/test_cohere_reranker.py b/tests/autorag/nodes/passagereranker/test_cohere_reranker.py index c7d7996d0..bd94d29b1 100644 --- a/tests/autorag/nodes/passagereranker/test_cohere_reranker.py +++ b/tests/autorag/nodes/passagereranker/test_cohere_reranker.py @@ -1,6 +1,5 @@ from unittest.mock import patch -# import cohere.base_client import cohere import pytest from cohere import RerankResponse, RerankResponseResultsItem @@ -46,11 +45,10 @@ async def mock_cohere_reranker( @pytest.fixture def cohere_reranker_instance(): - # return CohereReranker(project_dir=project_dir, api_key="test") - return CohereReranker(project_dir=project_dir) + return CohereReranker(project_dir=project_dir, api_key="test") -# @patch.object(cohere.base_client.AsyncBaseCohere, "rerank", mock_cohere_reranker) +@patch.object(cohere.client_v2.AsyncClientV2, "rerank", mock_cohere_reranker) def test_cohere_reranker(cohere_reranker_instance): top_k = 3 contents_result, id_result, score_result = cohere_reranker_instance._pure( @@ -59,7 +57,7 @@ def test_cohere_reranker(cohere_reranker_instance): base_reranker_test(contents_result, id_result, score_result, top_k) -@patch.object(cohere.base_client.AsyncBaseCohere, "rerank", mock_cohere_reranker) +@patch.object(cohere.client_v2.AsyncClientV2, "rerank", mock_cohere_reranker) def test_cohere_reranker_batch_one(cohere_reranker_instance): top_k = 3 batch = 1 @@ -74,7 +72,7 @@ def test_cohere_reranker_batch_one(cohere_reranker_instance): base_reranker_test(contents_result, id_result, score_result, top_k) -@patch.object(cohere.base_client.AsyncBaseCohere, "rerank", mock_cohere_reranker) +@patch.object(cohere.client_v2.AsyncClientV2, "rerank", mock_cohere_reranker) def test_cohere_node(): top_k = 1 result_df = CohereReranker.run_evaluator( From 26ff9464c42fbd0a793611e4cb9cdc22c811953b Mon Sep 17 00:00:00 2001 From: "Jeffrey (Dongkyu) Kim" Date: Fri, 13 Dec 2024 11:23:20 +0900 Subject: [PATCH 4/4] resolve test issue --- tests/autorag/nodes/passagereranker/test_cohere_reranker.py | 2 ++ tests/autorag/nodes/passagereranker/test_jina_reranker.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/autorag/nodes/passagereranker/test_cohere_reranker.py b/tests/autorag/nodes/passagereranker/test_cohere_reranker.py index bd94d29b1..556363ff3 100644 --- a/tests/autorag/nodes/passagereranker/test_cohere_reranker.py +++ b/tests/autorag/nodes/passagereranker/test_cohere_reranker.py @@ -1,3 +1,4 @@ +import os from unittest.mock import patch import cohere @@ -74,6 +75,7 @@ def test_cohere_reranker_batch_one(cohere_reranker_instance): @patch.object(cohere.client_v2.AsyncClientV2, "rerank", mock_cohere_reranker) def test_cohere_node(): + os.environ["CO_API_KEY"] = "test" top_k = 1 result_df = CohereReranker.run_evaluator( project_dir=project_dir, previous_result=previous_result, top_k=top_k diff --git a/tests/autorag/nodes/passagereranker/test_jina_reranker.py b/tests/autorag/nodes/passagereranker/test_jina_reranker.py index 09ae3e784..ca426b1fd 100644 --- a/tests/autorag/nodes/passagereranker/test_jina_reranker.py +++ b/tests/autorag/nodes/passagereranker/test_jina_reranker.py @@ -1,3 +1,4 @@ +import os from unittest.mock import patch import aiohttp @@ -103,6 +104,7 @@ def test_jina_reranker_batch_one(jina_reranker_instance): autorag.nodes.passagereranker.jina, "jina_reranker_pure", mock_jina_reranker_pure ) def test_jina_reranker_node(): + os.environ["JINAAI_API_KEY"] = "test" top_k = 1 result_df = JinaReranker.run_evaluator( project_dir=project_dir, previous_result=previous_result, top_k=top_k