Skip to content

Commit

Permalink
Add server for Document (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 authored Oct 31, 2024
1 parent 549db38 commit 589f53d
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 14 deletions.
3 changes: 2 additions & 1 deletion lazyllm/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .common import package, kwargs, arguments, LazyLLMCMD, timeout, final, ReadOnlyWrapper, DynamicDescriptor
from .common import FlatList, Identity, ResultCollector, ArgsDict, CaseInsensitiveDict
from .common import ReprRule, make_repr, modify_repr
from .common import once_flag, call_once, once_wrapper, singleton
from .common import once_flag, call_once, once_wrapper, singleton, reset_on_pickle
from .option import Option, OptionIter
from .threading import Thread, ThreadPoolExecutor
from .multiprocessing import SpawnProcess, ForkProcess
Expand Down Expand Up @@ -32,6 +32,7 @@
'compile_func',
'DynamicDescriptor',
'singleton',
'reset_on_pickle',

# arg praser
'LazyLLMCMD',
Expand Down
25 changes: 25 additions & 0 deletions lazyllm/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,28 @@ def get_instance(*args, **kwargs):
if cls not in instances: instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance


def reset_on_pickle(*fields):
def decorator(cls):
original_getstate = cls.__getstate__ if hasattr(cls, '__getstate__') else lambda self: self.__dict__
original_setstate = (cls.__setstate__ if hasattr(cls, '__setstate__') else
lambda self, state: self.__dict__.update(state))

def __getstate__(self):
state = original_getstate(self).copy()
for field, *_ in fields:
state[field] = None
return state

def __setstate__(self, state):
original_setstate(self, state)
for field in fields:
field, field_type = field if isinstance(field, (tuple, list)) else (field, None)
if field in state and state[field] is None and field_type is not None:
setattr(self, field, field_type() if field_type else None)

cls.__getstate__ = __getstate__
cls.__setstate__ = __setstate__
return cls
return decorator
5 changes: 3 additions & 2 deletions lazyllm/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,10 @@ def make_intention(base_model: str, nodes: Dict[str, List[dict]],


@NodeConstructor.register('Document')
def make_document(dataset_path: str, embed: Node = None, create_ui: bool = False, node_group: List = []):
def make_document(dataset_path: str, embed: Node = None, create_ui: bool = False,
server: bool = False, node_group: List = []):
document = lazyllm.tools.rag.Document(
dataset_path, Engine().build_node(embed).func if embed else None, manager=create_ui)
dataset_path, Engine().build_node(embed).func if embed else None, server=server, manager=create_ui)
for group in node_group:
if group['transform'] == 'LLMParser': group['llm'] = Engine().build_node(group['llm']).func
elif group['transform'] == 'FuncNode': group['function'] = make_code(group['function'])
Expand Down
5 changes: 5 additions & 0 deletions lazyllm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,11 @@ def __init__(self, m, pre=None, post=None, stream=False, return_trace=False,

_url_id = property(lambda self: self._impl._module_id)

def __call__(self, *args, **kw):
if len(args) > 1:
return super(__class__, self).__call__(package(args), **kw)
return super(__class__, self).__call__(*args, **kw)

def wait(self):
self._impl._launcher.wait()

Expand Down
3 changes: 3 additions & 0 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ def recurse_children(node: DocNode, visited: Set[DocNode]) -> bool:
LOG.debug(f"Found children nodes for {group}: {result}")
return list(result)

def __call__(self, func_name: str, *args, **kwargs):
return getattr(self, func_name)(*args, **kwargs)


DocImpl._create_builtin_node_group(name="CoarseChunk", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
DocImpl._create_builtin_node_group(name="MediumChunk", transform=SentenceSplitter, chunk_size=256, chunk_overlap=25)
Expand Down
42 changes: 34 additions & 8 deletions lazyllm/tools/rag/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, Optional, Dict, Union, List
import lazyllm
from lazyllm import ModuleBase, ServerModule, DynamicDescriptor
from lazyllm.launcher import LazyLLMLaunchersBase as Launcher

from .doc_manager import DocManager
from .doc_impl import DocImpl
Expand All @@ -12,32 +13,43 @@
import functools


class CallableDict(dict):
def __call__(self, cls, *args, **kw):
return self[cls](*args, **kw)

class Document(ModuleBase):
class _Impl(ModuleBase):
def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None,
manager: bool = False, server: bool = False, name: Optional[str] = None, launcher=None):
manager: bool = False, server: bool = False, name: Optional[str] = None,
launcher: Launcher = None):
super().__init__()
if not os.path.exists(dataset_path):
defatult_path = os.path.join(lazyllm.config["data_path"], dataset_path)
if os.path.exists(defatult_path):
dataset_path = defatult_path
launcher = launcher if launcher else lazyllm.launchers.remote(sync=False)
self._launcher: Launcher = launcher if launcher else lazyllm.launchers.remote(sync=False)
self._dataset_path = dataset_path
self._embed = embed if isinstance(embed, dict) else {EMBED_DEFAULT_KEY: embed} if embed else {}
self.name = name
for embed in self._embed.values():
if isinstance(embed, ModuleBase):
self._submodules.append(embed)
self._dlm = DocListManager(dataset_path, name).init_tables()
self._kbs = {DocListManager.DEDAULT_GROUP_NAME: DocImpl(embed=self._embed, dlm=self._dlm)}
self._kbs = CallableDict({DocListManager.DEDAULT_GROUP_NAME: DocImpl(embed=self._embed, dlm=self._dlm)})
if manager: self._manager = ServerModule(DocManager(self._dlm))
if server: self._doc = ServerModule(self._doc)
if server: self._kbs = ServerModule(self._kbs)

def add_kb_group(self, name):
self._kbs[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name)
self._dlm.add_kb_group(name)

def get_doc_by_kb_group(self, name): return self._kbs[name]
def get_doc_by_kb_group(self, name):
return self._kbs._impl._m[name] if isinstance(self._kbs, ServerModule) else self._kbs[name]

def stop(self): self._launcher.cleanup()

def __call__(self, *args, **kw):
return self._kbs(*args, **kw)

def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str, Callable]]] = None,
create_ui: bool = False, manager: bool = False, server: bool = False,
Expand All @@ -57,6 +69,9 @@ def create_kb_group(self, name: str) -> "Document":
@property
def _impl(self): return self._impls.get_doc_by_kb_group(self._curr_group)

