Skip to content

Commit

Permalink
fix: using proper xml parser + hack to make it handle mixed text and …
Browse files Browse the repository at this point in the history
…xml+
  • Loading branch information
evilsocket committed Jun 20, 2024
1 parent f202406 commit ddcdaf7
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 71 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ serde_trim = "1.1.0"
serde_yaml = "0.9.34"
tokio = "1.38.0"
urlencoding = "2.1.3"
xml-rs = "0.8.20"
2 changes: 1 addition & 1 deletion src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl Agent {
let response = self.generator.chat(&options).await?.trim().to_string();

// parse the model response into invocations
let invocations = serialization::xml::parse_model_response(&response)?;
let invocations = serialization::xml::try_parse(&response)?;
let mut prev: Option<Invocation> = None;

// nothing parsed, report the problem to the model
Expand Down
1 change: 0 additions & 1 deletion src/agent/namespaces/planning/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ impl Action for Clear {
pub(crate) fn get_namespace() -> Namespace {
Namespace::new(
"Planning".to_string(),
// TODO: improve this - it should be clear to the model that it should deconstruct complex problems in smaller ones using this tool.
include_str!("ns.prompt").to_string(),
vec![
Box::<AddStep>::default(),
Expand Down
213 changes: 144 additions & 69 deletions src/agent/serialization/xml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
use anyhow::Result;
use lazy_static::lazy_static;
use regex::Regex;
use xml::{reader::XmlEvent, EventReader};

use crate::agent::{
namespaces::Action,
Expand Down Expand Up @@ -116,82 +117,156 @@ pub(crate) fn serialize_storage(storage: &Storage) -> String {
}
}

pub(crate) fn parse_model_response(model_response: &str) -> Result<Vec<Invocation>> {
let mut invocations = vec![];

let model_response_size = model_response.len();
let mut current = 0;

// TODO: replace this with a proper xml parser
while current < model_response_size {
// read until < or end
let mut ptr = &model_response[current..];
if let Some(tag_open_idx) = ptr.find('<') {
current += tag_open_idx;
ptr = &ptr[tag_open_idx..];
// read tag
if let Some(tag_name_term_idx) = ptr.find(|c: char| c == '>' || c == ' ') {
current += tag_name_term_idx;
let tag_name = &ptr[1..tag_name_term_idx];
// println!("tag_name={}", tag_name);
if let Some(tag_close_idx) = ptr.find('>') {
current += tag_close_idx + tag_name.len();
let tag_closing = format!("</{}>", tag_name);
let tag_closing_idx = ptr.find(&tag_closing);

if let Some(tag_closing_idx) = tag_closing_idx {
// parse attributes if any
let attributes = if ptr.as_bytes()[tag_name_term_idx] == b' ' {
let attr_str = &ptr[tag_name_term_idx + 1..tag_close_idx];
let mut attrs = HashMap::new();

// parse as a list of key="value"
let iter = XML_ATTRIBUTES_PARSER.captures_iter(attr_str);
for caps in iter {
if caps.len() == 4 {
let key = caps.get(2).unwrap().as_str().trim();
let value = caps.get(3).unwrap().as_str().trim();
attrs.insert(key.to_string(), value.to_string());
}
}

Some(attrs)
} else {
None
};

// parse payload if any
let after_tag_close = &ptr[tag_close_idx + 1..tag_closing_idx];
let payload = if !after_tag_close.is_empty() {
if after_tag_close.as_bytes()[0] != b'<' {
Some(after_tag_close.trim().to_string())
} else {
None
}
} else {
None
};

invocations.push(Invocation::new(
tag_name.to_string(),
attributes,
payload,
));

continue;
#[derive(Default, Debug)]
pub struct Parsed {
pub processed: usize,
pub invocations: Vec<Invocation>,
}

fn build_invocation(
closing_name: String,
element: &XmlEvent,
payload: &Option<String>,
) -> Result<Invocation> {
let (name, attrs) = match element {
xml::reader::XmlEvent::StartElement {
name,
attributes,
namespace: _,
} => (name.to_string(), attributes),
_ => {
return Err(anyhow!("unexpected element {:?}", element));
}
};

if name != closing_name {
return Err(anyhow!(
"unexpected closing {} while parsing {}",
closing_name,
name
));
}

let action = name.to_string();
let attributes = if !attrs.is_empty() {
let mut map = HashMap::new();
for attr in attrs {
map.insert(attr.name.to_string(), attr.value.to_string());
}

Some(map)
} else {
None
};
let payload = payload.as_ref().map(|data| data.to_owned());

Ok(Invocation::new(action, attributes, payload))
}

fn try_parse_block(ptr: &str) -> Parsed {
let mut parser = EventReader::from_str(ptr);
let mut parsed = Parsed::default();
let src_size = parser.source().len();

let mut curr_element = None;
let mut curr_payload = None;

loop {
let event = parser.next();
if let Ok(event) = event {
// println!("{:?}", &event);
match event {
xml::reader::XmlEvent::StartDocument {
version: _,
encoding: _,
standalone: _,
} => {}
xml::reader::XmlEvent::EndDocument {} => {
break;
}
xml::reader::XmlEvent::StartElement {
name: _,
attributes: _,
namespace: _,
} => {
curr_element = Some(event);
}
xml::reader::XmlEvent::Characters(data) => {
curr_payload = Some(data);
}
xml::reader::XmlEvent::EndElement { name } => {
let ret = build_invocation(
name.to_string(),
curr_element.as_ref().unwrap(),
&curr_payload,
);
if let Ok(inv) = ret {
parsed.invocations.push(inv);
} else {
eprintln!("WARNING: {:?}", ret.err().unwrap());
}
}
_ => {
eprintln!("WARNING: unexpected xml element: {:?}", event);
}
}

// just skip ahead
current += 1;
} else {
// no more tags
break;
}
}

Ok(invocations)
let src_size_now = parser.source().len();

// amount of successfully processed bytes
parsed.processed = src_size - src_size_now;

return parsed;
}

pub(crate) fn try_parse(raw: &str) -> Result<Vec<Invocation>> {
let mut ptr = raw;
let mut parsed = vec![];

loop {
// search for a potential xml opening
let open_idx = ptr.find('<');
if open_idx.is_none() {
// no more xml
break;
}

let xml_start = open_idx.unwrap();
ptr = &ptr[xml_start..];

let parsed_block = try_parse_block(ptr);
if parsed_block.processed == 0 {
break;
} else {
parsed.extend(parsed_block.invocations);

// update offset
ptr = &ptr[parsed_block.processed..];
}
}

Ok(parsed)
}

// TODO: add tests for this
// TODO: add waaaaay more tests
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_parse_block_infinite_loop() {
let ptr = "<clear-plan></clear-plan>
<update-goal>test</update-goal>";
let parsed = try_parse_block(ptr);

assert_eq!(ptr.len(), parsed.processed);
assert_eq!(parsed.invocations.len(), 2);

assert_eq!(&parsed.invocations[0].action, "clear-plan");
assert_eq!(&parsed.invocations[1].action, "update-goal");
}
}

0 comments on commit ddcdaf7

Please sign in to comment.