Skip to content

Commit

Permalink
Add Cohere rerank model v3.5 (#1050)
Browse files Browse the repository at this point in the history
* working new model v2 client

* Update the cohere docs

* Add CO_API_KEY env variable and set new mock testing for v2 client
  • Loading branch information
vkehfdl1 authored Dec 19, 2024
1 parent 70a317c commit 0f6bba4
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 13 deletions.
12 changes: 7 additions & 5 deletions autorag/nodes/passagereranker/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ def __init__(self, project_dir: str, *args, **kwargs):
super().__init__(project_dir)
api_key = kwargs.pop("api_key", None)
api_key = os.getenv("COHERE_API_KEY", None) if api_key is None else api_key
if api_key is None:
api_key = os.getenv("CO_API_KEY", None)
if api_key is None:
raise KeyError(
"Please set the API key for Cohere rerank in the environment variable COHERE_API_KEY "
"or directly set it on the config YAML file."
)

self.cohere_client = cohere.AsyncClient(api_key)
self.cohere_client = cohere.AsyncClientV2(api_key=api_key)

def __del__(self):
del self.cohere_client
Expand All @@ -41,7 +43,7 @@ def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
queries, contents, scores, ids = self.cast_to_run(previous_result)
top_k = kwargs.pop("top_k")
batch = kwargs.pop("batch", 64)
model = kwargs.pop("model", "rerank-multilingual-v2.0")
model = kwargs.pop("model", "rerank-v3.5")
return self._pure(queries, contents, scores, ids, top_k, batch, model)

def _pure(
Expand All @@ -52,7 +54,7 @@ def _pure(
ids_list: List[List[str]],
top_k: int,
batch: int = 64,
model: str = "rerank-multilingual-v2.0",
model: str = "rerank-v3.5",
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
"""
Rerank a list of contents with Cohere rerank models.
Expand All @@ -65,8 +67,8 @@ def _pure(
:param top_k: The number of passages to be retrieved
:param batch: The number of queries to be processed in a batch
:param model: The model name for Cohere rerank.
You can choose between "rerank-multilingual-v2.0" and "rerank-english-v2.0".
Default is "rerank-multilingual-v2.0".
You can choose between "rerank-v3.5", "rerank-english-v3.0", and "rerank-multilingual-v3.0".
Default is "rerank-v3.5".
:return: Tuple of lists containing the reranked contents, ids, and scores
"""
# Run async cohere_rerank_pure function
Expand Down
11 changes: 9 additions & 2 deletions docs/source/nodes/passage_reranker/cohere.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ Next, you can set your Cohere API key in the environment variable.
export COHERE_API_KEY=your_cohere_api_key
```

or

```bash
export CO_API_KEY=your_cohere_api_key
```

Or, you can set your Cohere API key in the config.yaml file directly.

```yaml
Expand All @@ -34,8 +40,9 @@ Or, you can set your Cohere API key in the config.yaml file directly.
It sends the batch size of passages to cohere API at once.
If it is too large, it can cause some error.
(default: 64)
- **model** : The type of model you want to use for reranking. Default is "rerank-multilingual-v2.0" and you can change
it to "rerank-multilingual-v1.0" or "rerank-english-v2.0" (default: "rerank-multilingual-v2.0")
- **model** : The type of model you want to use for reranking.
Default is "rerank-v3.5" and you can change
it to "rerank-v3.5" or "rerank-english-v3.0" or "rerank-multilingual-v3.0"
- **api_key** : The cohere api key.
## **Example config.yaml**
Expand Down
10 changes: 6 additions & 4 deletions tests/autorag/nodes/passagereranker/test_cohere_reranker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from unittest.mock import patch

import cohere.base_client
import cohere
import pytest
from cohere import RerankResponse, RerankResponseResultsItem

Expand Down Expand Up @@ -48,7 +49,7 @@ def cohere_reranker_instance():
return CohereReranker(project_dir=project_dir, api_key="test")


@patch.object(cohere.base_client.AsyncBaseCohere, "rerank", mock_cohere_reranker)
@patch.object(cohere.client_v2.AsyncClientV2, "rerank", mock_cohere_reranker)
def test_cohere_reranker(cohere_reranker_instance):
top_k = 3
contents_result, id_result, score_result = cohere_reranker_instance._pure(
Expand All @@ -57,7 +58,7 @@ def test_cohere_reranker(cohere_reranker_instance):
base_reranker_test(contents_result, id_result, score_result, top_k)


@patch.object(cohere.base_client.AsyncBaseCohere, "rerank", mock_cohere_reranker)
@patch.object(cohere.client_v2.AsyncClientV2, "rerank", mock_cohere_reranker)
def test_cohere_reranker_batch_one(cohere_reranker_instance):
top_k = 3
batch = 1
Expand All @@ -72,8 +73,9 @@ def test_cohere_reranker_batch_one(cohere_reranker_instance):
base_reranker_test(contents_result, id_result, score_result, top_k)


@patch.object(cohere.base_client.AsyncBaseCohere, "rerank", mock_cohere_reranker)
@patch.object(cohere.client_v2.AsyncClientV2, "rerank", mock_cohere_reranker)
def test_cohere_node():
os.environ["CO_API_KEY"] = "test"
top_k = 1
result_df = CohereReranker.run_evaluator(
project_dir=project_dir, previous_result=previous_result, top_k=top_k
Expand Down
2 changes: 2 additions & 0 deletions tests/autorag/nodes/passagereranker/test_jina_reranker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from unittest.mock import patch

import aiohttp
Expand Down Expand Up @@ -103,6 +104,7 @@ def test_jina_reranker_batch_one(jina_reranker_instance):
autorag.nodes.passagereranker.jina, "jina_reranker_pure", mock_jina_reranker_pure
)
def test_jina_reranker_node():
os.environ["JINAAI_API_KEY"] = "test"
top_k = 1
result_df = JinaReranker.run_evaluator(
project_dir=project_dir, previous_result=previous_result, top_k=top_k
Expand Down
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,3 @@

def pytest_sessionstart(session):
os.environ["BM25"] = "bm25"
os.environ["JINAAI_API_KEY"] = "mock_jinaai_api_key"
os.environ["COHERE_API_KEY"] = "mock_cohere_api_key"

0 comments on commit 0f6bba4

Please sign in to comment.