Skip to content

Commit

Permalink
add filter support for RAG (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline authored Nov 19, 2024
1 parent c934d8f commit 8035a65
Show file tree
Hide file tree
Showing 26 changed files with 609 additions and 315 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ jobs:
export PYTHONPATH=$PWD:$PYTHONPATH
export LAZYLLM_DATA_PATH=/mnt/lustre/share_data/lazyllm/data/
export LAZYLLM_MODEL_PATH=/mnt/lustre/share_data/lazyllm/models
export LAZYLLM_HOME="${{ env.CI_PATH }}/${{ github.run_id }}-${{ github.job }}"
mkdir -p $LAZYLLM_HOME
python -m pytest --lf --last-failed-no-failures=all --durations=0 --reruns=2 -v tests/basic_tests/
AdvancedStandardTests:
Expand All @@ -85,6 +87,8 @@ jobs:
export PYTHONPATH=$PWD:$PYTHONPATH
export LAZYLLM_DATA_PATH=/mnt/lustre/share_data/lazyllm/data/
export LAZYLLM_MODEL_PATH=/mnt/lustre/share_data/lazyllm/models
export LAZYLLM_HOME="${{ env.CI_PATH }}/${{ github.run_id }}-${{ github.job }}"
mkdir -p $LAZYLLM_HOME
source ~/ENV/env.sh
python -m pytest --lf --last-failed-no-failures=all --durations=0 --reruns=2 -v tests/advanced_tests/standard_test/
Expand All @@ -101,6 +105,8 @@ jobs:
export PYTHONPATH=$PWD:$PYTHONPATH
export LAZYLLM_DATA_PATH=/mnt/lustre/share_data/lazyllm/data/
export LAZYLLM_MODEL_PATH=/mnt/lustre/share_data/lazyllm/models
export LAZYLLM_HOME="${{ env.CI_PATH }}/${{ github.run_id }}-${{ github.job }}"
mkdir -p $LAZYLLM_HOME
python -m pytest --lf --last-failed-no-failures=all --durations=0 --reruns=2 -v tests/advanced_tests/full_test/
ChargeTests:
Expand All @@ -114,5 +120,7 @@ jobs:
export PYTHONPATH=$PWD:$PYTHONPATH
export LAZYLLM_DATA_PATH=/mnt/lustre/share_data/lazyllm/data/
export LAZYLLM_MODEL_PATH=/mnt/lustre/share_data/lazyllm/models
export LAZYLLM_HOME="${{ env.CI_PATH }}/${{ github.run_id }}-${{ github.job }}"
mkdir -p $LAZYLLM_HOME
source ~/ENV/env.sh
python -m pytest --lf --last-failed-no-failures=all --durations=0 --reruns=2 -v tests/charge_tests
88 changes: 50 additions & 38 deletions examples/rag_milvus_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,68 @@

import os
import lazyllm
from lazyllm import bind
import tempfile
from lazyllm import bind, config
from lazyllm.tools.rag import DocField
import shutil

def run(query):
_, store_file = tempfile.mkstemp(suffix=".db")
class TmpDir:
def __init__(self):
self.root_dir = os.path.expanduser(os.path.join(config['home'], 'rag_for_ut'))
self.rag_dir = os.path.join(self.root_dir, 'rag_master')
os.makedirs(self.rag_dir, exist_ok=True)
# creates a dummy file for rag
with open(os.path.join(self.rag_dir, '_dummy.txt'), "wb") as fd:
fd.write(b'dsfjfasfkjdsfewifjewofjefiejw')
self.store_file = os.path.join(self.root_dir, "milvus.db")

milvus_store_conf = {
'type': 'milvus',
'kwargs': {
'uri': store_file,
'embedding_index_type': 'HNSW',
'embedding_metric_type': 'COSINE',
},
}
def __del__(self):
shutil.rmtree(self.root_dir)

documents = lazyllm.Document(dataset_path="rag_master",
embed=lazyllm.TrainableModule("bge-large-zh-v1.5"),
manager=False,
store_conf=milvus_store_conf)
tmp_dir = TmpDir()

documents.create_node_group(name="sentences",
transform=lambda s: '。'.split(s))
milvus_store_conf = {
'type': 'milvus',
'kwargs': {
'uri': tmp_dir.store_file,
'embedding_index_type': 'HNSW',
'embedding_metric_type': 'COSINE',
},
}

prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task.'\
' In this task, you need to provide your answer based on the given context and question.'
doc_fields = {
'comment': DocField(data_type=DocField.DTYPE_VARCHAR, max_size=65535, default_value=' '),
'signature': DocField(data_type=DocField.DTYPE_VARCHAR, max_size=32, default_value=' '),
}

with lazyllm.pipeline() as ppl:
ppl.retriever = lazyllm.Retriever(doc=documents, group_name="sentences", topk=3)
prompt = 'You will play the role of an AI Q&A assistant and complete a dialogue task.'\
' In this task, you need to provide your answer based on the given context and question.'

ppl.reranker = lazyllm.Reranker(name='ModuleReranker',
model="bge-reranker-large",
topk=1,
output_format='content',
join=True) | bind(query=ppl.input)
documents = lazyllm.Document(dataset_path=tmp_dir.rag_dir,
embed=lazyllm.TrainableModule("bge-large-zh-v1.5"),
manager=True,
store_conf=milvus_store_conf,
doc_fields=doc_fields)

ppl.formatter = (
lambda nodes, query: dict(context_str=nodes, query=query)
) | bind(query=ppl.input)
documents.create_node_group(name="block", transform=lambda s: s.split("\n") if s else '')

ppl.llm = lazyllm.TrainableModule('internlm2-chat-7b').prompt(
lazyllm.ChatPrompter(instruction=prompt, extro_keys=['context_str']))
with lazyllm.pipeline() as ppl:
ppl.retriever = lazyllm.Retriever(doc=documents, group_name="block", topk=3)

rag = lazyllm.ActionModule(ppl)
rag.start()
res = rag(query)
ppl.reranker = lazyllm.Reranker(name='ModuleReranker',
model="bge-reranker-large",
topk=1,
output_format='content',
join=True) | bind(query=ppl.input)

os.remove(store_file)
ppl.formatter = (
lambda nodes, query: dict(context_str=nodes, query=query)
) | bind(query=ppl.input)

return res
ppl.llm = lazyllm.TrainableModule('internlm2-chat-7b').prompt(
lazyllm.ChatPrompter(instruction=prompt, extro_keys=['context_str']))

if __name__ == '__main__':
res = run('何为天道?')
rag = lazyllm.ActionModule(ppl)
rag.start()
res = rag('何为天道?')
print(f'answer: {res}')
26 changes: 19 additions & 7 deletions lazyllm/common/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from typing import Type
from lazyllm.thirdparty import redis
from filelock import FileLock

config.add("default_fsqueue", str, "sqlite", "DEFAULT_FSQUEUE")
config.add("fsqredis_url", str, "", "FSQREDIS_URL")
Expand Down Expand Up @@ -64,16 +65,27 @@ def _size(self, id): pass
@abstractmethod
def _clear(self, id): pass

# true means one connection can be used in multiple thread
# refer to: https://sqlite.org/compile.html#threadsafe
def sqlite3_check_threadsafety() -> bool:
conn = sqlite3.connect(":memory:")
res = conn.execute("""
select * from pragma_compile_options
where compile_options like 'THREADSAFE=%'
""").fetchall()
conn.close()
return True if res[0][0] == 'THREADSAFE=1' else False

class SQLiteQueue(FileSystemQueue):
def __init__(self, klass='__default__'):
super(__class__, self).__init__(klass=klass)
self.db_path = os.path.expanduser(os.path.join(config['home'], '.lazyllm_filesystem_queue.db'))
self._lock = threading.Lock()
self._lock = FileLock(self.db_path + '.lock')
self._check_same_thread = not sqlite3_check_threadsafety()
self._initialize_db()

def _initialize_db(self):
with sqlite3.connect(self.db_path) as conn:
with self._lock, sqlite3.connect(self.db_path, check_same_thread=self._check_same_thread) as conn:
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS queue (
Expand All @@ -87,7 +99,7 @@ def _initialize_db(self):

def _enqueue(self, id, message):
with self._lock:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, check_same_thread=self._check_same_thread) as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT MAX(position) FROM queue WHERE id = ?
Expand All @@ -103,7 +115,7 @@ def _enqueue(self, id, message):
def _dequeue(self, id, limit=None):
"""Retrieve and remove all messages from the queue."""
with self._lock:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, check_same_thread=self._check_same_thread) as conn:
cursor = conn.cursor()
if limit:
cursor.execute('SELECT message, position FROM queue WHERE id = ? '
Expand All @@ -123,7 +135,7 @@ def _dequeue(self, id, limit=None):

def _peek(self, id):
with self._lock:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, check_same_thread=self._check_same_thread) as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT message FROM queue WHERE id = ? ORDER BY position ASC LIMIT 1
Expand All @@ -135,7 +147,7 @@ def _peek(self, id):

