forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add enum output parser (langchain-ai#5165)
- Loading branch information
Showing
3 changed files
with
237 additions
and
0 deletions.
There are no files selected for viewing
173 changes: 173 additions & 0 deletions
173
docs/modules/prompts/output_parsers/examples/enum.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0360be02", | ||
"metadata": {}, | ||
"source": [ | ||
"# Enum Output Parser\n", | ||
"\n", | ||
"This notebook shows how to use an Enum output parser" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "2f039b4b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain.output_parsers.enum import EnumOutputParser" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "9a35d1a7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from enum import Enum\n", | ||
"\n", | ||
"class Colors(Enum):\n", | ||
" RED = \"red\"\n", | ||
" GREEN = \"green\"\n", | ||
" BLUE = \"blue\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "a90a66f5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"parser = EnumOutputParser(enum=Colors)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "c48b88cb", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<Colors.RED: 'red'>" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"parser.parse(\"red\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "7d313e41", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<Colors.GREEN: 'green'>" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# Can handle spaces\n", | ||
"parser.parse(\" green\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "976ae42d", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<Colors.BLUE: 'blue'>" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# And new lines\n", | ||
"parser.parse(\"blue\\n\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"id": "636a48ab", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "OutputParserException", | ||
"evalue": "Response 'yellow' is not one of the expected values: ['red', 'green', 'blue']", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | ||
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/enum.py:25\u001b[0m, in \u001b[0;36mEnumOutputParser.parse\u001b[0;34m(self, response)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 25\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresponse\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstrip\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m:\n", | ||
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/enum.py:315\u001b[0m, in \u001b[0;36mEnumMeta.__call__\u001b[0;34m(cls, value, names, module, qualname, type, start)\u001b[0m\n\u001b[1;32m 314\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m names \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;66;03m# simple value lookup\u001b[39;00m\n\u001b[0;32m--> 315\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__new__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;66;03m# otherwise, functional API: we're creating a new Enum type\u001b[39;00m\n", | ||
"File \u001b[0;32m~/.pyenv/versions/3.9.1/lib/python3.9/enum.py:611\u001b[0m, in \u001b[0;36mEnum.__new__\u001b[0;34m(cls, value)\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m exc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 611\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ve_exc\n\u001b[1;32m 612\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m exc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", | ||
"\u001b[0;31mValueError\u001b[0m: 'yellow' is not a valid Colors", | ||
"\nDuring handling of the above exception, another exception occurred:\n", | ||
"\u001b[0;31mOutputParserException\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[8], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# And raises errors when appropriate\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mparser\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparse\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43myellow\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", | ||
"File \u001b[0;32m~/workplace/langchain/langchain/output_parsers/enum.py:27\u001b[0m, in \u001b[0;36mEnumOutputParser.parse\u001b[0;34m(self, response)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menum(response\u001b[38;5;241m.\u001b[39mstrip())\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m:\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m OutputParserException(\n\u001b[1;32m 28\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mResponse \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresponse\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m is not one of the \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 29\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexpected values: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_valid_values\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 30\u001b[0m )\n", | ||
"\u001b[0;31mOutputParserException\u001b[0m: Response 'yellow' is not one of the expected values: ['red', 'green', 'blue']" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# And raises errors when appropriate\n", | ||
"parser.parse(\"yellow\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c517f447", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.1" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from enum import Enum | ||
from typing import Any, Dict, List, Type | ||
|
||
from pydantic import root_validator | ||
|
||
from langchain.schema import BaseOutputParser, OutputParserException | ||
|
||
|
||
class EnumOutputParser(BaseOutputParser): | ||
enum: Type[Enum] | ||
|
||
@root_validator() | ||
def raise_deprecation(cls, values: Dict) -> Dict: | ||
enum = values["enum"] | ||
if not all(isinstance(e.value, str) for e in enum): | ||
raise ValueError("Enum values must be strings") | ||
return values | ||
|
||
@property | ||
def _valid_values(self) -> List[str]: | ||
return [e.value for e in self.enum] | ||
|
||
def parse(self, response: str) -> Any: | ||
try: | ||
return self.enum(response.strip()) | ||
except ValueError: | ||
raise OutputParserException( | ||
f"Response '{response}' is not one of the " | ||
f"expected values: {self._valid_values}" | ||
) | ||
|
||
def get_format_instructions(self) -> str: | ||
return f"Select one of the following options: {', '.join(self._valid_values)}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from enum import Enum | ||
|
||
from langchain.output_parsers.enum import EnumOutputParser | ||
from langchain.schema import OutputParserException | ||
|
||
|
||
class Colors(Enum): | ||
RED = "red" | ||
GREEN = "green" | ||
BLUE = "blue" | ||
|
||
|
||
def test_enum_output_parser_parse() -> None: | ||
parser = EnumOutputParser(enum=Colors) | ||
|
||
# Test valid inputs | ||
result = parser.parse("red") | ||
assert result == Colors.RED | ||
|
||
result = parser.parse("green") | ||
assert result == Colors.GREEN | ||
|
||
result = parser.parse("blue") | ||
assert result == Colors.BLUE | ||
|
||
# Test invalid input | ||
try: | ||
parser.parse("INVALID") | ||
assert False, "Should have raised OutputParserException" | ||
except OutputParserException: | ||
pass |