From 3c9b8ba006880b9dccc12456305c81b018089091 Mon Sep 17 00:00:00 2001 From: Edward Funnekotter Date: Wed, 4 Sep 2024 20:19:41 -0400 Subject: [PATCH] chore: Refactor make_history_start_with_user_message method (#32) Fix the method to not trim the first entry if it is a "system" role --- .../openai/openai_chat_model_with_history.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/solace_ai_connector/components/general/openai/openai_chat_model_with_history.py b/src/solace_ai_connector/components/general/openai/openai_chat_model_with_history.py index cca2e7b..5fde36f 100644 --- a/src/solace_ai_connector/components/general/openai/openai_chat_model_with_history.py +++ b/src/solace_ai_connector/components/general/openai/openai_chat_model_with_history.py @@ -132,11 +132,20 @@ def clear_history_but_keep_depth(self, session_id: str, depth: int, history): def make_history_start_with_user_message(self, session_id, history): if session_id in history: - while ( - history[session_id]["messages"] - and history[session_id]["messages"][0]["role"] != "user" - ): - history[session_id]["messages"].pop(0) + messages = history[session_id]["messages"] + if messages: + if messages[0]["role"] == "system": + # Start from the second message if the first is "system" + start_index = 1 + else: + # Start from the first message otherwise + start_index = 0 + + while ( + start_index < len(messages) + and messages[start_index]["role"] != "user" + ): + messages.pop(start_index) def handle_timer_event(self, timer_data): if timer_data["timer_id"] == "history_cleanup":