Skip to content

Commit

Permalink
add node transform (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 authored Sep 14, 2024
1 parent 09988d8 commit 58b0a30
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 35 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ build
*.db
mkdocs.yml
.temp
lazyllm_chroma/
docs/en/assets
docs/zh/assets
2 changes: 2 additions & 0 deletions lazyllm/docs/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
Args:
name (str): The name of the node group.
transform (Callable): The transformation rule that converts a node into a node group. The function prototype is `(DocNode, group_name, **kwargs) -> List[DocNode]`. Currently built-in options include [SentenceSplitter][lazyllm.tools.SentenceSplitter], and users can define their own transformation rules.
trans_node (bool): Determines whether the input and output of transform are `DocNode` or `str`, default is None. Can only be set to true when `transform` is `Callable`.
parent (str): The node that needs further transformation. The series of new nodes obtained after transformation will be child nodes of this parent node. If not specified, the transformation starts from the root node.
kwargs: Parameters related to the specific implementation.
''')
Expand All @@ -58,6 +59,7 @@
Args:
name (str): node group 的名称。
transform (Callable): 将 node 转换成 node group 的转换规则,函数原型是 `(DocNode, group_name, **kwargs) -> List[DocNode]`。目前内置的有 [SentenceSplitter][lazyllm.tools.SentenceSplitter]。用户也可以自定义转换规则。
trans_node (bool): 决定了transform的输入和输出是 `DocNode` 还是 `str` ,默认为None。只有在 `transform` 为 `Callable` 时才可以设置为true。
parent (str): 需要进一步转换的节点。转换之后得到的一系列新的节点将会作为该父节点的子节点。如果不指定则从根节点开始转换。
kwargs: 和具体实现相关的参数。
''')
Expand Down
3 changes: 2 additions & 1 deletion lazyllm/tools/rag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .document import Document
from .retriever import Retriever
from .rerank import Reranker, register_reranker
from .transform import SentenceSplitter, LLMParser
from .transform import SentenceSplitter, LLMParser, NodeTransform
from .index import register_similarity
from .store import DocNode

Expand All @@ -10,6 +10,7 @@
"Document",
"Reranker",
"Retriever",
"NodeTransform",
"SentenceSplitter",
"LLMParser",
"register_similarity",
Expand Down
26 changes: 16 additions & 10 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import ast
from collections import defaultdict
from functools import wraps
from typing import Callable, Dict, List, Optional, Set
from typing import Callable, Dict, List, Optional, Set, Union
from lazyllm import LOG, config, once_flag, call_once
from lazyllm.common import LazyLlmRequest
from .transform import FuncNodeTransform, SentenceSplitter, LLMParser
from .transform import NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser
from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME, BaseStore
from .data_loaders import DirectoryReader
from .index import DefaultIndex
Expand Down Expand Up @@ -75,16 +75,22 @@ def _create_node_group_default(self):
chunk_overlap=12,
)

def create_node_group(
self, name, transform: Callable, parent: str = LAZY_ROOT_NAME, **kwargs
) -> None:
def create_node_group(self, name, transform: Union[str, Callable] = None, parent: str = LAZY_ROOT_NAME,
trans_node: bool = None, **kwargs) -> None:
if name in self.node_groups:
LOG.warning(f"Duplicate group name: {name}")
if isinstance(transform, str):
transform = _transmap[transform.lower()]
assert callable(transform), "transform should be callable"
if isinstance(transform, type):
assert trans_node is None, 'Is not allowed to set `trans_node` when transform is `type`'
if not issubclass(type, NodeTransform):
LOG.warning('Please note! You are trying to use a completely custom transform class. The relationship '
'between nodes may become unreliable, `Document.get_parent/get_child` functions and the '
'target parameter of Retriever may have strange anomalies. Please use it at your own risk.')
else:
assert callable(transform), "transform should be callable"
self.node_groups[name] = dict(
transform=transform, transform_kwargs=kwargs, parent_name=parent
transform=transform, trans_node=trans_node, transform_kwargs=kwargs, parent_name=parent
)

def add_files(self, input_files: List[str]) -> None:
Expand Down Expand Up @@ -137,11 +143,11 @@ def _get_transform(self, name):
"Please check the group name or add a new one through `create_node_group`."
)

transform = node_group["transform"]
transform, trans_node = node_group["transform"], node_group["trans_node"]
return (
transform(**node_group["transform_kwargs"])
if isinstance(transform, type)
else FuncNodeTransform(transform)
else FuncNodeTransform(transform, trans_node=trans_node)
)

def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None:
Expand All @@ -150,7 +156,7 @@ def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None:
node_group = self.node_groups.get(group_name)
transform = self._get_transform(group_name)
parent_nodes = self._get_nodes(node_group["parent_name"], store)
nodes = transform(parent_nodes, group_name)
nodes = transform.batch_forward(parent_nodes, group_name)
store.add_nodes(nodes)
LOG.debug(f"building {group_name} nodes: {nodes}")

Expand Down
32 changes: 20 additions & 12 deletions lazyllm/tools/rag/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,27 @@ def split_text_keep_separator(text: str, separator: str) -> List[str]:

class NodeTransform(ABC):

def forward(
def batch_forward(
self, documents: Union[DocNode, List[DocNode]], node_group: str, **kwargs
) -> List[DocNode]:
documents = documents if isinstance(documents, list) else [documents]
all_nodes: List[DocNode] = []
for node in documents:
splits = self.transform(node, **kwargs)
all_nodes.extend(build_nodes_from_splits(splits, node, node_group))
splits = self(node, **kwargs)
for s in splits:
s.parent = node
s.group = node_group
node.children[node_group] = splits
all_nodes.extend(splits)
return all_nodes

@abstractmethod
def transform(self, document: DocNode, **kwargs) -> List[str]:
raise NotImplementedError("Not implemented")

def __call__(
self, nodes: List[DocNode], node_group: str, **kwargs: Any
) -> List[DocNode]:
return self.forward(nodes, node_group, **kwargs)
def __call__(self, node: DocNode, **kwargs: Any) -> List[DocNode]:
# Parent and child should not be set here.
return [DocNode(text=chunk) for chunk in self.transform(node, **kwargs) if chunk]


class SentenceSplitter(NodeTransform):
Expand Down Expand Up @@ -253,17 +256,22 @@ class FuncNodeTransform(NodeTransform):
Wrapped the transform to: List[Docnode] -> List[Docnode]
This wrapper supports:
This wrapper supports when trans_node is False:
1. str -> list: transform=lambda t: t.split('\n')
2. str -> str: transform=lambda t: t[:3]
This wrapper supports when trans_node is True:
1. DocNode -> list: pipeline(lambda x:x, SentenceSplitter)
2. DocNode -> DocNode: pipeline(LLMParser)
"""

def __init__(self, func: Callable[[str], List[str]]):
self._func = func
def __init__(self, func: Union[Callable[[str], List[str]], Callable[[DocNode], List[DocNode]]],
trans_node: bool = None):
self._func, self._trans_node = func, trans_node

def transform(self, node: DocNode, **kwargs) -> List[str]:
result = self._func(node.get_text())
text_splits = [result] if isinstance(result, str) else result
result = self._func(node if self._trans_node else node.get_text())
text_splits = [result] if isinstance(result, (str, DocNode)) else result
return text_splits


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# flake8: noqa: E501

import lazyllm
from lazyllm.tools.rag.transform import SentenceSplitter
from lazyllm.tools.rag.store import DocNode, MetadataMode
from lazyllm.tools.rag.store import DocNode


class TestSentenceSplitter:
Expand All @@ -10,17 +9,17 @@ def setup_method(self):
self.splitter = SentenceSplitter(chunk_size=30, chunk_overlap=10)

def test_forward(self):
text = """ Before college the two main things I worked on, outside of school, were writing and programming. I didn't write essays. I wrote what beginning writers were supposed to write then, and probably still are: short stories. My stories were awful. They had hardly any plot, just characters with strong feelings, which I imagined made them deep.
"""
text = """ Before college the two main things I worked on, outside of school, were writing and programming. I didn't write essays. I wrote what beginning writers were supposed to write then, and probably still are: short stories. My stories were awful. They had hardly any plot, just characters with strong feelings, which I imagined made them deep.""" # noqa: E501
docs = [DocNode(text=text)]

result = self.splitter.forward(docs, node_group='default')
result = self.splitter.batch_forward(docs, node_group='default')
result_texts = [n.get_text() for n in result]
expected_texts = [
"Before college the two main things I worked on, outside of school, were writing and programming.I didn't write essays.",
"I didn't write essays.I wrote what beginning writers were supposed to write then, and probably still are: short stories.My stories were awful.",
"My stories were awful.They had hardly any plot, just characters with strong feelings, which I imagined made them deep.",
"Before college the two main things I worked on, outside of school, were writing and programming.I didn't write essays.", # noqa: E501
"I didn't write essays.I wrote what beginning writers were supposed to write then, and probably still are: short stories.My stories were awful.", # noqa: E501
"My stories were awful.They had hardly any plot, just characters with strong feelings, which I imagined made them deep.", # noqa: E501
]
assert (
result_texts == expected_texts
), f"Expected {expected_texts}, but got {result_texts}"
assert result_texts == expected_texts

trans = lazyllm.pipeline(lambda x: x, self.splitter)
assert [n.get_text() for n in trans(docs[0])] == expected_texts

0 comments on commit 58b0a30

Please sign in to comment.