Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename DocNode::text to DocNode::content for image/video/... support #364

Merged
merged 7 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lazyllm/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .globals import globals, LazyLlmResponse, LazyLlmRequest, encode_request, decode_request
from .bind import root, Bind as bind, _0, _1, _2, _3, _4, _5, _6, _7, _8, _9
from .queue import FileSystemQueue
from .utils import compile_func
from .utils import compile_func, obj2str, str2obj

__all__ = [
# registry
Expand All @@ -36,6 +36,8 @@
'reset_on_pickle',
'Color',
'colored_text',
'obj2str',
'str2obj',

# arg praser
'LazyLLMCMD',
Expand Down
7 changes: 3 additions & 4 deletions lazyllm/common/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import contextvars
import copy
from typing import Any, Tuple, Optional, List, Dict
import pickle
from pydantic import BaseModel as struct
from .common import package, kwargs
from .deprecated import deprecated
import asyncio
import base64
from .utils import obj2str, str2obj


class ReadWriteLock(object):
Expand Down Expand Up @@ -226,9 +225,9 @@ def __str__(self): return str(self.messages)


def encode_request(input):
return base64.b64encode(pickle.dumps(input)).decode('utf-8')
return obj2str(input)


def decode_request(input, default=None):
if input is None: return default
return pickle.loads(base64.b64decode(input.encode('utf-8')))
return str2obj(input)
8 changes: 8 additions & 0 deletions lazyllm/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Union, Dict, Callable, Any, Optional
import re
import ast
import pickle
import base64

def check_path(
path: Union[str, PathLike],
Expand Down Expand Up @@ -34,3 +36,9 @@ def compile_func(func_code: str, global_env: Optional[Dict[str, Any]] = None) ->
local_dict = {}
exec(func, global_env, local_dict)
return local_dict[fname]

def obj2str(obj: Any) -> str:
return base64.b64encode(pickle.dumps(obj)).decode('utf-8')

def str2obj(data: str) -> Any:
return None if data is None else pickle.loads(base64.b64decode(data.encode('utf-8')))
22 changes: 9 additions & 13 deletions lazyllm/tools/rag/chroma_store.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from typing import Any, Dict, List, Optional, Callable, Set
from lazyllm.thirdparty import chromadb
from lazyllm import LOG
from lazyllm.common import override
from lazyllm.common import override, obj2str, str2obj
from .store_base import StoreBase, LAZY_ROOT_NAME
from .doc_node import DocNode
from .index_base import IndexBase
from .utils import _FileNodeIndex
from .default_index import DefaultIndex
from .map_store import MapStore
import pickle
import base64

# ---------------------------------------------------------------------------- #

Expand Down Expand Up @@ -111,7 +109,7 @@ def _save_nodes(self, nodes: List[DocNode]) -> None:
ids.append(node.uid)
embeddings.append([0]) # we don't use chroma for retrieving
metadatas.append(metadata)
documents.append(node.get_text())
documents.append(obj2str(node.content))
if ids:
collection.upsert(
embeddings=embeddings,
Expand All @@ -132,16 +130,14 @@ def _build_nodes_from_chroma(self, results: Dict[str, List], embed_dims: Dict[st
chroma_metadata = results['metadatas'][i]

parent = chroma_metadata['parent']
local_metadata = pickle.loads(base64.b64decode(chroma_metadata['metadata'].encode('utf-8')))

global_metadata = pickle.loads(base64.b64decode(chroma_metadata['global_metadata'].encode('utf-8')))\
if not parent else None
local_metadata = str2obj(chroma_metadata['metadata'])
global_metadata = str2obj(chroma_metadata['global_metadata']) if not parent else None

node = DocNode(
uid=uid,
text=results["documents"][i],
content=str2obj(results["documents"][i]),
group=chroma_metadata["group"],
embedding=pickle.loads(base64.b64decode(chroma_metadata['embedding'].encode('utf-8'))),
embedding=str2obj(chroma_metadata['embedding']),
parent=parent,
metadata=local_metadata,
global_metadata=global_metadata,
Expand Down Expand Up @@ -170,12 +166,12 @@ def _make_chroma_metadata(self, node: DocNode) -> Dict[str, Any]:
metadata = {
"group": node.group,
"parent": node.parent.uid if node.parent else "",
"embedding": base64.b64encode(pickle.dumps(node.embedding)).decode('utf-8'),
"metadata": base64.b64encode(pickle.dumps(node.metadata)).decode('utf-8'),
"embedding": obj2str(node.embedding),
"metadata": obj2str(node.metadata),
}

if not node.parent:
metadata["global_metadata"] = base64.b64encode(pickle.dumps(node.global_metadata)).decode('utf-8')
metadata["global_metadata"] = obj2str(node.global_metadata)

return metadata

Expand Down
4 changes: 2 additions & 2 deletions lazyllm/tools/rag/doc_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import hashlib
import json
from typing import List, Optional, Dict
from pydantic import BaseModel, Field
Expand All @@ -10,6 +9,7 @@
import lazyllm
from lazyllm import FastapiApp as app
from .utils import DocListManager, BaseResponse
from .doc_impl import gen_docid
from .global_metadata import RAG_DOC_ID, RAG_DOC_PATH


Expand Down Expand Up @@ -74,7 +74,7 @@ def upload_files(self, files: List[UploadFile], override: bool = False, # noqa
lazyllm.LOG.error(f'writing file [{path}] to disk failed: [{e}]')
raise e

file_id = hashlib.sha256(path.encode()).hexdigest()
file_id = gen_docid(path)
self._manager.update_file_status([file_id], status=DocListManager.Status.success)
results.append('Success')

Expand Down
22 changes: 16 additions & 6 deletions lazyllm/tools/rag/doc_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ class MetadataMode(str, Enum):

@reset_on_pickle(('_lock', threading.Lock))
class DocNode:
def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: Optional[str] = None,
embedding: Optional[Dict[str, List[float]]] = None, parent: Optional["DocNode"] = None,
metadata: Optional[Dict[str, Any]] = None, global_metadata: Optional[Dict[str, Any]] = None):
def __init__(self, uid: Optional[str] = None, content: Optional[Union[str, List[Any]]] = None,
group: Optional[str] = None, embedding: Optional[Dict[str, List[float]]] = None,
parent: Optional["DocNode"] = None, metadata: Optional[Dict[str, Any]] = None,
ouonline marked this conversation as resolved.
Show resolved Hide resolved
global_metadata: Optional[Dict[str, Any]] = None, *, text: Optional[str] = None):
if text and content:
raise ValueError('`text` and `content` cannot be set at the same time.')

self.uid: str = uid if uid else str(uuid.uuid4())
self.text: Optional[str] = text
self.content: Optional[Union[str, List[Any]]] = content if content else text
self.group: Optional[str] = group
self.embedding: Optional[Dict[str, List[float]]] = embedding or {}
self._metadata: Dict[str, Any] = metadata or {}
Expand All @@ -37,6 +41,12 @@ def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group:
raise ValueError('only ROOT node can set global metadata.')
self._global_metadata = global_metadata if global_metadata else {}

@property
def text(self) -> str:
if not isinstance(self.content, str):
raise TypeError(f"node content type '{type(self.content)}' is not a string")
return self.content

@property
def root_node(self) -> Optional["DocNode"]:
root = self.parent
Expand Down Expand Up @@ -97,7 +107,7 @@ def get_parent_id(self) -> str:

def __str__(self) -> str:
return (
f"DocNode(id: {self.uid}, group: {self.group}, text: {self.get_text()}) parent: {self.get_parent_id()}, "
f"DocNode(id: {self.uid}, group: {self.group}, content: {self.content}) parent: {self.get_parent_id()}, "
f"children: {self.get_children_str()}"
)

Expand Down Expand Up @@ -159,4 +169,4 @@ def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
return f"{metadata_str}\n\n{self.text}".strip()

def to_dict(self) -> Dict:
return dict(text=self.text, embedding=self.embedding, metadata=self.metadata)
return dict(content=self.content, embedding=self.embedding, metadata=self.metadata)
14 changes: 6 additions & 8 deletions lazyllm/tools/rag/milvus_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from .index_base import IndexBase
from .store_base import StoreBase
from .global_metadata import GlobalMetadataDesc, RAG_DOC_PATH, RAG_DOC_ID
from lazyllm.common import override
import pickle
import base64
from lazyllm.common import override, obj2str, str2obj

class MilvusStore(StoreBase):
# we define these variables as members so that pymilvus is not imported until MilvusStore is instantiated.
Expand All @@ -29,7 +27,7 @@ def _def_constants(self) -> None:
'dtype': pymilvus.DataType.VARCHAR,
'max_length': 256,
},
'text': {
'content': {
'dtype': pymilvus.DataType.VARCHAR,
'max_length': 65535,
},
Expand Down Expand Up @@ -253,9 +251,9 @@ def _construct_filter_expr(self, filters: Dict[str, Union[str, int, List, Set]])
def _serialize_node_partial(self, node: DocNode) -> Dict:
res = {
'uid': node.uid,
'text': node.text,
'content': obj2str(node.content),
'parent': node.parent.uid if node.parent else '',
'metadata': base64.b64encode(pickle.dumps(node._metadata)).decode('utf-8'),
'metadata': obj2str(node._metadata),
}

for k, v in node.embedding.items():
Expand All @@ -273,9 +271,9 @@ def _deserialize_node_partial(self, result: Dict) -> DocNode:

doc = DocNode(
uid=record.pop('uid'),
text=record.pop('text'),
content=str2obj(record.pop('content')),
parent=record.pop('parent'), # this is the parent's uid
metadata=pickle.loads(base64.b64decode(record.pop('metadata').encode('utf-8'))),
metadata=str2obj(record.pop('metadata')),
)

for k, v in record.items():
Expand Down
4 changes: 2 additions & 2 deletions tests/basic_tests/test_rag_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ def _load_data(self, file, extra_info=None, fs=None):
with open(file, 'r') as f:
data = f.read()
node = DocNode(text=data, metadata=extra_info or {})
node.text = "Call the class YmlReader."
node.content = "Call the class YmlReader."
return [node]

def processYml(file, extra_info=None):
with open(file, 'r') as f:
data = f.read()
node = DocNode(text=data, metadata=extra_info or {})
node.text = "Call the function processYml."
node.content = "Call the function processYml."
return [node]

class TestRagReader(object):
Expand Down
12 changes: 10 additions & 2 deletions tests/basic_tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,16 @@ def test_reload(self):
self.store = MilvusStore(group_embed_keys=self.group_embed_keys, embed=self.mock_embed,
embed_dims=self.embed_dims, global_metadata_desc=self.global_metadata_desc,
uri=self.store_file)
self.assertEqual(set([node.uid for node in self.store.get_nodes('group1')]),
set([self.node1.uid, self.node2.uid, self.node3.uid]))

nodes = self.store.get_nodes('group1')
orig_nodes = [self.node1, self.node2, self.node3]
self.assertEqual(set([node.uid for node in nodes]), set([node.uid for node in orig_nodes]))

for node in nodes:
for orig_node in orig_nodes:
if node.uid == orig_node.uid:
self.assertEqual(node.text, orig_node.text)
break

# XXX `array_contains_any` is not supported in local(aka lite) mode. skip this ut
def _test_query_with_array_filter(self):
Expand Down