diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py
index f74bd4c1050f9..40d72953d485b 100644
--- a/libs/core/langchain_core/output_parsers/xml.py
+++ b/libs/core/langchain_core/output_parsers/xml.py
@@ -1,4 +1,5 @@
import re
+import xml
import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
@@ -81,33 +82,52 @@ def _transform(
continue
# feed buffer to parser
parser.feed(buffer)
+
buffer = ""
# yield all events
- for event, elem in parser.read_events():
- if event == "start":
- # update current path
- current_path.append(elem.tag)
- current_path_has_children = False
- elif event == "end":
- # remove last element from current path
- current_path.pop()
- # yield element
- if not current_path_has_children:
- yield nested_element(current_path, elem)
- # prevent yielding of parent element
- if current_path:
- current_path_has_children = True
- else:
- xml_started = False
+ try:
+ for event, elem in parser.read_events():
+ if event == "start":
+ # update current path
+ current_path.append(elem.tag)
+ current_path_has_children = False
+ elif event == "end":
+ # remove last element from current path
+ #
+ current_path.pop()
+ # yield element
+ if not current_path_has_children:
+ yield nested_element(current_path, elem)
+ # prevent yielding of parent element
+ if current_path:
+ current_path_has_children = True
+ else:
+ xml_started = False
+ except xml.etree.ElementTree.ParseError:
+ # This might be junk at the end of the XML input.
+ # Let's check whether the current path is empty.
+ if not current_path:
+ # If it is empty, we can ignore this error.
+ break
+ else:
+ raise
+
# close parser
- parser.close()
+ try:
+ parser.close()
+ except xml.etree.ElementTree.ParseError:
+ # Ignore. This will ignore any incomplete XML at the end of the input
+ pass
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:
+ xml_start_re = re.compile(r"<[a-zA-Z:_]")
parser = ET.XMLPullParser(["start", "end"])
+ xml_started = False
current_path: List[str] = []
current_path_has_children = False
+ buffer = ""
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
@@ -115,24 +135,54 @@ async def _atransform(
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
- # pass chunk to parser
- parser.feed(chunk)
+ # add chunk to buffer of unprocessed text
+ buffer += chunk
+ # if xml string hasn't started yet, continue to next chunk
+ if not xml_started:
+ if match := xml_start_re.search(buffer):
+ # if xml string has started, remove all text before it
+ buffer = buffer[match.start() :]
+ xml_started = True
+ else:
+ continue
+ # feed buffer to parser
+ parser.feed(buffer)
+
+ buffer = ""
# yield all events
- for event, elem in parser.read_events():
- if event == "start":
- # update current path
- current_path.append(elem.tag)
- current_path_has_children = False
- elif event == "end":
- # remove last element from current path
- current_path.pop()
- # yield element
- if not current_path_has_children:
- yield nested_element(current_path, elem)
- # prevent yielding of parent element
- current_path_has_children = True
+ try:
+ for event, elem in parser.read_events():
+ if event == "start":
+ # update current path
+ current_path.append(elem.tag)
+ current_path_has_children = False
+ elif event == "end":
+ # remove last element from current path
+ #
+ current_path.pop()
+ # yield element
+ if not current_path_has_children:
+ yield nested_element(current_path, elem)
+ # prevent yielding of parent element
+ if current_path:
+ current_path_has_children = True
+ else:
+ xml_started = False
+ except xml.etree.ElementTree.ParseError:
+ # This might be junk at the end of the XML input.
+ # Let's check whether the current path is empty.
+ if not current_path:
+ # If it is empty, we can ignore this error.
+ break
+ else:
+ raise
+
# close parser
- parser.close()
+ try:
+ parser.close()
+ except xml.etree.ElementTree.ParseError:
+ # Ignore. This will ignore any incomplete XML at the end of the input
+ pass
def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
"""Converts xml tree to python dictionary."""
diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
index 222c8bf759610..48e7372b98a3d 100644
--- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
+++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
@@ -1,10 +1,12 @@
"""Test XMLOutputParser"""
+from typing import AsyncIterator, Iterable
+
import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.xml import XMLOutputParser
-DEF_RESULT_ENCODING = """
+DATA = """
@@ -13,6 +15,25 @@
tag
"""
+WITH_XML_HEADER = f"""
+{DATA}"""
+
+
+IN_XML_TAGS_WITH_XML_HEADER = f"""
+```xml
+{WITH_XML_HEADER}
+```
+"""
+
+IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK = f"""
+Some random text
+```xml
+{WITH_XML_HEADER}
+```
+More random text
+"""
+
+
DEF_RESULT_EXPECTED = {
"foo": [
{"bar": [{"baz": None}, {"baz": "slim.shady"}]},
@@ -24,23 +45,13 @@
@pytest.mark.parametrize(
"result",
[
- DEF_RESULT_ENCODING,
- DEF_RESULT_ENCODING[DEF_RESULT_ENCODING.find("\n") :],
- f"""
-```xml
-{DEF_RESULT_ENCODING}
-```
-""",
- f"""
-Some random text
-```xml
-{DEF_RESULT_ENCODING}
-```
-More random text
-""",
+ DATA, # has no xml header
+ WITH_XML_HEADER,
+ IN_XML_TAGS_WITH_XML_HEADER,
+ IN_XML_TAGS_WITH_HEADER_AND_TRAILING_JUNK,
],
)
-def test_xml_output_parser(result: str) -> None:
+async def test_xml_output_parser(result: str) -> None:
"""Test XMLOutputParser."""
xml_parser = XMLOutputParser()
@@ -48,12 +59,23 @@ def test_xml_output_parser(result: str) -> None:
xml_result = xml_parser.parse(result)
assert DEF_RESULT_EXPECTED == xml_result
- # TODO(Eugene): Fix this test for newer python version
- # assert list(xml_parser.transform(iter(result))) == [
- # {"foo": [{"bar": [{"baz": None}]}]},
- # {"foo": [{"bar": [{"baz": "slim.shady"}]}]},
- # {"foo": [{"baz": "tag"}]},
- # ]
+ assert list(xml_parser.transform(iter(result))) == [
+ {"foo": [{"bar": [{"baz": None}]}]},
+ {"foo": [{"bar": [{"baz": "slim.shady"}]}]},
+ {"foo": [{"baz": "tag"}]},
+ ]
+
+ async def _as_iter(iterable: Iterable[str]) -> AsyncIterator[str]:
+ for item in iterable:
+ yield item
+
+ chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))]
+
+ assert list(chunks) == [
+ {"foo": [{"bar": [{"baz": None}]}]},
+ {"foo": [{"bar": [{"baz": "slim.shady"}]}]},
+ {"foo": [{"baz": "tag"}]},
+ ]
@pytest.mark.parametrize("result", ["foo>", "