-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ced0a9d
commit 569fa3c
Showing
2 changed files
with
174 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from functools import lru_cache | ||
from typing import Callable, List, Optional | ||
from lazyllm import ModuleBase, config, LOG | ||
from lazyllm.tools.rag.store import DocNode, MetadataMode | ||
from lazyllm.components.utils.downloader import ModelManager | ||
import numpy as np | ||
|
||
|
||
class RerankerV2(ModuleBase): | ||
registered_reranker = dict() | ||
|
||
def __init__(self, name: str = "ModuleReranker", **kwargs) -> None: | ||
super().__init__() | ||
self.name = name | ||
self.kwargs = kwargs | ||
|
||
def forward(self, nodes: List[DocNode], query: str = "") -> List[DocNode]: | ||
results = self.registered_reranker[self.name](nodes, query=query, **self.kwargs) | ||
LOG.debug(f"Rerank use `{self.name}` and get nodes: {results}") | ||
return results | ||
|
||
@classmethod | ||
def register_reranker( | ||
cls: "RerankerV2", func: Optional[Callable] = None, batch: bool = False | ||
): | ||
def decorator(f): | ||
def wrapper(nodes, **kwargs): | ||
if batch: | ||
return f(nodes, **kwargs) | ||
else: | ||
results = [f(node, **kwargs) for node in nodes] | ||
return [result for result in results if result] | ||
|
||
cls.registered_reranker[f.__name__] = wrapper | ||
return wrapper | ||
|
||
return decorator(func) if func else decorator | ||
|
||
|
||
@lru_cache(maxsize=None) | ||
def get_nlp_and_matchers(language): | ||
import spacy | ||
from spacy.matcher import PhraseMatcher | ||
|
||
nlp = spacy.blank(language) | ||
required_matcher = PhraseMatcher(nlp.vocab) | ||
exclude_matcher = PhraseMatcher(nlp.vocab) | ||
return nlp, required_matcher, exclude_matcher | ||
|
||
|
||
@RerankerV2.register_reranker | ||
def KeywordFilter( | ||
node: DocNode, | ||
required_keys: List[str], | ||
exclude_keys: List[str], | ||
language: str = "en", | ||
**kwargs, | ||
) -> Optional[DocNode]: | ||
nlp, required_matcher, exclude_matcher = get_nlp_and_matchers(language) | ||
if required_keys: | ||
required_matcher.add("RequiredKeywords", list(nlp.pipe(required_keys))) | ||
if exclude_keys: | ||
exclude_matcher.add("ExcludeKeywords", list(nlp.pipe(exclude_keys))) | ||
|
||
doc = nlp(node.get_content()) | ||
if required_keys and not required_matcher(doc): | ||
return None | ||
if exclude_keys and exclude_matcher(doc): | ||
return None | ||
return node | ||
|
||
|
||
@lru_cache(maxsize=None) | ||
def get_cross_encoder_model(model_name: str): | ||
from sentence_transformers import CrossEncoder | ||
|
||
model = ModelManager(config["model_source"]).download(model_name) | ||
return CrossEncoder(model) | ||
|
||
|
||
@RerankerV2.register_reranker(batch=True) | ||
def ModuleReranker( | ||
nodes: List[DocNode], model: str, query: str, topk: int = -1, **kwargs | ||
): | ||
cross_encoder = get_cross_encoder_model(model) | ||
query_pairs = [ | ||
(query, node.get_content(metadata_mode=MetadataMode.EMBED)) for node in nodes | ||
] | ||
scores = cross_encoder.predict(query_pairs) | ||
sorted_indices = np.argsort(scores)[::-1] # Descending order | ||
if topk > 0: | ||
sorted_indices = sorted_indices[:topk] | ||
|
||
return [nodes[i] for i in sorted_indices] | ||
|
||
|
||
# User-defined similarity decorator | ||
def register_reranker(func=None, batch=False): | ||
return RerankerV2.register_reranker(func, batch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import unittest | ||
from unittest.mock import patch, MagicMock | ||
from lazyllm.tools.rag.store import DocNode | ||
from lazyllm.tools.rag.rerank import RerankerV2, register_reranker | ||
|
||
|
||
class TestRerankerV2(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.doc1 = DocNode(text="This is a test document with the keyword apple.") | ||
self.doc2 = DocNode( | ||
text="This is another test document with the keyword banana." | ||
) | ||
self.doc3 = DocNode(text="This document contains the keyword cherry.") | ||
self.nodes = [self.doc1, self.doc2, self.doc3] | ||
self.query = "test query" | ||
|
||
def test_keyword_filter_with_required_keys(self): | ||
required_keys = ["apple"] | ||
exclude_keys = [] | ||
reranker = RerankerV2( | ||
name="KeywordFilter", required_keys=required_keys, exclude_keys=exclude_keys | ||
) | ||
results = reranker.forward(self.nodes, query=self.query) | ||
self.assertEqual(len(results), 1) | ||
self.assertEqual(results[0].get_content(), self.doc1.get_content()) | ||
|
||
def test_keyword_filter_with_exclude_keys(self): | ||
required_keys = [] | ||
exclude_keys = ["banana"] | ||
reranker = RerankerV2( | ||
name="KeywordFilter", required_keys=required_keys, exclude_keys=exclude_keys | ||
) | ||
results = reranker.forward(self.nodes, query=self.query) | ||
self.assertEqual(len(results), 2) | ||
self.assertNotIn(self.doc2, results) | ||
|
||
@patch("lazyllm.components.utils.downloader.ModelManager.download") | ||
@patch("sentence_transformers.CrossEncoder") | ||
def test_module_reranker(self, MockCrossEncoder, mock_download): | ||
mock_model = MagicMock() | ||
mock_download.return_value = "mock_model_path" | ||
MockCrossEncoder.return_value = mock_model | ||
mock_model.predict.return_value = [0.8, 0.6, 0.9] | ||
|
||
reranker = RerankerV2(name="ModuleReranker", model="dummy-model", topk=2) | ||
results = reranker.forward(self.nodes, query=self.query) | ||
|
||
self.assertEqual(len(results), 2) | ||
self.assertEqual( | ||
results[0].get_content(), self.doc3.get_content() | ||
) # highest score | ||
self.assertEqual( | ||
results[1].get_content(), self.doc1.get_content() | ||
) # second highest score | ||
|
||
def test_register_reranker_decorator(self): | ||
@register_reranker | ||
def CustomReranker(node, **kwargs): | ||
if "custom" in node.get_content(): | ||
return node | ||
return None | ||
|
||
custom_doc = DocNode(text="This document contains custom keyword.") | ||
nodes = [self.doc1, self.doc2, self.doc3, custom_doc] | ||
|
||
reranker = RerankerV2(name="CustomReranker") | ||
results = reranker.forward(nodes) | ||
|
||
self.assertEqual(len(results), 1) | ||
self.assertEqual(results[0].get_content(), custom_doc.get_content()) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |