Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

output parser serialization #758

Merged
merged 1 commit into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions langchain/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,24 @@ def check_valid_template(
raise ValueError("Invalid prompt schema.")


class BaseOutputParser(ABC):
class BaseOutputParser(BaseModel, ABC):
"""Class to parse the output of an LLM call."""

@abstractmethod
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
"""Parse the output of an LLM call."""

@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError

def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict


class ListOutputParser(BaseOutputParser):
"""Class to parse the output of an LLM call to a list."""
Expand All @@ -79,6 +90,11 @@ class RegexParser(BaseOutputParser, BaseModel):
output_keys: List[str]
default_output_key: Optional[str] = None

@property
def _type(self) -> str:
"""Return the type key."""
return "regex_parser"

def parse(self, text: str) -> Dict[str, str]:
"""Parse the output of an LLM call."""
match = re.search(self.regex, text)
Expand Down Expand Up @@ -142,7 +158,7 @@ def _prompt_type(self) -> str:

def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt."""
prompt_dict = super().dict()
prompt_dict = super().dict(**kwargs)
prompt_dict["_type"] = self._prompt_type
return prompt_dict

Expand Down
18 changes: 17 additions & 1 deletion langchain/prompts/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import requests
import yaml

from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.base import BasePromptTemplate, RegexParser
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate

Expand Down Expand Up @@ -69,6 +69,20 @@ def _load_examples(config: dict) -> dict:
return config


def _load_output_parser(config: dict) -> dict:
"""Load output parser."""
if "output_parser" in config:
if config["output_parser"] is not None:
_config = config["output_parser"]
output_parser_type = _config["_type"]
if output_parser_type == "regex_parser":
output_parser = RegexParser(**_config)
else:
raise ValueError(f"Unsupported output parser {output_parser_type}")
config["output_parser"] = output_parser
return config


def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
"""Load the few shot prompt from the config."""
# Load the suffix and prefix templates.
Expand All @@ -86,13 +100,15 @@ def _load_few_shot_prompt(config: dict) -> FewShotPromptTemplate:
config["example_prompt"] = load_prompt_from_config(config["example_prompt"])
# Load the examples.
config = _load_examples(config)
config = _load_output_parser(config)
return FewShotPromptTemplate(**config)


def _load_prompt(config: dict) -> PromptTemplate:
"""Load the prompt template from config."""
# Load the template from disk if necessary.
config = _load_template("template", config)
config = _load_output_parser(config)
return PromptTemplate(**config)


Expand Down