Skip to content

Commit

Permalink
add reranker (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
yewentao256 authored Jul 30, 2024
1 parent ced0a9d commit 569fa3c
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 0 deletions.
99 changes: 99 additions & 0 deletions lazyllm/tools/rag/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from functools import lru_cache
from typing import Callable, List, Optional
from lazyllm import ModuleBase, config, LOG
from lazyllm.tools.rag.store import DocNode, MetadataMode
from lazyllm.components.utils.downloader import ModelManager
import numpy as np


class RerankerV2(ModuleBase):
registered_reranker = dict()

def __init__(self, name: str = "ModuleReranker", **kwargs) -> None:
super().__init__()
self.name = name
self.kwargs = kwargs

def forward(self, nodes: List[DocNode], query: str = "") -> List[DocNode]:
results = self.registered_reranker[self.name](nodes, query=query, **self.kwargs)
LOG.debug(f"Rerank use `{self.name}` and get nodes: {results}")
return results

@classmethod
def register_reranker(
cls: "RerankerV2", func: Optional[Callable] = None, batch: bool = False
):
def decorator(f):
def wrapper(nodes, **kwargs):
if batch:
return f(nodes, **kwargs)
else:
results = [f(node, **kwargs) for node in nodes]
return [result for result in results if result]

cls.registered_reranker[f.__name__] = wrapper
return wrapper

return decorator(func) if func else decorator


@lru_cache(maxsize=None)
def get_nlp_and_matchers(language):
import spacy
from spacy.matcher import PhraseMatcher

nlp = spacy.blank(language)
required_matcher = PhraseMatcher(nlp.vocab)
exclude_matcher = PhraseMatcher(nlp.vocab)
return nlp, required_matcher, exclude_matcher


@RerankerV2.register_reranker
def KeywordFilter(
node: DocNode,
required_keys: List[str],
exclude_keys: List[str],
language: str = "en",
**kwargs,
) -> Optional[DocNode]:
nlp, required_matcher, exclude_matcher = get_nlp_and_matchers(language)
if required_keys:
required_matcher.add("RequiredKeywords", list(nlp.pipe(required_keys)))
if exclude_keys:
exclude_matcher.add("ExcludeKeywords", list(nlp.pipe(exclude_keys)))

doc = nlp(node.get_content())
if required_keys and not required_matcher(doc):
return None
if exclude_keys and exclude_matcher(doc):
return None
return node


@lru_cache(maxsize=None)
def get_cross_encoder_model(model_name: str):
from sentence_transformers import CrossEncoder

model = ModelManager(config["model_source"]).download(model_name)
return CrossEncoder(model)


@RerankerV2.register_reranker(batch=True)
def ModuleReranker(
nodes: List[DocNode], model: str, query: str, topk: int = -1, **kwargs
):
cross_encoder = get_cross_encoder_model(model)
query_pairs = [
(query, node.get_content(metadata_mode=MetadataMode.EMBED)) for node in nodes
]
scores = cross_encoder.predict(query_pairs)
sorted_indices = np.argsort(scores)[::-1] # Descending order
if topk > 0:
sorted_indices = sorted_indices[:topk]

return [nodes[i] for i in sorted_indices]


# User-defined similarity decorator
def register_reranker(func=None, batch=False):
return RerankerV2.register_reranker(func, batch)
75 changes: 75 additions & 0 deletions tests/basic_tests/test_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import unittest
from unittest.mock import patch, MagicMock
from lazyllm.tools.rag.store import DocNode
from lazyllm.tools.rag.rerank import RerankerV2, register_reranker


class TestRerankerV2(unittest.TestCase):

def setUp(self):
self.doc1 = DocNode(text="This is a test document with the keyword apple.")
self.doc2 = DocNode(
text="This is another test document with the keyword banana."
)
self.doc3 = DocNode(text="This document contains the keyword cherry.")
self.nodes = [self.doc1, self.doc2, self.doc3]
self.query = "test query"

def test_keyword_filter_with_required_keys(self):
required_keys = ["apple"]
exclude_keys = []
reranker = RerankerV2(
name="KeywordFilter", required_keys=required_keys, exclude_keys=exclude_keys
)
results = reranker.forward(self.nodes, query=self.query)
self.assertEqual(len(results), 1)
self.assertEqual(results[0].get_content(), self.doc1.get_content())

def test_keyword_filter_with_exclude_keys(self):
required_keys = []
exclude_keys = ["banana"]
reranker = RerankerV2(
name="KeywordFilter", required_keys=required_keys, exclude_keys=exclude_keys
)
results = reranker.forward(self.nodes, query=self.query)
self.assertEqual(len(results), 2)
self.assertNotIn(self.doc2, results)

@patch("lazyllm.components.utils.downloader.ModelManager.download")
@patch("sentence_transformers.CrossEncoder")
def test_module_reranker(self, MockCrossEncoder, mock_download):
mock_model = MagicMock()
mock_download.return_value = "mock_model_path"
MockCrossEncoder.return_value = mock_model
mock_model.predict.return_value = [0.8, 0.6, 0.9]

reranker = RerankerV2(name="ModuleReranker", model="dummy-model", topk=2)
results = reranker.forward(self.nodes, query=self.query)

self.assertEqual(len(results), 2)
self.assertEqual(
results[0].get_content(), self.doc3.get_content()
) # highest score
self.assertEqual(
results[1].get_content(), self.doc1.get_content()
) # second highest score

def test_register_reranker_decorator(self):
@register_reranker
def CustomReranker(node, **kwargs):
if "custom" in node.get_content():
return node
return None

custom_doc = DocNode(text="This document contains custom keyword.")
nodes = [self.doc1, self.doc2, self.doc3, custom_doc]

reranker = RerankerV2(name="CustomReranker")
results = reranker.forward(nodes)

self.assertEqual(len(results), 1)
self.assertEqual(results[0].get_content(), custom_doc.get_content())


if __name__ == "__main__":
unittest.main()

0 comments on commit 569fa3c

Please sign in to comment.