From 753f649ac2d4a2aea7d355cddd44ccc195cffd5a Mon Sep 17 00:00:00 2001 From: wangzhihong Date: Fri, 14 Jun 2024 19:57:53 +0800 Subject: [PATCH] add prompter API docs and best-practice docs (#7) --- README.ENG.md | 2 + README.md | 3 +- docs/source/api/components.rst | 17 ++ docs/source/api/flow.rst | 2 + docs/source/best_practice/flow.rst | 263 +++++++++++++++++++ docs/source/best_practice/module.rst | 4 + docs/source/best_practice/prompt.rst | 29 ++ docs/source/best_practice/rag.rst | 44 ++++ docs/source/contribution/contribution.rst | 2 +- docs/source/index.rst | 10 +- docs/source/user_guide/FAQ.rst | 2 + docs/source/user_guide/best_practice.rst | 4 - docs/source/user_guide/environment.rst | 6 +- lazyllm/common/common.py | 9 +- lazyllm/components/prompter/__init__.py | 3 +- lazyllm/components/prompter/builtinPrompt.py | 2 +- lazyllm/docs/components.py | 150 ++++++++++- 17 files changed, 533 insertions(+), 19 deletions(-) create mode 100644 docs/source/best_practice/flow.rst create mode 100644 docs/source/best_practice/module.rst create mode 100644 docs/source/best_practice/prompt.rst create mode 100644 docs/source/best_practice/rag.rst delete mode 100644 docs/source/user_guide/best_practice.rst diff --git a/README.ENG.md b/README.ENG.md index f927c238..12a9902c 100644 --- a/README.ENG.md +++ b/README.ENG.md @@ -11,6 +11,8 @@ LazyLLM is a low-code development tool for building multi-agent LLMs(large langu The AI application development process based on LazyLLM follows the **prototype building -> data feedback -> iterative optimization** workflow. This means you can quickly build a prototype application using LazyLLM, then analyze bad cases using task-specific data, and subsequently iterate on algorithms and fine-tune models at critical stages of the application to gradually enhance the overall performance.
+**Tutorials**: https://lazyllm.readthedocs.io/ + ## Features **Convenient AI Application Assembly Process**: Even if you are not familiar with large models, you can still easily assemble AI applications with multiple agents using our built-in data flow and functional modules, just like Lego building. diff --git a/README.md b/README.md index a4bfd0a2..16226bf5 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,8 @@ ## 一、简介 LazyLLM是一款低代码构建**多Agent**大模型应用的开发工具,协助开发者用极低的成本构建复杂的AI应用,并可以持续的迭代优化效果。LazyLLM提供了便捷的搭建应用的workflow,并且为应用开发过程中的各个环节提供了大量的标准流程和工具。
-基于LazyLLM的AI应用构建流程是**原型搭建 -> 数据回流 -> 迭代优化**,即您可以先基于LazyLLM快速跑通应用的原型,再结合场景任务数据进行bad-case分析,然后对应用中的关键环节进行算法迭代和模型微调,进而逐步提升整个应用的效果。 +基于LazyLLM的AI应用构建流程是**原型搭建 -> 数据回流 -> 迭代优化**,即您可以先基于LazyLLM快速跑通应用的原型,再结合场景任务数据进行bad-case分析,然后对应用中的关键环节进行算法迭代和模型微调,进而逐步提升整个应用的效果。
+**用户文档**: https://lazyllm.readthedocs.io/ ## 二、特性 diff --git a/docs/source/api/components.rst b/docs/source/api/components.rst index 67f2b839..f1d1b201 100644 --- a/docs/source/api/components.rst +++ b/docs/source/api/components.rst @@ -30,6 +30,23 @@ Launcher :members: :exclude-members: Status, get_idle_nodes +Prompter +========= + +.. autoclass:: lazyllm.components.prompter.LazyLLMPrompterBase + :members: generate_prompt, get_response + :exclude-members: + +.. autoclass:: lazyllm.components.AlpacaPrompter + :members: generate_prompt, get_response + :exclude-members: + +.. autoclass:: lazyllm.components.ChatPrompter + :members: generate_prompt, get_response + :exclude-members: + +.. _api.components.register: + Register ========= diff --git a/docs/source/api/flow.rst b/docs/source/api/flow.rst index 7259f600..b8063e8f 100644 --- a/docs/source/api/flow.rst +++ b/docs/source/api/flow.rst @@ -1,3 +1,5 @@ +.. _api.flow: + lazyllm.Flow ----------------------- diff --git a/docs/source/best_practice/flow.rst b/docs/source/best_practice/flow.rst new file mode 100644 index 00000000..9a77d596 --- /dev/null +++ b/docs/source/best_practice/flow.rst @@ -0,0 +1,263 @@ +LazyLLM中的数据流 +----------------- + +LazyLLM中定义了大量的数据流组件,用于让您像搭积木一样,借助LazyLLM中提供的工具和组件,来搭建复杂的大模型应用。本节会详细介绍数据流的使用方法。 + +定义和API文档 +============ +数据流的定义和基本使用方法如 :ref:`api.flow` 中所述 + +pipeline +============ + +基本使用 +^^^^^^^^ + +Pipeline是顺次执行的数据流,上一个阶段的输出成为下一个阶段的输入。pipeline支持函数和仿函数(或仿函数的type)。一个典型的pipeline如下所示: + +.. code-block:: python + + from lazyllm import pipeline + + class Functor(object): + def __call__(self, x): return x * x + + def f1(input): return input + 1 + f2 = lambda x: x * 2 + f3 = Functor() + + assert pipeline(f1, f2, f3, Functor)(1) == 256 + + +.. note:: + 借助LazyLLM的注册机制 :ref:`api.components.register` 注册的函数,也可以直接被pipeline使用,下面给出一个例子 + + +.. code-block:: python + + import lazyllm + from lazyllm import pipeline, component_register + + component_register.new_group('g1') + + @component_register('g1') + def test1(input): return input + 1 + + @component_register('g1') + def test2(input): return input * 3 + + assert pipeline(lazyllm.g1.test1, lazyllm.g1.test2(launcher=lazyllm.launchers.empty))(1) == 6 + + +with语句 +^^^^^^^^ + +除了基本的用法之外,pipeline还支持一个更为灵活的用法 ``with pipeline() as p`` 来让代码更加的简洁和清晰,示例如下 + +.. code-block:: python + + from lazyllm import pipeline + + class Functor(object): + def __call__(self, x): return x * x + + def f1(input): return input + 1 + f2 = lambda x: x * 2 + f3 = Functor() + + with pipeline() as p: + p.f1 = f1 + p.f2 = f2 + p.f3 = f3 + + assert p(1) == 16 + +.. note:: + ``parallel``, ``diverter`` 等也支持with的用法。 + +参数绑定 +^^^^^^^^ + +很多时候,我们并不希望一成不变的将上级的输出给到下一级作为输入,某一下游环节可以需要很久之前的某环节的输出,甚至是整个pipeline的输入。 +在计算图模式的范式下(例如dify和llamaindex),会把函数作为节点,把数据作为边,通过添加边的方式来实现这一行为。 +但LazyLLM不会让你如此复杂,你仅需要掌握参数绑定,就可以自由的在pipeline中从上游向下游传递参数。 + +假设我们定义了一些函数,本小节会一直使用这些函数,不再重复定义。 + +.. code-block:: python + + def f1(input, input2=0): return input + input2 + 1 + def f2(input): return input + 3 + def f3(input): return f'f3-{input}' + def f4(in1, in2, in3): return f'get [{in1}], [{in2}], [{in3}]' + +下面给出一个参数绑定的具体例子: + +.. code-block:: python + + from lazyllm import pipeline, _0 + with pipeline() as p: + p.f1 = f1 + p.f2 = f2 + p.f3 = f3 + p.f4 = bind(f4, p.input, _0, p.f2) + assert p(1) == 'get [1], [f3-5], [5]' + +上述例子中, ``bind`` 函数用于参数绑定,它的基本使用方法和C++的 ``std::bind`` 一致,其中 ``_0`` 表示新函数的第0个参数在被绑定的函数的参数表中的位置。 +对于上面的案例,整个pipeline的输入会作为f4的第一个参数(此处我们假设从第一个开始计数),f3的输出(即新函数的输入)会作为f4的第二个参数,f2的输出会作为f4的第三个参数。 + +.. note:: + + - 参数绑定仅在一个pipeline中生效(注意,当flow出现嵌套时,在子flow中不生效),仅允许下游函数绑定上游函数的输出作为参数。 + - 使用参数绑定后,平铺的方式传入的参数中,未被 ``_0``, ``_1``等 ``placeholder`` 引用的会被丢弃 + +上面的方式已经足够简单和清晰,如果您仍然觉得 ```bind`` 作为函数不够直观,可以尝试使用如下方式,两种方式没有任何区别: + +.. code-block:: python + + from lazyllm import pipeline, _0 + with pipeline() as p: + p.f1 = f1 + p.f2 = f2 + p.f3 = f3 + p.f4 = f4 | bind(p.input, _0, p.f2) + assert p(1) == 'get [1], [f3-5], [5]' + +.. note:: + + 请小心lambda函数!如果使用了lambda函数,请注意给lambda函数加括号,例如 ``(lambda x, y: pass) | bind(1, _0)`` + +除了C++的bind方式之外,作为python,我们额外提供了 ``kwargs`` 的参数绑定, ``kwargs``和c++的绑定方式可以混合使用,示例如下: + +.. code-block:: python + + from lazyllm import pipeline, _0 + with pipeline() as p: + p.f1 = f1 + p.f2 = f2 + p.f3 = f3 + p.f4 = f4 | bind(p.input, _0, in3=p.f2) + assert p(1) == 'get [1], [f3-5], [5]' + +.. note:: + + 通过 ``kwargs`` 绑定的参数的值不能使用 ``_0`` 等 + +如果pipeline的输入比较复杂,可以直接对 ``input`` 做一次简单的解析处理,示例如下: + +.. code-block:: python + + def f1(input): return dict(a=input[0], b=input[1]) + def f2(input): return input['a'] + input['b'] + def f3(input, extro): return f'[{input} + {extro}]' + + with pipeline() as p1: + p1.f1 = f1 + with pipeline() as p1.p2: + p2.f2 = f2 + p2.f3 = f3 | bind(extro=p2.input['b']) + p1.f3 = f3 | bind(extro=p1.input[0]) + + assert p1([1, 2]) == '[[3 + 2] + 1]' + +上面的例子比较复杂,我们逐步来解析。首先输入的list经过 ``p1.f1`` 变成 ``dict(a=1, b=2)`` ,则p2的输入也是 ``dict(a=1, b=2)``,经过 ``p2.f2`` 之后输出为 ``3``, +然后 ``p2.f3`` 绑定了 ``p2`` 的输入的 ``['b']``, 即 ``2``, 因此p2.f3的输出是 ``[3 + 2]``, 回到 ``p1.f3``,它绑定了 ``p1`` 的输入的第 ``0`` 个元素,因此最终的输出是 ``[[3 + 2] + 1]`` + +pipeline.bind +^^^^^^^^^^^^^^^^ +当发生pipeline的嵌套(或pipeline与其他flow的嵌套时),我们有时候需要将外层的输入传递到内层中,此时也可以使用bind,示例如下: + +.. code-block:: python + + from lazyllm import pipeline, _0 + with pipeline() as p1: + p1.f1 = f1 + p1.f2 = f2 + with pipeline().bind(extro=p1.input[0]) as p1.p2: + p2.f3 = f3 + p1.p3 = pipeline(f3) | bind(extro=p1.input[1]) + + assert p1([1, 2]) == '[[3 + 1] + 2]' + +AutoCapture(试验特性) +^^^^^^^^^^^^^^^^^^^^^ +为了进一步简化代码的复杂性,我们上线了自动捕获with块内定义的变量的能力,示例如下: + +.. code-block:: python + + from lazyllm import pipeline, _0 + with pipeline(auto_capture=True) as p: + p1 = f1 + p2 = f2 + p3 = f3 + p4 = f4 | bind(p.input, _0, in3=p2) + + assert p(1) == 'get [1], [f3-5], [5]' + +.. note:: + - 该能力目前还不是很完善,不推荐大家使用,敬请期待 + +parallel +============ + +parallel的所有组件共享输入,并将结果合并输出。 ``parallel`` 的定义方法和 ``pipeline`` 类似,也可以直接在定义 ``parallel`` 时初始化其元素,或在with块中初始化其元素。 + +.. note:: + + 因 ``parallel`` 所有的模块共享输入,因此 ``parallel`` 的输入不支持被参数绑定。 + +结果后处理 +^^^^^^^^^ + +为了进一步简化流程的复杂性,不引入过多的匿名函数,parallel的结果可以做一个简单的后处理(目前仅支持 ``sum`` 或 ``asdict``),然后传给下一级。下面给出一个例子: + +.. code-block:: python + + from lazyllm import parallel + + def f1(input): return input + + with parallel() as p: + p.f1 = f1 + p.f2 = f1 + assert p(1) == (1, 1) + + with parallel().asdict as p: + p.f1 = f1 + p.f2 = f1 + assert p(1) == dict(f1=1, f2=1) + + with parallel().sum as p: + p.f1 = f1 + p.f2 = f1 + assert p(1) == 2 + +.. note:: + + 如果使用 ``asdict``, 需要为 ``parallel``中的元素取名字,返回的 ``dict``的 ``key``即为元素的名字。 + +顺序执行 +^^^^^^^^^ + +``parallel`` 默认是多线程并行执行的,在一些特殊情况下,可以根据需求改成顺序执行。下面给出一个例子: + +.. code-block:: python + + from lazyllm import parallel + + def f1(input): return input + + with parallel.sequential() as p: + p.f1 = f1 + p.f2 = f1 + assert p(1) == (1, 1) + +.. note:: + + ``diverter`` 也可以通过 ``.sequential``来实现顺序执行 + + +小结 +============ + +本篇着重讲解了 ``pipeline`` 和 ``parallel``,相信您对如何利用LazyLLM的flow搭建复杂的应用已经有了初步的认识,其他的数据流组件不做过多赘述,您可以参考 :ref:`api.flow` 来获取他们的使用方式。 diff --git a/docs/source/best_practice/module.rst b/docs/source/best_practice/module.rst new file mode 100644 index 00000000..7e1a198d --- /dev/null +++ b/docs/source/best_practice/module.rst @@ -0,0 +1,4 @@ +LazyLLM的顶层核心概念:模块 +========================= + + diff --git a/docs/source/best_practice/prompt.rst b/docs/source/best_practice/prompt.rst new file mode 100644 index 00000000..ff0f0d14 --- /dev/null +++ b/docs/source/best_practice/prompt.rst @@ -0,0 +1,29 @@ +Prompter +============ + +为了让您在不同的线上模型和不同的本地模型都能获得一致的使用体验,在微调和推理中也能获得一致的使用体验,我们定义了Prompter + +LazyLLM Prompter的设计思路 +------------------------- + +基本概念说明 +^^^^^^^^^^^^^ + +设计思路 +^^^^^^^^^^^^^ + + +Prompter示例 +------------------------- + +Prompter的使用和定义方式 +^^^^^^^^^^^^^^^^^^^^^^^^^ + +和OnlineChatModule一起使用 +^^^^^^^^^^^^^^^^^^^^^^^^^ + +和TrainableModule一起使用 +^^^^^^^^^^^^^^^^^^^^^^^^^ + +LazyLLM中内置的场景Prompt +------------------------- \ No newline at end of file diff --git a/docs/source/best_practice/rag.rst b/docs/source/best_practice/rag.rst new file mode 100644 index 00000000..136ff43d --- /dev/null +++ b/docs/source/best_practice/rag.rst @@ -0,0 +1,44 @@ +RAG +================== + +检索增强生成(Retrieval-augmented Generation, RAG)是当前备受关注的大模型前沿技术之一。其工作原理是,当模型需要生成文本或回答问题时,首先会从一个庞大的文档集合中检索出相关的信息。这些检索到的信息随后用于指导生成过程,从而显著提高生成文本的质量和准确性。通过这种方式,RAG能够在处理复杂问题时提供更加精确和有意义的回答,是自然语言处理领域的重要进展之一。这种方法的优越性在于它结合了检索和生成的优势,使得模型不仅能够生成流畅的文本,还能基于真实数据提供有依据的回答。 +本文展示了如何利用LazyLLM搭建自己的RAG应用,并随心所欲的增加召回策略。 + +RAG的原理简介 +------------------- + +用LazyLLM搭建你的第一个RAG应用 +------------------- + +基本的RAG ++++++++++++++++++++ + + +文档管理服务 ++++++++++++++++++++ + + +部署本地模型并使用 ++++++++++++++++++++ + + +多路召回 ++++++++++++++++++++ + + +自定义parser ++++++++++++++++++++ + + +微调你的模型 ++++++++++++++++++++ + +线上模型的微调 +^^^^^^^^^^^^^^^^ + +本地模型的微调 +^^^^^^^^^^^^^^^^ + + +LazyLLM中RAG模块的设计思想 +---------------------------- \ No newline at end of file diff --git a/docs/source/contribution/contribution.rst b/docs/source/contribution/contribution.rst index d5835356..e382b042 100644 --- a/docs/source/contribution/contribution.rst +++ b/docs/source/contribution/contribution.rst @@ -119,7 +119,7 @@ 如果你非常善于处理冲突,那么可以使用 rebase 的方式来解决冲突,因为这能够保证你的 commit log 的整洁。如果你不太熟悉 ``rebase`` 的使用,那么可以使用 ``merge`` 的方式来解决冲突。 pull request规范 -~~~~~~~~~ +~~~~~~~~~~~~~~~~~ 1. 一个 ``pull request`` 对应一个短期分支 diff --git a/docs/source/index.rst b/docs/source/index.rst index d6a26bda..535a716f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -7,9 +7,17 @@ LazyLLM user_guide/getting_start user_guide/environment - user_guide/best_practice user_guide/FAQ +.. toctree:: + :caption: Best Practice Guide + :maxdepth: 1 + + best_practice/flow + best_practice/prompt + best_practice/module + best_practice/rag + .. toctree:: :caption: LazyLLM API :maxdepth: 1 diff --git a/docs/source/user_guide/FAQ.rst b/docs/source/user_guide/FAQ.rst index e69de29b..3617aa2a 100644 --- a/docs/source/user_guide/FAQ.rst +++ b/docs/source/user_guide/FAQ.rst @@ -0,0 +1,2 @@ +FAQ +============ \ No newline at end of file diff --git a/docs/source/user_guide/best_practice.rst b/docs/source/user_guide/best_practice.rst deleted file mode 100644 index 847e905f..00000000 --- a/docs/source/user_guide/best_practice.rst +++ /dev/null @@ -1,4 +0,0 @@ -最佳实践 -======= - -最佳实践编写中,敬请期待 \ No newline at end of file diff --git a/docs/source/user_guide/environment.rst b/docs/source/user_guide/environment.rst index 487c5e1d..6784aaad 100644 --- a/docs/source/user_guide/environment.rst +++ b/docs/source/user_guide/environment.rst @@ -1,8 +1,8 @@ 环境依赖 -======= +======== 依赖及场景说明 -~~~~~~~~~~ +~~~~~~~~~~~~~ - 微调(基于alpaca-lora): datasets, deepspeed, faiss-cpu, fire, gradio, numpy, peft, torch, transformers - 微调(基于collie): collie-lm, numpy, peft, torch, transformers, datasets, deepspeed, fire @@ -11,7 +11,7 @@ - RAG: llama_index, llama-index-retrievers-bm25, llama-index-storage-docstore-redis, llama-index-vector-stores-redis, rank_bm25, redisvl, llama_index, llama-index-embeddings-huggingface 基础依赖 -~~~~~~~ +~~~~~~~~~~ - fastapi: FastAPI 是一个现代、快速(高性能)的Web框架,用于构建API,与Python 3.6+类型提示一起使用。 - loguru: Loguru 是一个Python日志库,旨在通过简洁、易用的API提供灵活的日志记录功能。 diff --git a/lazyllm/common/common.py b/lazyllm/common/common.py index 613289f5..66d50c25 100644 --- a/lazyllm/common/common.py +++ b/lazyllm/common/common.py @@ -120,7 +120,9 @@ class Bind(object): class __None: pass class Input(object): - def __init__(self): self._item_key, self._attr_key = None, None + class __None: pass + + def __init__(self): self._item_key, self._attr_key = Bind.Input.__None, Bind.Input.__None def __getitem__(self, key): self._item_key = key @@ -135,9 +137,8 @@ def get_input(self, input): input = input.input if input.input else input.kwargs elif isinstance(input, LazyLlmResponse): input = input.messages - if self._item_key: - return input[self._item_key] - elif self._attr_key: return getattr(input, self._attr_key) + if self._item_key is not Bind.Input.__None: return input[self._item_key] + elif self._attr_key is not Bind.Input.__None: return getattr(input, self._attr_key) return input def __init__(self, __bind_func=__None, *args, **kw): diff --git a/lazyllm/components/prompter/__init__.py b/lazyllm/components/prompter/__init__.py index e4049f0f..da90ecab 100644 --- a/lazyllm/components/prompter/__init__.py +++ b/lazyllm/components/prompter/__init__.py @@ -1,12 +1,13 @@ from .prompter import Prompter from .alpacaPrompter import AlpacaPrompter from .chatPrompter import ChatPrompter -from .builtinPrompt import LazyLLMPrompterBase as PrompterBase, EmptyPrompter +from .builtinPrompt import LazyLLMPrompterBase, LazyLLMPrompterBase as PrompterBase, EmptyPrompter __all__ = [ 'Prompter', 'AlpacaPrompter', 'ChatPrompter', + 'LazyLLMPrompterBase', 'PrompterBase', 'EmptyPrompter', ] diff --git a/lazyllm/components/prompter/builtinPrompt.py b/lazyllm/components/prompter/builtinPrompt.py index a7ea8e02..a1424c81 100644 --- a/lazyllm/components/prompter/builtinPrompt.py +++ b/lazyllm/components/prompter/builtinPrompt.py @@ -98,7 +98,7 @@ def generate_prompt(self, input: Union[str, Dict[str, str], None] = None, history: List[Union[List[str], Dict[str, Any]]] = None, tools: Union[List[Dict[str, Any]], None] = None, label: Union[str, None] = None, - *, show: bool = False, return_dict: bool = False) -> str: + *, show: bool = False, return_dict: bool = False) -> Union[str, Dict]: instruction, input = self._get_instruction_and_input(input) history = self._get_histories(history, return_dict=return_dict) tools = self._get_tools(tools, return_dict=return_dict) diff --git a/lazyllm/docs/components.py b/lazyllm/docs/components.py index 72347524..feded438 100644 --- a/lazyllm/docs/components.py +++ b/lazyllm/docs/components.py @@ -512,15 +512,159 @@ def test_prompter(): ''') add_example('ModelDownloader', '''\ - >>> downloader = ModelDownloader(model_source='huggingface') - >>> downloader.download('GLM3-6B') +>>> downloader = ModelDownloader(model_source='huggingface') +>>> downloader.download('GLM3-6B') ''') + +# ============= Prompter + +add_chinese_doc('prompter.PrompterBase', '''\ +Prompter的基类,自定义的Prompter需要继承此基类,并通过基类提供的 ``_init_prompt`` 函数来设置Prompt模板和Instruction的模板,以及截取结果所使用的字符串。可以查看 :doc:`/best_practice/prompt` 进一步了解Prompt的设计思想和使用方式。 + +Prompt模板和Instruction模板都用 ``{}`` 表示要填充的字段,其中Prompt可包含的字段有 ``system``, ``history``, ``tools``等,而instruction_template可包含的字段为 ``instruction`` 和 ``extro_keys`` 。 +``instruction`` 由应用的开发者传入, ``instruction`` 中也可以带有 ``{}`` 用于让定义可填充的字段,方便用户填入额外的信息。 +''') + +add_english_doc('prompter.PrompterBase', '''\ +The base class of Prompter. A custom Prompter needs to inherit from this base class and set the Prompt template and the Instruction template using the `_init_prompt` function provided by the base class, as well as the string used to capture results. Refer to :doc:`/best_practice/prompt.rst` for further understanding of the design philosophy and usage of Prompts. + +Both the Prompt template and the Instruction template use ``{}`` to indicate the fields to be filled in. The fields that can be included in the Prompt are `system`, `history`, `tools`, etc., while the fields that can be included in the instruction_template are `instruction` and `extro_keys`. +``instruction`` is passed in by the application developer, and the ``instruction`` can also contain ``{}`` to define fillable fields, making it convenient for users to input additional information. +''') + +add_example('prompter.PrompterBase', '''\ +>>> from lazyllm.components.prompter import PrompterBase +>>> class MyPrompter(PrompterBase): +... def __init__(self, instruction = None, extro_keys = None, show = False): +... super(__class__, self).__init__(show) +... instruction_template = f'{instruction}\\n{{extro_keys}}\\n'.replace('{extro_keys}', PrompterBase._get_extro_key_template(extro_keys)) +... self._init_prompt("{system}\\n{instruction}{history}\\n{input}\\n, ## Response::", instruction_template, '## Response::') +... +>>> p = MyPrompter('ins {instruction}') +>>> p.generate_prompt('hello') +'You are an AI-Agent developed by LazyLLM.\\nins hello\\n\\n\\n\\n, ## Response::' +>>> p.generate_prompt('hello world', return_dict=True) +{'messages': [{'role': 'system', 'content': 'You are an AI-Agent developed by LazyLLM.\\nins hello world\\n\\n'}, {'role': 'user', 'content': ''}]} +''') + +add_chinese_doc('prompter.PrompterBase.generate_prompt', '''\ +根据用户输入,生成对应的Prompt. + +Args: + input (Option[str | Dict]): Prompter的输入,如果是dict,会填充到instruction的槽位中;如果是str,则会作为输入。 + history (Option[List[List | Dict]]): 历史对话,可以为 ``[[u, s], [u, s]]`` 或 openai的history格式,默认为None。 + tools (Option[List[Dict]]: 可以使用的工具合集,大模型用作FunctionCall时使用,默认为None + label (Option[str]): 标签,训练或微调时使用,默认为None + show (bool): 标志是否打印生成的Prompt,默认为False + return_dict (bool): 标志是否返回dict,一般情况下使用 ``OnlineChatModule`` 时会设置为True。如果返回dict,则仅填充 ``instruction``。默认为False +''') + +add_english_doc('prompter.PrompterBase.generate_prompt', '''\ + +Generate a corresponding Prompt based on user input. + +Args: + input (Option[str | Dict]): The input from the prompter, if it's a dict, it will be filled into the slots of the instruction; if it's a str, it will be used as input. + history (Option[List[List | Dict]]): Historical conversation, can be ``[[u, s], [u, s]]`` or in openai's history format, defaults to None. + tools (Option[List[Dict]]): A collection of tools that can be used, used when the large model performs FunctionCall, defaults to None. + label (Option[str]): Label, used during fine-tuning or training, defaults to None. + show (bool): Flag indicating whether to print the generated Prompt, defaults to False. + return_dict (bool): Flag indicating whether to return a dict, generally set to True when using ``OnlineChatModule``. If returning a dict, only the ``instruction`` will be filled. Defaults to False. +''') + +add_chinese_doc('prompter.PrompterBase.get_response', '''\ +用作对Prompt的截断,只保留有价值的输出 + +Args: + output (str): 大模型的输出 + input (Option[[str]): 大模型的输入,若指定此参数,会将输出中包含输入的部分全部截断,默认为None +''') + +add_english_doc('prompter.PrompterBase.get_response', '''\ +Used to truncate the Prompt, keeping only valuable output. + +Args: + output (str): The output of the large model. + input (Option[str]): The input of the large model. If this parameter is specified, any part of the output that includes the input will be completely truncated. Defaults to None. +''') + + +add_chinese_doc('AlpacaPrompter', '''\ +Alpaca格式的Prompter,支持工具调用,不支持历史对话。 + +Args: + instruction (Option[str]): 大模型的任务指令,至少带一个可填充的槽位(如 ``{instruction}``)。 + extro_keys (Option[List]): 额外的字段,用户的输入会填充这些字段。 + show (bool): 标志是否打印生成的Prompt,默认为False +''') + +add_english_doc('AlpacaPrompter', '''\ +Alpaca-style Prompter, supports tool calls, does not support historical dialogue. + +Sure! Here is the translation, keeping the original format: + +Args: + instruction (Option[str]): Task instructions for the large model, with at least one fillable slot (e.g. ``{instruction}``). + extro_keys (Option[List]): Additional fields that will be filled with user input. + show (bool): Flag indicating whether to print the generated Prompt, default is False. +''') + +add_example('AlpacaPrompter', '''\ +>>> from lazyllm import AlpacaPrompter +>>> p = AlpacaPrompter('hello world {instruction}') +>>> p.generate_prompt('this is my input') +'You are an AI-Agent developed by LazyLLM.\\nBelow is an instruction that describes a task, paired with extra messages such as input that provides further context if possible. Write a response that appropriately completes the request.\\n\\n ### Instruction:\\nhello world this is my input\\n\\n\\n### Response:\\n' +>>> p.generate_prompt('this is my input', return_dict=True) +{'messages': [{'role': 'system', 'content': 'You are an AI-Agent developed by LazyLLM.\\nBelow is an instruction that describes a task, paired with extra messages such as input that provides further context if possible. Write a response that appropriately completes the request.\\n\\n ### Instruction:\\nhello world this is my input\\n\\n'}, {'role': 'user', 'content': ''}]} +>>> +>>> p = AlpacaPrompter('hello world {instruction}, {input}', extro_keys=['knowledge']) +>>> p.generate_prompt(dict(instruction='hello world', input='my input', knowledge='lazyllm')) +'You are an AI-Agent developed by LazyLLM.\\nBelow is an instruction that describes a task, paired with extra messages such as input that provides further context if possible. Write a response that appropriately completes the request.\\n\\n ### Instruction:\\nhello world hello world, my input\\n\\nHere are some extra messages you can referred to:\\n\\n### knowledge:\\nlazyllm\\n\\n\\n### Response:\\n' +>>> p.generate_prompt(dict(instruction='hello world', input='my input', knowledge='lazyllm'), return_dict=True) +{'messages': [{'role': 'system', 'content': 'You are an AI-Agent developed by LazyLLM.\\nBelow is an instruction that describes a task, paired with extra messages such as input that provides further context if possible. Write a response that appropriately completes the request.\\n\\n ### Instruction:\\nhello world hello world, my input\\n\\nHere are some extra messages you can referred to:\\n\\n### knowledge:\\nlazyllm\\n\\n'}, {'role': 'user', 'content': ''}]} +''') + +add_chinese_doc('ChatPrompter', '''\ +多轮对话的Prompt,支持工具调用和历史对话 + +Args: + instruction (Option[str]): 大模型的任务指令,可以带0到多个待填充的槽位,用 ``{}`` 表示。 + extro_keys (Option[List]): 额外的字段,用户的输入会填充这些字段。 + show (bool): 标志是否打印生成的Prompt,默认为False +''') + +add_english_doc('ChatPrompter', '''\ +chat prompt, supports tool calls and historical dialogue. + +Args: + instruction (Option[str]): Task instructions for the large model, with 0 to multiple fillable slot, represented by ``{}``. + extro_keys (Option[List]): Additional fields that will be filled with user input. + show (bool): Flag indicating whether to print the generated Prompt, default is False. +''') + +add_example('ChatPrompter', '''\ +>>> p = ChatPrompter('hello world') +>>> p.generate_prompt('this is my input') +'<|start_system|>You are an AI-Agent developed by LazyLLM.hello world\\n\\n<|end_system|>\\n\\n\\n<|Human|>:\\nthis is my input\\n<|Assistant|>:\\n' +>>> p.generate_prompt('this is my input', return_dict=True) +{'messages': [{'role': 'system', 'content': 'You are an AI-Agent developed by LazyLLM.\\nhello world\\n\\n'}, {'role': 'user', 'content': 'this is my input'}]} +>>> +>>> p = ChatPrompter('hello world {instruction}', extro_keys=['knowledge']) +>>> p.generate_prompt(dict(instruction='this is my ins', input='this is my inp', knowledge='LazyLLM-Knowledge')) +'<|start_system|>You are an AI-Agent developed by LazyLLM.hello world this is my ins\\nHere are some extra messages you can referred to:\\n\\n### knowledge:\\nLazyLLM-Knowledge\\n\\n\\n<|end_system|>\\n\\n\\n<|Human|>:\\nthis is my inp\\n<|Assistant|>:\\n' +>>> p.generate_prompt(dict(instruction='this is my ins', input='this is my inp', knowledge='LazyLLM-Knowledge'), return_dict=True) +{'messages': [{'role': 'system', 'content': 'You are an AI-Agent developed by LazyLLM.\\nhello world this is my ins\\nHere are some extra messages you can referred to:\\n\\n### knowledge:\\nLazyLLM-Knowledge\\n\\n\\n'}, {'role': 'user', 'content': 'this is my inp'}]} +>>> p.generate_prompt(dict(instruction='this is my ins', input='this is my inp', knowledge='LazyLLM-Knowledge'), history=[['s1', 'e1'], ['s2', 'e2']]) +'<|start_system|>You are an AI-Agent developed by LazyLLM.hello world this is my ins\\nHere are some extra messages you can referred to:\\n\\n### knowledge:\\nLazyLLM-Knowledge\\n\\n\\n<|end_system|>\\n\\n<|Human|>:s1<|Assistant|>:e1<|Human|>:s2<|Assistant|>:e2\\n<|Human|>:\\nthis is my inp\\n<|Assistant|>:\\n' +''') + +# ============= Launcher + add_chinese_doc = functools.partial(utils.add_chinese_doc, module=lazyllm.launcher) add_english_doc = functools.partial(utils.add_english_doc, module=lazyllm.launcher) add_example = functools.partial(utils.add_example, module=lazyllm.launcher) -# ============= Launcher # Launcher-EmptyLauncher add_chinese_doc('EmptyLauncher', '''\ 此类是 ``LazyLLMLaunchersBase`` 的子类,作为一个本地的启动器。