-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The type of messages in deserialized CreateChatCompletionRequest are all SystemMessage #216
Comments
I also have this issue. Using actix_web |
…quest all SystemMessage Turns out we dont even need role in the child. SIMPLIFY REQUIREMENTS. also just use serde tag, it handles the serilization for us too. thanks coca.codes Closes 64bit#216
Was banging my head on this for a bit, but just pushed a fix on my branch. thanks coco.codes from the NAMTAO discord! to solve the parent issue, of them always being System, we implement the macro This maps the role key to the appropriate enum under I solved this by deleting the role in the child and implementing it in the parent as a method that runs a match on the type of enum (not even really needed, turns out the I've verified this works in prod across a bunch of different model providers, im happy with this solution, though i dont know if it will be merged. you're free to merge from my fork if you like |
Thank you, I will have a try. |
i wrote custom wrapper for ser and deser use async_openai::types::{ChatCompletionRequestMessage};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::ser::SerializeStruct;
use serde_json::Value;
#[derive(Debug)]
pub struct Message(ChatCompletionRequestMessage);
impl Message {
pub fn from_original(enum_val: ChatCompletionRequestMessage) -> Self {
Message(enum_val)
}
pub fn into_original(self) -> ChatCompletionRequestMessage {
self.0
}
}
impl Serialize for Message {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("Message", 2)?;
match &self.0 {
ChatCompletionRequestMessage::System(msg) => {
state.serialize_field("type", "system")?;
state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
}
ChatCompletionRequestMessage::User(msg) => {
state.serialize_field("type", "user")?;
state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
}
ChatCompletionRequestMessage::Assistant(msg) => {
state.serialize_field("type", "assistant")?;
state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
}
ChatCompletionRequestMessage::Tool(msg) => {
state.serialize_field("type", "tool")?;
state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
}
ChatCompletionRequestMessage::Function(msg) => {
state.serialize_field("type", "function")?;
state.serialize_field("content", &serde_json::to_value(&msg).map_err(|_| "Serialization error").unwrap())?;
}
}
state.end()
}
}
impl<'de> Deserialize<'de> for Message {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let value: Value = Deserialize::deserialize(deserializer)?;
let msg_type = value.get("type").and_then(Value::as_str).ok_or_else(|| {
serde::de::Error::custom("Missing or invalid `type` field")
})?;
match msg_type {
"system" => {
let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestSystemMessage").unwrap();
Ok(Message(ChatCompletionRequestMessage::System(msg)))
}
"user" => {
let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestUserMessage").unwrap();
Ok(Message(ChatCompletionRequestMessage::User(msg)))
}
"assistant" => {
let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestAssistantMessage").unwrap();
Ok(Message(ChatCompletionRequestMessage::Assistant(msg)))
}
"tool" => {
let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestToolMessage").unwrap();
Ok(Message(ChatCompletionRequestMessage::Tool(msg)))
}
"function" => {
let msg = serde_json::from_value(value["content"].clone()).map_err(|_| "Failed to deserialize ChatCompletionRequestFunctionMessage").unwrap();
Ok(Message(ChatCompletionRequestMessage::Function(msg)))
}
_ => Err(serde::de::Error::unknown_variant(msg_type, &["system", "user", "assistant", "tool", "function"])),
}
}
} |
Instead of complex ser-de implementations, types have be udpated for proper serialization and deserialization in v0.23.0 Thank you @sontallive for contributing the test too - its included as part of tests in https://github.com/64bit/async-openai/blob/main/async-openai/tests/ser_de.rs |
I want to deserialize request json to
CreateChatCompletionRequest
but i found the messages are all System.code
result
The text was updated successfully, but these errors were encountered: