diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index f74bd4c1050f9..40d72953d485b 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -1,4 +1,5 @@ import re +import xml import xml.etree.ElementTree as ET from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union @@ -81,33 +82,52 @@ def _transform( continue # feed buffer to parser parser.feed(buffer) + buffer = "" # yield all events - for event, elem in parser.read_events(): - if event == "start": - # update current path - current_path.append(elem.tag) - current_path_has_children = False - elif event == "end": - # remove last element from current path - current_path.pop() - # yield element - if not current_path_has_children: - yield nested_element(current_path, elem) - # prevent yielding of parent element - if current_path: - current_path_has_children = True - else: - xml_started = False + try: + for event, elem in parser.read_events(): + if event == "start": + # update current path + current_path.append(elem.tag) + current_path_has_children = False + elif event == "end": + # remove last element from current path + # + current_path.pop() + # yield element + if not current_path_has_children: + yield nested_element(current_path, elem) + # prevent yielding of parent element + if current_path: + current_path_has_children = True + else: + xml_started = False + except xml.etree.ElementTree.ParseError: + # This might be junk at the end of the XML input. + # Let's check whether the current path is empty. + if not current_path: + # If it is empty, we can ignore this error. + break + else: + raise + # close parser - parser.close() + try: + parser.close() + except xml.etree.ElementTree.ParseError: + # Ignore. This will ignore any incomplete XML at the end of the input + pass async def _atransform( self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[AddableDict]: + xml_start_re = re.compile(r"<[a-zA-Z:_]") parser = ET.XMLPullParser(["start", "end"]) + xml_started = False current_path: List[str] = [] current_path_has_children = False + buffer = "" async for chunk in input: if isinstance(chunk, BaseMessage): # extract text @@ -115,24 +135,54 @@ async def _atransform( if not isinstance(chunk_content, str): continue chunk = chunk_content - # pass chunk to parser - parser.feed(chunk) + # add chunk to buffer of unprocessed text + buffer += chunk + # if xml string hasn't started yet, continue to next chunk + if not xml_started: + if match := xml_start_re.search(buffer): + # if xml string has started, remove all text before it + buffer = buffer[match.start() :] + xml_started = True + else: + continue + # feed buffer to parser + parser.feed(buffer) + + buffer = "" # yield all events - for event, elem in parser.read_events(): - if event == "start": - # update current path - current_path.append(elem.tag) - current_path_has_children = False - elif event == "end": - # remove last element from current path - current_path.pop() - # yield element - if not current_path_has_children: - yield nested_element(current_path, elem) - # prevent yielding of parent element - current_path_has_children = True + try: + for event, elem in parser.read_events(): + if event == "start": + # update current path + current_path.append(elem.tag) + current_path_has_children = False + elif event == "end": + # remove last element from current path + # + current_path.pop() + # yield element + if not current_path_has_children: + yield nested_element(current_path, elem) + # prevent yielding of parent element + if current_path: + current_path_has_children = True + else: + xml_started = False + except xml.etree.ElementTree.ParseError: + # This might be junk at the end of the XML input. + # Let's check whether the current path is empty. + if not current_path: + # If it is empty, we can ignore this error. + break + else: + raise + # close parser - parser.close() + try: + parser.close() + except xml.etree.ElementTree.ParseError: + # Ignore. This will ignore any incomplete XML at the end of the input + pass def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]: """Converts xml tree to python dictionary.""" diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py index 222c8bf759610..48e7372b98a3d 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py @@ -1,10 +1,12 @@ """Test XMLOutputParser""" +from typing import AsyncIterator, Iterable + import pytest from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers.xml import XMLOutputParser -DEF_RESULT_ENCODING = """ +DATA = """ @@ -13,6 +15,25 @@ tag """ +WITH_XML_HEADER = f""" +{DATA}""" + + +IN_XML_TAGS_WITH_XML_HEADER = f""" +```xml +{WITH_XML_HEADER} +``` +""" + +IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK = f""" +Some random text +```xml +{WITH_XML_HEADER} +``` +More random text +""" + + DEF_RESULT_EXPECTED = { "foo": [ {"bar": [{"baz": None}, {"baz": "slim.shady"}]}, @@ -24,23 +45,13 @@ @pytest.mark.parametrize( "result", [ - DEF_RESULT_ENCODING, - DEF_RESULT_ENCODING[DEF_RESULT_ENCODING.find("\n") :], - f""" -```xml -{DEF_RESULT_ENCODING} -``` -""", - f""" -Some random text -```xml -{DEF_RESULT_ENCODING} -``` -More random text -""", + DATA, # has no xml header + WITH_XML_HEADER, + IN_XML_TAGS_WITH_XML_HEADER, + IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK, ], ) -def test_xml_output_parser(result: str) -> None: +async def test_xml_output_parser(result: str) -> None: """Test XMLOutputParser.""" xml_parser = XMLOutputParser() @@ -48,12 +59,23 @@ def test_xml_output_parser(result: str) -> None: xml_result = xml_parser.parse(result) assert DEF_RESULT_EXPECTED == xml_result - # TODO(Eugene): Fix this test for newer python version - # assert list(xml_parser.transform(iter(result))) == [ - # {"foo": [{"bar": [{"baz": None}]}]}, - # {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, - # {"foo": [{"baz": "tag"}]}, - # ] + assert list(xml_parser.transform(iter(result))) == [ + {"foo": [{"bar": [{"baz": None}]}]}, + {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, + {"foo": [{"baz": "tag"}]}, + ] + + async def _as_iter(iterable: Iterable[str]) -> AsyncIterator[str]: + for item in iterable: + yield item + + chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))] + + assert list(chunks) == [ + {"foo": [{"bar": [{"baz": None}]}]}, + {"foo": [{"bar": [{"baz": "slim.shady"}]}]}, + {"foo": [{"baz": "tag"}]}, + ] @pytest.mark.parametrize("result", ["foo>", "