@property
def manager(self): return getattr(self._impls, '_manager', None)

@DynamicDescriptor
def create_node_group(self, name: str = None, *, transform: Callable, parent: str = LAZY_ROOT_NAME,
trans_node: bool = None, num_workers: int = 0, **kwargs) -> None:
Expand All @@ -78,14 +93,25 @@ def add_reader(self, pattern: str, func: Optional[Callable] = None):
def register_global_reader(cls, pattern: str, func: Optional[Callable] = None):
return cls.add_reader(pattern, func)

def _forward(self, func_name: str, *args, **kw):
return self._impls(self._curr_group, func_name, *args, **kw)

def find_parent(self, target) -> Callable:
return functools.partial(DocImpl.find_parent, group=target)
# TODO: Currently, when a DocNode is returned from the server, it will carry all parent nodes and child nodes.
# So the query of parent and child nodes can be performed locally, and there is no need to search the
# document service through the server for the time being. When this item is optimized, the code will become:
# return functools.partial(self._forward, 'find_parent', group=target)
return functools.partial(Document.find_parent, group=target)

def find_children(self, target) -> Callable:
return functools.partial(DocImpl.find_children, group=target)
# TODO: Currently, when a DocNode is returned from the server, it will carry all parent nodes and child nodes.
# So the query of parent and child nodes can be performed locally, and there is no need to search the
# document service through the server for the time being. When this item is optimized, the code will become:
# return functools.partial(self._forward, 'find_children', group=target)
return functools.partial(Document.find_children, group=target)

def forward(self, *args, **kw) -> List[DocNode]:
return self._impl.retrieve(*args, **kw)
return self._forward('retrieve', *args, **kw)

def __repr__(self):
return lazyllm.make_repr("Module", "Document", manager=hasattr(self._impl, '_manager'))
3 changes: 2 additions & 1 deletion lazyllm/tools/rag/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from typing import Any, Callable, Dict, List, Optional, Union
import chromadb
from lazyllm import LOG, config
from lazyllm import LOG, config, reset_on_pickle
from chromadb.api.models.Collection import Collection
import threading
import json
Expand All @@ -24,6 +24,7 @@ class MetadataMode(str, Enum):
NONE = auto()


@reset_on_pickle(('_lock', threading.Lock))
class DocNode:
def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: Optional[str] = None,
embedding: Optional[Dict[str, List[float]]] = None, parent: Optional["DocNode"] = None,
Expand Down
4 changes: 2 additions & 2 deletions tests/basic_tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ def test_rag(self):

# test add doc_group
resources[-1] = dict(id='0', kind='Document', name='d1', args=dict(
dataset_path='rag_master', node_group=[dict(name='sentence', transform='SentenceSplitter',
chunk_size=100, chunk_overlap=10)]))
dataset_path='rag_master', server=True, node_group=[
dict(name='sentence', transform='SentenceSplitter', chunk_size=100, chunk_overlap=10)]))
nodes.extend([dict(id='2', kind='Retriever', name='ret2',
args=dict(doc='0', group_name='sentence', similarity='bm25', topk=3)),
dict(id='3', kind='JoinFormatter', name='c', args=dict(type='sum'))])
Expand Down

0 comments on commit 589f53d

Please sign in to comment.