Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cohere rerank model v3.5 #1050

Merged
merged 7 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions autorag/nodes/passagereranker/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ def __init__(self, project_dir: str, *args, **kwargs):
super().__init__(project_dir)
api_key = kwargs.pop("api_key", None)
api_key = os.getenv("COHERE_API_KEY", None) if api_key is None else api_key
if api_key is None:
api_key = os.getenv("CO_API_KEY", None)
if api_key is None:
raise KeyError(
"Please set the API key for Cohere rerank in the environment variable COHERE_API_KEY "
"or directly set it on the config YAML file."
)

self.cohere_client = cohere.AsyncClient(api_key)
self.cohere_client = cohere.AsyncClientV2(api_key=api_key)

def __del__(self):
del self.cohere_client
Expand All @@ -41,7 +43,7 @@ def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
queries, contents, scores, ids = self.cast_to_run(previous_result)
top_k = kwargs.pop("top_k")
batch = kwargs.pop("batch", 64)
model = kwargs.pop("model", "rerank-multilingual-v2.0")
model = kwargs.pop("model", "rerank-v3.5")
return self._pure(queries, contents, scores, ids, top_k, batch, model)

def _pure(
Expand All @@ -52,7 +54,7 @@ def _pure(
ids_list: List[List[str]],
top_k: int,
batch: int = 64,
model: str = "rerank-multilingual-v2.0",
model: str = "rerank-v3.5",
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
"""
Rerank a list of contents with Cohere rerank models.
Expand All @@ -65,8 +67,8 @@ def _pure(
:param top_k: The number of passages to be retrieved
:param batch: The number of queries to be processed in a batch
:param model: The model name for Cohere rerank.
You can choose between "rerank-multilingual-v2.0" and "rerank-english-v2.0".
Default is "rerank-multilingual-v2.0".
You can choose between "rerank-v3.5", "rerank-english-v3.0", and "rerank-multilingual-v3.0".
Default is "rerank-v3.5".
:return: Tuple of lists containing the reranked contents, ids, and scores
"""
# Run async cohere_rerank_pure function
Expand Down
11 changes: 9 additions & 2 deletions docs/source/nodes/passage_reranker/cohere.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ Next, you can set your Cohere API key in the environment variable.
export COHERE_API_KEY=your_cohere_api_key
```

or

```bash
export CO_API_KEY=your_cohere_api_key
```

Or, you can set your Cohere API key in the config.yaml file directly.

```yaml
Expand All @@ -34,8 +40,9 @@ Or, you can set your Cohere API key in the config.yaml file directly.
It sends the batch size of passages to cohere API at once.
If it is too large, it can cause some error.
(default: 64)
- **model** : The type of model you want to use for reranking. Default is "rerank-multilingual-v2.0" and you can change
it to "rerank-multilingual-v1.0" or "rerank-english-v2.0" (default: "rerank-multilingual-v2.0")
- **model** : The type of model you want to use for reranking.
Default is "rerank-v3.5" and you can change
it to "rerank-v3.5" or "rerank-english-v3.0" or "rerank-multilingual-v3.0"
- **api_key** : The cohere api key.

## **Example config.yaml**
Expand Down
10 changes: 6 additions & 4 deletions tests/autorag/nodes/passagereranker/test_cohere_reranker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from unittest.mock import patch

import cohere.base_client
import cohere
import pytest
from cohere import RerankResponse, RerankResponseResultsItem

Expand Down Expand Up @@ -48,7 +49,7 @@ def cohere_reranker_instance():
return CohereReranker(project_dir=project_dir, api_key="test")


@patch.object(cohere.base_client.AsyncBaseCohere, "rerank", mock_cohere_reranker)
@patch.object(cohere.client_v2.AsyncClientV2, "rerank", mock_cohere_reranker)
def test_cohere_reranker(cohere_reranker_instance):
top_k = 3
contents_result, id_result, score_result = cohere_reranker_instance._pure(
Expand All @@ -57,7 +58,7 @@ def test_cohere_reranker(cohere_reranker_instance):
base_reranker_test(contents_result, id_result, score_result, top_k)


@patch.object(cohere.base_client.AsyncBaseCohere, "rerank", mock_cohere_reranker)
@patch.object(cohere.client_v2.AsyncClientV2, "rerank", mock_cohere_reranker)
def test_cohere_reranker_batch_one(cohere_reranker_instance):
top_k = 3
batch = 1
Expand All @@ -72,8 +73,9 @@ def test_cohere_reranker_batch_one(cohere_reranker_instance):
base_reranker_test(contents_result, id_result, score_result, top_k)


@patch.object(cohere.base_client.AsyncBaseCohere, "rerank", mock_cohere_reranker)
@patch.object(cohere.client_v2.AsyncClientV2, "rerank", mock_cohere_reranker)
def test_cohere_node():
os.environ["CO_API_KEY"] = "test"
top_k = 1
result_df = CohereReranker.run_evaluator(
project_dir=project_dir, previous_result=previous_result, top_k=top_k
Expand Down
2 changes: 2 additions & 0 deletions tests/autorag/nodes/passagereranker/test_jina_reranker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from unittest.mock import patch

import aiohttp
Expand Down Expand Up @@ -103,6 +104,7 @@ def test_jina_reranker_batch_one(jina_reranker_instance):
autorag.nodes.passagereranker.jina, "jina_reranker_pure", mock_jina_reranker_pure
)
def test_jina_reranker_node():
os.environ["JINAAI_API_KEY"] = "test"
top_k = 1
result_df = JinaReranker.run_evaluator(
project_dir=project_dir, previous_result=previous_result, top_k=top_k
Expand Down
2 changes: 0 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,3 @@

def pytest_sessionstart(session):
os.environ["BM25"] = "bm25"
os.environ["JINAAI_API_KEY"] = "mock_jinaai_api_key"
os.environ["COHERE_API_KEY"] = "mock_cohere_api_key"
Loading