From 0128f6a52b484ad06a7596dc9aa304e396c7bb26 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 17 May 2023 16:23:59 -0400 Subject: [PATCH 1/2] Add local agent --- src/transformers/__init__.py | 13 +++- src/transformers/tools/__init__.py | 4 +- src/transformers/tools/agents.py | 113 +++++++++++++++++++++++++++++ 3 files changed, 127 insertions(+), 3 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ca476f30c291..c079981490b7 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -614,6 +614,7 @@ "tools": [ "Agent", "HfAgent", + "LocalAgent", "OpenAiAgent", "PipelineTool", "RemoteTool", @@ -4361,7 +4362,17 @@ ) # Tools - from .tools import Agent, HfAgent, OpenAiAgent, PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool + from .tools import ( + Agent, + HfAgent, + LocalAgent, + OpenAiAgent, + PipelineTool, + RemoteTool, + Tool, + launch_gradio_demo, + load_tool, + ) # Trainer from .trainer_callback import ( diff --git a/src/transformers/tools/__init__.py b/src/transformers/tools/__init__.py index 8465d7370d3c..a5b3c7dc05eb 100644 --- a/src/transformers/tools/__init__.py +++ b/src/transformers/tools/__init__.py @@ -24,7 +24,7 @@ _import_structure = { - "agents": ["Agent", "HfAgent", "OpenAiAgent"], + "agents": ["Agent", "HfAgent", "LocalAgent", "OpenAiAgent"], "base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"], } @@ -46,7 +46,7 @@ _import_structure["translation"] = ["TranslationTool"] if TYPE_CHECKING: - from .agents import Agent, HfAgent, OpenAiAgent + from .agents import Agent, HfAgent, LocalAgent, OpenAiAgent from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool try: diff --git a/src/transformers/tools/agents.py b/src/transformers/tools/agents.py index 79413954df10..a41e93398dd7 100644 --- a/src/transformers/tools/agents.py +++ b/src/transformers/tools/agents.py @@ -24,6 +24,8 @@ import requests from huggingface_hub import HfFolder, hf_hub_download, list_spaces +from ..generation import StoppingCriteria, StoppingCriteriaList +from ..models.auto import AutoModelForCausalLM, AutoTokenizer from ..utils import is_openai_available, logging from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE @@ -492,3 +494,114 @@ def generate_one(self, prompt, stop): if result.endswith(stop_seq): return result[: -len(stop_seq)] return result + + +class LocalAgent(Agent): + """ + Agent that uses a local model and tokenizer to generate code. + + Args: + model ([`PreTrainedModel`]): + The model to use for the agent. + tokenizer ([`PreTrainedTokenizer`]): + The tokenizer to use for the agent. + chat_prompt_template (`str`, *optional*): + Pass along your own prompt if you want to override the default template for the `chat` method. + run_prompt_template (`str`, *optional*): + Pass along your own prompt if you want to override the default template for the `run` method. + additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*): + Any additional tools to include on top of the default ones. If you pass along a tool with the same name as + one of the default tools, that default tool will be overridden. + + Example: + + ```py + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent + + checkpoint = "bigcode/starcoder" + model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16) + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + + agent = LocalAgent(model, tokenizer) + agent.run("Draw me a picture of rivers and lakes.") + ``` + """ + + def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None): + self.model = model + self.tokenizer = tokenizer + super().__init__( + chat_prompt_template=chat_prompt_template, + run_prompt_template=run_prompt_template, + additional_tools=additional_tools, + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """ + Convenience method to build a `LocalAgent` from a pretrained checkpoint. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The name of a repo on the Hub or a local path to a folder containing both model and tokenizer. + kwargs: + Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`]. + + Example: + + ```py + import torch + from transformers import LocalAgent + + agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16) + agent.run("Draw me a picture of rivers and lakes.") + ``` + """ + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) + return cls(model, tokenizer) + + @property + def _model_device(self): + if hasattr(self.model, "hf_device_map"): + return list(self.model.hf_device_map.values())[0] + for param in self.mode.parameters(): + return param.device + + def generate_one(self, prompt, stop): + encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device) + src_len = encoded_inputs["input_ids"].shape[1] + stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)]) + outputs = self.model.generate( + encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria + ) + + result = self.tokenizer.decode(outputs[0].tolist()[src_len:]) + # Inference API returns the stop sequence + for stop_seq in stop: + if result.endswith(stop_seq): + result = result[: -len(stop_seq)] + return result + + +class StopSequenceCriteria(StoppingCriteria): + """ + This class can be used to stop generation whenever a sequence of tokens is encountered. + + Args: + stop_sequences (`str` or `List[str]`): + The sequence (or list of sequences) on which to stop execution. + tokenizer: + The tokenizer used to decode the model outputs. + """ + + def __init__(self, stop_sequences, tokenizer): + if isinstance(stop_sequences, str): + stop_sequences = [stop_sequences] + self.stop_sequences = stop_sequences + self.tokenizer = tokenizer + + def __call__(self, input_ids, scores, **kwargs) -> bool: + decoded_output = self.tokenizer.decode(input_ids.tolist()[0]) + return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences) From b69e32c4afd259f68f0e70f0c5bf998ca188d601 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 17 May 2023 16:26:50 -0400 Subject: [PATCH 2/2] Document LocalAgent --- docs/source/en/main_classes/agent.mdx | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/en/main_classes/agent.mdx b/docs/source/en/main_classes/agent.mdx index 953857c410cb..ecf1747d12bd 100644 --- a/docs/source/en/main_classes/agent.mdx +++ b/docs/source/en/main_classes/agent.mdx @@ -24,12 +24,16 @@ contains the API docs for the underlying classes. ## Agents -We provide two types of agents: [`HfAgent`] uses inference endpoints for opensource models and [`OpenAiAgent`] uses OpenAI closed models. +We provide three types of agents: [`HfAgent`] uses inference endpoints for opensource models, [`LocalAgent`] uses a model of your choice locally and [`OpenAiAgent`] uses OpenAI closed models. ### HfAgent [[autodoc]] HfAgent +### LocalAgent + +[[autodoc]] LocalAgent + ### OpenAiAgent [[autodoc]] OpenAiAgent