def _size(self, id):
with self._lock:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, check_same_thread=self._check_same_thread) as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT COUNT(*) FROM queue WHERE id = ?
Expand All @@ -144,7 +156,7 @@ def _size(self, id):

def _clear(self, id):
with self._lock:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, check_same_thread=self._check_same_thread) as conn:
cursor = conn.cursor()
cursor.execute('''
DELETE FROM queue WHERE id = ?
Expand Down
5 changes: 2 additions & 3 deletions lazyllm/components/embedding/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from lazyllm import LOG
from lazyllm.thirdparty import transformers as tf
from lazyllm.thirdparty import torch

from sentence_transformers import CrossEncoder
from lazyllm.thirdparty import sentence_transformers
import numpy as np


Expand Down Expand Up @@ -59,7 +58,7 @@ def __init__(self, base_rerank, source=None, init=False):
lazyllm.call_once(self.init_flag, self.load_reranker)

def load_reranker(self):
self.reranker = CrossEncoder(self.base_rerank)
self.reranker = sentence_transformers.CrossEncoder(self.base_rerank)

def __call__(self, inps):
lazyllm.call_once(self.init_flag, self.load_reranker)
Expand Down
8 changes: 4 additions & 4 deletions lazyllm/components/utils/downloader/model_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, model_source=lazyllm.config['model_source'],
self.model_source = model_source
self.token = token
self.cache_dir = cache_dir
self.model_pathes = model_path.split(":") if len(model_path) > 0 else []
self.model_paths = model_path.split(":") if len(model_path) > 0 else []

@classmethod
def get_model_type(cls, model) -> str:
Expand Down Expand Up @@ -111,7 +111,7 @@ def download(self, model=''):
return model_save_dir if model_save_dir else model

def _model_exists_at_path(self, model_name):
if len(self.model_pathes) == 0:
if len(self.model_paths) == 0:
return None
model_dirs = []

Expand All @@ -122,10 +122,10 @@ def _model_exists_at_path(self, model_name):
model_dirs.append(model_name_mapping[model_name.lower()]['source'][source].replace('/', os.sep))
model_dirs.append(model_name.replace('/', os.sep))

for model_path in self.model_pathes:
for model_path in self.model_paths:
if len(model_path) == 0: continue
if model_path[0] != os.sep:
print(f"[WARNING] skipping path {model_path} as only absolute pathes is accepted.")
print(f"[WARNING] skipping path {model_path} as only absolute paths is accepted.")
continue
for model_dir in model_dirs:
full_model_dir = os.path.join(model_path, model_dir)
Expand Down
3 changes: 2 additions & 1 deletion lazyllm/thirdparty/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __getattribute__(self, __name):

modules = ['redis', 'huggingface_hub', 'jieba', 'modelscope', 'pandas', 'jwt', 'rank_bm25', 'redisvl', 'datasets',
'deepspeed', 'fire', 'numpy', 'peft', 'torch', 'transformers', 'collie', 'faiss', 'flash_attn', 'google',
'lightllm', 'vllm', 'ChatTTS', 'wandb', 'funasr', 'sklearn', 'torchvision', 'scipy', 'pymilvus']
'lightllm', 'vllm', 'ChatTTS', 'wandb', 'funasr', 'sklearn', 'torchvision', 'scipy', 'pymilvus',
'sentence_transformers', 'gradio', 'chromadb']
for m in modules:
vars()[m] = PackageWrapper(m)
2 changes: 2 additions & 0 deletions lazyllm/tools/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader)
from .dataReader import SimpleDirectoryReader
from .doc_manager import DocManager, DocListManager
from .global_metadata import GlobalMetadataDesc as DocField


__all__ = [
Expand Down Expand Up @@ -37,4 +38,5 @@
"SimpleDirectoryReader",
'DocManager',
'DocListManager',
'DocField',
]
41 changes: 23 additions & 18 deletions lazyllm/tools/rag/chroma_store.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any, Dict, List, Optional, Callable, Set
import chromadb
from lazyllm.thirdparty import chromadb
from lazyllm import LOG
from lazyllm.common import override
from chromadb.api.models.Collection import Collection
from .store_base import StoreBase, LAZY_ROOT_NAME
from .doc_node import DocNode
from .index_base import IndexBase
Expand All @@ -20,7 +19,7 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla
self._db_client = chromadb.PersistentClient(path=dir)
LOG.success(f"Initialzed chromadb in path: {dir}")
node_groups = list(group_embed_keys.keys())
self._collections: Dict[str, Collection] = {
self._collections: Dict[str, chromadb.api.models.Collection.Collection] = {
group: self._db_client.get_or_create_collection(group)
for group in node_groups
}
Expand Down Expand Up @@ -78,21 +77,23 @@ def _load_store(self, embed_dims: Dict[str, int]) -> None:
return

# Restore all nodes
uid2node = {}
for group in self._collections.keys():
results = self._peek_all_documents(group)
nodes = self._build_nodes_from_chroma(results, embed_dims)
self._map_store.update_nodes(nodes)
for node in nodes:
uid2node[node.uid] = node

# Rebuild relationships
for group_name in self._map_store.all_groups():
nodes = self._map_store.get_nodes(group_name)
for node in nodes:
if node.parent:
parent_uid = node.parent
parent_node = self._map_store.find_node_by_uid(parent_uid)
node.parent = parent_node
parent_node.children[node.group].append(node)
LOG.debug(f"build {group} nodes from chromadb: {nodes}")
for node in uid2node.values():
if node.parent:
parent_uid = node.parent
parent_node = uid2node.get(parent_uid)
node.parent = parent_node
parent_node.children[node.group].append(node)
LOG.debug(f"build {group} nodes from chromadb: {nodes}")

self._map_store.update_nodes(list(uid2node.values()))
LOG.success("Successfully Built nodes from chromadb.")

def _save_nodes(self, nodes: List[DocNode]) -> None:
Expand Down Expand Up @@ -131,16 +132,19 @@ def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dims: Dict[st
chroma_metadata = results['metadatas'][i]

parent = chroma_metadata['parent']
fields = pickle.loads(base64.b64decode(chroma_metadata['fields'].encode('utf-8')))\
if parent else None
local_metadata = pickle.loads(base64.b64decode(chroma_metadata['metadata'].encode('utf-8')))

global_metadata = pickle.loads(base64.b64decode(chroma_metadata['global_metadata'].encode('utf-8')))\
if not parent else None

node = DocNode(
uid=uid,
text=results["documents"][i],
group=chroma_metadata["group"],
embedding=pickle.loads(base64.b64decode(chroma_metadata['embedding'].encode('utf-8'))),
parent=parent,
fields=fields,
metadata=local_metadata,
global_metadata=global_metadata,
)

if node.embedding:
Expand All @@ -167,10 +171,11 @@ def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]:
"group": node.group,
"parent": node.parent.uid if node.parent else "",
"embedding": base64.b64encode(pickle.dumps(node.embedding)).decode('utf-8'),
"metadata": base64.b64encode(pickle.dumps(node.metadata)).decode('utf-8'),
}

if node.parent:
metadata["fields"] = base64.b64encode(pickle.dumps(node.fields)).decode('utf-8')
if not node.parent:
metadata["global_metadata"] = base64.b64encode(pickle.dumps(node.global_metadata)).decode('utf-8')

return metadata

Expand Down
Loading

0 comments on commit 8035a65

Please sign in to comment.