From 27bd982ef522cd827f1971e12012c22e7739be4c Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 14 Jan 2025 11:11:10 +0700 Subject: [PATCH 1/5] chore: unit test for cohere and handle stop curl --- .../extensions/remote-engine/remote_engine.cc | 16 +- .../extensions/remote-engine/remote_engine.h | 1 + engine/test/components/test_remote_engine.cc | 233 ++++++++++++++++++ 3 files changed, 249 insertions(+), 1 deletion(-) diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 0d7ecbef1..abda0d824 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -35,6 +35,7 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, status["has_error"] = true; status["is_stream"] = true; status["status_code"] = k400BadRequest; + context->need_stop = false; (*context->callback)(std::move(status), std::move(check_error)); return size * nmemb; } @@ -58,7 +59,8 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, status["is_done"] = true; status["has_error"] = false; status["is_stream"] = true; - status["status_code"] = 200; + status["status_code"] = k200OK; + context->need_stop = false; (*context->callback)(std::move(status), Json::Value()); break; } @@ -169,6 +171,15 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( curl_slist_free_all(headers); curl_easy_cleanup(curl); + if (context.need_stop) { + CTL_DBG("No stop message received, need to stop"); + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = k200OK; + (*context.callback)(std::move(status), Json::Value()); + } return response; } @@ -626,6 +637,9 @@ void RemoteEngine::HandleChatCompletion( try { response_json["stream"] = false; + if (!response_json.isMember("model")) { + response_json["model"] = model; + } response_str = renderer_.Render(template_str, response_json); } catch (const std::exception& e) { throw std::runtime_error("Template rendering error: " + diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index bc6d534c5..46222467a 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -24,6 +24,7 @@ struct StreamContext { std::string model; extensions::TemplateRenderer& renderer; std::string stream_template; + bool need_stop = true; }; struct CurlResponse { std::string body; diff --git a/engine/test/components/test_remote_engine.cc b/engine/test/components/test_remote_engine.cc index 3bb3cdca3..9a9c528bb 100644 --- a/engine/test/components/test_remote_engine.cc +++ b/engine/test/components/test_remote_engine.cc @@ -181,6 +181,239 @@ TEST_F(RemoteEngineTest, AnthropicResponse) { EXPECT_TRUE(res_json["choices"][0]["message"]["content"].isNull()); } +TEST_F(RemoteEngineTest, CohereRequest) { + std::string tpl = + R"({ + {% for key, value in input_request %} + {% if key == "messages" %} + {% if input_request.messages.0.role == "system" %} + "preamble": "{{ input_request.messages.0.content }}", + {% if length(input_request.messages) > 2 %} + "chatHistory": [ + {% for message in input_request.messages %} + {% if not loop.is_first and not loop.is_last %} + {"role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": "{{ message.content }}" } {% if loop.index < length(input_request.messages) - 2 %},{% endif %} + {% endif %} + {% endfor %} + ], + {% endif %} + "message": "{{ last(input_request.messages).content }}" + {% else %} + {% if length(input_request.messages) > 2 %} + "chatHistory": [ + {% for message in input_request.messages %} + {% if not loop.is_last %} + { "role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": "{{ message.content }}" } {% if loop.index < length(input_request.messages) - 2 %},{% endif %} + {% endif %} + {% endfor %} + ], + {% endif %} + "message": "{{ last(input_request.messages).content }}" + {% endif %} + {% if not loop.is_last %},{% endif %} + {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} + "{{ key }}": {{ tojson(value) }} + {% if not loop.is_last %},{% endif %} + {% endif %} + {% endfor %} })"; + { + std::string message_with_system = R"({ + "engine" : "cohere", + "max_tokens" : 1024, + "messages": [ + {"role": "system", "content": "You are a seasoned data scientist at a Fortune 500 company."}, + {"role": "user", "content": "Hello, world"}, + {"role": "assistant", "content": "The man who is widely credited with discovering gravity is Sir Isaac Newton"}, + {"role": "user", "content": "How are you today?"} + ], + "model": "command-r-plus-08-2024", + "stream" : true +})"; + + auto data = json_helper::ParseJsonString(message_with_system); + + extensions::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt()); + for (auto const& msg : data["messages"]) { + if (msg["role"].asString() == "system") { + EXPECT_EQ(msg["content"].asString(), res_json["preamble"].asString()); + } + } + EXPECT_EQ(res_json["message"].asString(), "How are you today?"); + } + + { + std::string message_without_system = R"({ + "messages": [ + {"role": "user", "content": "Hello, world"} + ], + "model": "command-r-plus-08-2024", + "max_tokens": 1024, + })"; + + auto data = json_helper::ParseJsonString(message_without_system); + + extensions::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt()); + EXPECT_EQ(data["messages"][0]["content"].asString(), + res_json["message"].asString()); + } +} + +TEST_F(RemoteEngineTest, CohereResponse) { + std::string tpl = R"( + {% if input_request.stream %} + {"object": "chat.completion.chunk", + "model": "{{ input_request.model }}", + "choices": [{"index": 0, "delta": { {% if input_request.event_type == "text-generation" %} "role": "assistant", "content": "{{ input_request.text }}" {% else %} "role": "assistant", "content": null {% endif %} }, + {% if input_request.event_type == "stream-end" %} "finish_reason": "{{ input_request.finish_reason }}" {% else %} "finish_reason": null {% endif %} }] + } + {% else %} + {"id": "{{ input_request.generation_id }}", + "created": null, + "object": "chat.completion", + "model": "{{ input_request.model }}", + "choices": [{ "index": 0, "message": { "role": "assistant", "content": {% if not input_request.text %} null {% else %} "{{input_request.text}}" {% endif %}, "refusal": null }, "logprobs": null, "finish_reason": "{{ input_request.finish_reason }}" } ], "usage": { "prompt_tokens": {{ input_request.meta.tokens.input_tokens }}, "completion_tokens": {{ input_request.meta.tokens.output_tokens }}, "total_tokens": {{ input_request.meta.tokens.input_tokens + input_request.meta.tokens.output_tokens }}, "prompt_tokens_details": { "cached_tokens": 0 }, "completion_tokens_details": { "reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, "system_fingerprint": "fp_6b68a8204b"} {% endif %})"; + std::string message = R"({ + "event_type": "text-generation", + "text": " help" +})"; + auto data = json_helper::ParseJsonString(message); + data["stream"] = true; + data["model"] = "cohere"; + extensions::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(res_json["choices"][0]["delta"]["content"].asString(), " help"); + + message = R"( + { + "event_type": "stream-end", + "response": { + "text": "Hello! How can I help you today?", + "generation_id": "29f14a5a-11de-4cae-9800-25e4747408ea", + "chat_history": [ + { + "role": "USER", + "message": "hello world!" + }, + { + "role": "CHATBOT", + "message": "Hello! How can I help you today?" + } + ], + "finish_reason": "COMPLETE", + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 3, + "output_tokens": 9 + }, + "tokens": { + "input_tokens": 69, + "output_tokens": 9 + } + } + }, + "finish_reason": "COMPLETE" + })"; + data = json_helper::ParseJsonString(message); + data["stream"] = true; + data["model"] = "cohere"; + res = rdr.Render(tpl, data); + res_json = json_helper::ParseJsonString(res); + EXPECT_TRUE(res_json["choices"][0]["delta"]["content"].isNull()); + + // non-stream + message = R"( + { + "text": "Isaac Newton was born on 25 December 1642 (Old Style) or 4 January 1643 (New Style).", + "generation_id": "0385c7cf-4247-43a3-a450-b25b547a31e1", + "citations": [ + { + "start": 25, + "end": 41, + "text": "25 December 1642", + "document_ids": [ + "web-search_0" + ] + } + ], + "search_queries": [ + { + "text": "Isaac Newton birth year", + "generation_id": "9a497980-c3e2-4460-b81c-ef44d293f95d" + } + ], + "search_results": [ + { + "connector": { + "id": "web-search" + }, + "document_ids": [ + "web-search_0" + ], + "search_query": { + "text": "Isaac Newton birth year", + "generation_id": "9a497980-c3e2-4460-b81c-ef44d293f95d" + } + } + ], + "finish_reason": "COMPLETE", + "chat_history": [ + { + "role": "USER", + "message": "Who discovered gravity?" + }, + { + "role": "CHATBOT", + "message": "The man who is widely credited with discovering gravity is Sir Isaac Newton" + }, + { + "role": "USER", + "message": "What year was he born?" + }, + { + "role": "CHATBOT", + "message": "Isaac Newton was born on 25 December 1642 (Old Style) or 4 January 1643 (New Style)." + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 31738, + "output_tokens": 35 + }, + "tokens": { + "input_tokens": 32465, + "output_tokens": 205 + } + } +} + )"; + + data = json_helper::ParseJsonString(message); + data["stream"] = false; + data["model"] = "cohere"; + res = rdr.Render(tpl, data); + res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(res_json["choices"][0]["message"]["content"].asString(), + "Isaac Newton was born on 25 December 1642 (Old Style) or 4 " + "January 1643 (New Style)."); +} + TEST_F(RemoteEngineTest, HeaderTemplate) { { std::string header_template = From 3364989b275195a51dca0a313728410c258014ed Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 14 Jan 2025 16:05:14 +0700 Subject: [PATCH 2/5] fix: parse failed nlohmann::json --- engine/extensions/remote-engine/remote_engine.cc | 1 + engine/extensions/template_renderer.cc | 5 +---- engine/test/components/test_remote_engine.cc | 4 ++-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index abda0d824..226de703f 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -663,6 +663,7 @@ void RemoteEngine::HandleChatCompletion( Json::Value error; error["error"] = "Failed to parse response"; callback(std::move(status), std::move(error)); + LOG_WARN << "Failed to parse response: " << response_str; return; } diff --git a/engine/extensions/template_renderer.cc b/engine/extensions/template_renderer.cc index 32e7d72f5..14be76cc3 100644 --- a/engine/extensions/template_renderer.cc +++ b/engine/extensions/template_renderer.cc @@ -50,9 +50,6 @@ std::string TemplateRenderer::Render(const std::string& tmpl, LOG_DEBUG << "Result: " << result; - // Validate JSON - auto parsed = nlohmann::json::parse(result); - return result; } catch (const std::exception& e) { LOG_ERROR << "Template rendering failed: " << e.what(); @@ -133,4 +130,4 @@ std::string TemplateRenderer::RenderFile(const std::string& template_path, e.what()); } } -} // namespace remote_engine \ No newline at end of file +} // namespace extensions \ No newline at end of file diff --git a/engine/test/components/test_remote_engine.cc b/engine/test/components/test_remote_engine.cc index 9a9c528bb..4f498bf04 100644 --- a/engine/test/components/test_remote_engine.cc +++ b/engine/test/components/test_remote_engine.cc @@ -337,7 +337,7 @@ TEST_F(RemoteEngineTest, CohereResponse) { // non-stream message = R"( { - "text": "Isaac Newton was born on 25 December 1642 (Old Style) or 4 January 1643 (New Style).", + "text": "Isaac Newton was born on 25 December 1642 (Old Style) \n\nor 4 January 1643 (New Style).", "generation_id": "0385c7cf-4247-43a3-a450-b25b547a31e1", "citations": [ { @@ -410,7 +410,7 @@ TEST_F(RemoteEngineTest, CohereResponse) { res = rdr.Render(tpl, data); res_json = json_helper::ParseJsonString(res); EXPECT_EQ(res_json["choices"][0]["message"]["content"].asString(), - "Isaac Newton was born on 25 December 1642 (Old Style) or 4 " + "Isaac Newton was born on 25 December 1642 (Old Style) \n\nor 4 " "January 1643 (New Style)."); } From 6d8d0480de195df4fed000bb2362f7a519715c5d Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 14 Jan 2025 18:51:29 +0700 Subject: [PATCH 3/5] fix: tojson string --- .../extensions/remote-engine/remote_engine.cc | 4 +-- engine/extensions/template_renderer.cc | 6 ++-- engine/test/components/test_remote_engine.cc | 29 ++++++++++--------- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 226de703f..e8e656070 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -647,8 +647,8 @@ void RemoteEngine::HandleChatCompletion( } } catch (const std::exception& e) { // Log error and potentially rethrow or handle accordingly - LOG_WARN << "Error in TransformRequest: " << e.what(); - LOG_WARN << "Using original request body"; + LOG_WARN << "Error: " << e.what(); + LOG_WARN << "Using original body"; response_str = response_json.toStyledString(); } diff --git a/engine/extensions/template_renderer.cc b/engine/extensions/template_renderer.cc index 14be76cc3..a855a5d7c 100644 --- a/engine/extensions/template_renderer.cc +++ b/engine/extensions/template_renderer.cc @@ -21,8 +21,9 @@ TemplateRenderer::TemplateRenderer() { const auto& value = *args[0]; if (value.is_string()) { - return nlohmann::json(std::string("\"") + value.get() + - "\""); + std::string v = value.get(); + v = std::regex_replace(v, std::regex("\""), "\\\""); + return nlohmann::json(std::string("\"") + v + "\""); } return value; }); @@ -53,6 +54,7 @@ std::string TemplateRenderer::Render(const std::string& tmpl, return result; } catch (const std::exception& e) { LOG_ERROR << "Template rendering failed: " << e.what(); + LOG_ERROR << "Data: " << data.toStyledString(); LOG_ERROR << "Template: " << tmpl; throw std::runtime_error(std::string("Template rendering failed: ") + e.what()); diff --git a/engine/test/components/test_remote_engine.cc b/engine/test/components/test_remote_engine.cc index 4f498bf04..2661a590b 100644 --- a/engine/test/components/test_remote_engine.cc +++ b/engine/test/components/test_remote_engine.cc @@ -18,14 +18,14 @@ TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) { "messages": [ {% for message in input_request.messages %} {% if not loop.is_first %} - {"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {"role": "{{ message.role }}", "content": {{ tojson(message.content) }} } {% if not loop.is_last %},{% endif %} {% endif %} {% endfor %} ] {% else %} "messages": [ {% for message in input_request.messages %} - {"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {"role": " {{ message.role}}", "content": {{ tojson(message.content) }} } {% if not loop.is_last %},{% endif %} {% endfor %} ] {% endif %} @@ -187,28 +187,28 @@ TEST_F(RemoteEngineTest, CohereRequest) { {% for key, value in input_request %} {% if key == "messages" %} {% if input_request.messages.0.role == "system" %} - "preamble": "{{ input_request.messages.0.content }}", + "preamble": {{ tojson(input_request.messages.0.content) }}, {% if length(input_request.messages) > 2 %} "chatHistory": [ {% for message in input_request.messages %} {% if not loop.is_first and not loop.is_last %} - {"role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": "{{ message.content }}" } {% if loop.index < length(input_request.messages) - 2 %},{% endif %} + {"role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": {{ tojson(message.content) }} } {% if loop.index < length(input_request.messages) - 2 %},{% endif %} {% endif %} {% endfor %} ], {% endif %} - "message": "{{ last(input_request.messages).content }}" + "message": {{ tojson(last(input_request.messages).content) }} {% else %} {% if length(input_request.messages) > 2 %} "chatHistory": [ {% for message in input_request.messages %} {% if not loop.is_last %} - { "role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": "{{ message.content }}" } {% if loop.index < length(input_request.messages) - 2 %},{% endif %} + { "role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": {{ tojson(message.content) }} } {% if loop.index < length(input_request.messages) - 2 %},{% endif %} {% endif %} {% endfor %} ], {% endif %} - "message": "{{ last(input_request.messages).content }}" + "message": {{ tojson(last(input_request.messages).content) }} {% endif %} {% if not loop.is_last %},{% endif %} {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} @@ -249,7 +249,7 @@ TEST_F(RemoteEngineTest, CohereRequest) { { std::string message_without_system = R"({ "messages": [ - {"role": "user", "content": "Hello, world"} + {"role": "user", "content": "Hello, \"the\" world"} ], "model": "command-r-plus-08-2024", "max_tokens": 1024, @@ -273,7 +273,7 @@ TEST_F(RemoteEngineTest, CohereResponse) { {% if input_request.stream %} {"object": "chat.completion.chunk", "model": "{{ input_request.model }}", - "choices": [{"index": 0, "delta": { {% if input_request.event_type == "text-generation" %} "role": "assistant", "content": "{{ input_request.text }}" {% else %} "role": "assistant", "content": null {% endif %} }, + "choices": [{"index": 0, "delta": { {% if input_request.event_type == "text-generation" %} "role": "assistant", "content": {{ tojson(input_request.text) }} {% else %} "role": "assistant", "content": null {% endif %} }, {% if input_request.event_type == "stream-end" %} "finish_reason": "{{ input_request.finish_reason }}" {% else %} "finish_reason": null {% endif %} }] } {% else %} @@ -281,7 +281,7 @@ TEST_F(RemoteEngineTest, CohereResponse) { "created": null, "object": "chat.completion", "model": "{{ input_request.model }}", - "choices": [{ "index": 0, "message": { "role": "assistant", "content": {% if not input_request.text %} null {% else %} "{{input_request.text}}" {% endif %}, "refusal": null }, "logprobs": null, "finish_reason": "{{ input_request.finish_reason }}" } ], "usage": { "prompt_tokens": {{ input_request.meta.tokens.input_tokens }}, "completion_tokens": {{ input_request.meta.tokens.output_tokens }}, "total_tokens": {{ input_request.meta.tokens.input_tokens + input_request.meta.tokens.output_tokens }}, "prompt_tokens_details": { "cached_tokens": 0 }, "completion_tokens_details": { "reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, "system_fingerprint": "fp_6b68a8204b"} {% endif %})"; + "choices": [{ "index": 0, "message": { "role": "assistant", "content": {% if not input_request.text %} null {% else %} {{ tojson(input_request.text) }} {% endif %}, "refusal": null }, "logprobs": null, "finish_reason": "{{ input_request.finish_reason }}" } ], "usage": { "prompt_tokens": {{ input_request.meta.tokens.input_tokens }}, "completion_tokens": {{ input_request.meta.tokens.output_tokens }}, "total_tokens": {{ input_request.meta.tokens.input_tokens + input_request.meta.tokens.output_tokens }}, "prompt_tokens_details": { "cached_tokens": 0 }, "completion_tokens_details": { "reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, "system_fingerprint": "fp_6b68a8204b"} {% endif %})"; std::string message = R"({ "event_type": "text-generation", "text": " help" @@ -337,7 +337,7 @@ TEST_F(RemoteEngineTest, CohereResponse) { // non-stream message = R"( { - "text": "Isaac Newton was born on 25 December 1642 (Old Style) \n\nor 4 January 1643 (New Style).", + "text": "Isaac Newton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 January 1643 (New Style).", "generation_id": "0385c7cf-4247-43a3-a450-b25b547a31e1", "citations": [ { @@ -409,9 +409,10 @@ TEST_F(RemoteEngineTest, CohereResponse) { data["model"] = "cohere"; res = rdr.Render(tpl, data); res_json = json_helper::ParseJsonString(res); - EXPECT_EQ(res_json["choices"][0]["message"]["content"].asString(), - "Isaac Newton was born on 25 December 1642 (Old Style) \n\nor 4 " - "January 1643 (New Style)."); + EXPECT_EQ( + res_json["choices"][0]["message"]["content"].asString(), + "Isaac Newton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 " + "January 1643 (New Style)."); } TEST_F(RemoteEngineTest, HeaderTemplate) { From ac9894c19461aa2f2eb806838e8c47b71f614889 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 14 Jan 2025 21:24:13 +0700 Subject: [PATCH 4/5] fix: return --- engine/extensions/template_renderer.cc | 1 + engine/test/components/test_remote_engine.cc | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/engine/extensions/template_renderer.cc b/engine/extensions/template_renderer.cc index a855a5d7c..f79ea4a73 100644 --- a/engine/extensions/template_renderer.cc +++ b/engine/extensions/template_renderer.cc @@ -23,6 +23,7 @@ TemplateRenderer::TemplateRenderer() { if (value.is_string()) { std::string v = value.get(); v = std::regex_replace(v, std::regex("\""), "\\\""); + v = std::regex_replace(v, std::regex("\n"), "\\n"); return nlohmann::json(std::string("\"") + v + "\""); } return value; diff --git a/engine/test/components/test_remote_engine.cc b/engine/test/components/test_remote_engine.cc index 2661a590b..c86508771 100644 --- a/engine/test/components/test_remote_engine.cc +++ b/engine/test/components/test_remote_engine.cc @@ -249,7 +249,7 @@ TEST_F(RemoteEngineTest, CohereRequest) { { std::string message_without_system = R"({ "messages": [ - {"role": "user", "content": "Hello, \"the\" world"} + {"role": "user", "content": "Hello, \"the\" \n\nworld"} ], "model": "command-r-plus-08-2024", "max_tokens": 1024, From 2dd0753b80c9ca71d221f9548f1b11f1e4476b5e Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Wed, 15 Jan 2025 05:34:44 +0700 Subject: [PATCH 5/5] fix: escape json --- .../extensions/remote-engine/remote_engine.cc | 2 + engine/extensions/template_renderer.cc | 11 +++--- engine/test/components/test_remote_engine.cc | 4 +- engine/utils/string_utils.h | 38 +++++++++++++++++++ 4 files changed, 48 insertions(+), 7 deletions(-) diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index e8e656070..989961092 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -613,6 +613,7 @@ void RemoteEngine::HandleChatCompletion( status["status_code"] = k500InternalServerError; Json::Value error; error["error"] = "Failed to parse response"; + LOG_WARN << "Failed to parse response: " << response.body; callback(std::move(status), std::move(error)); return; } @@ -648,6 +649,7 @@ void RemoteEngine::HandleChatCompletion( } catch (const std::exception& e) { // Log error and potentially rethrow or handle accordingly LOG_WARN << "Error: " << e.what(); + LOG_WARN << "Response: " << response.body; LOG_WARN << "Using original body"; response_str = response_json.toStyledString(); } diff --git a/engine/extensions/template_renderer.cc b/engine/extensions/template_renderer.cc index f79ea4a73..dfd50f000 100644 --- a/engine/extensions/template_renderer.cc +++ b/engine/extensions/template_renderer.cc @@ -7,7 +7,9 @@ #include #include #include "utils/logging_utils.h" +#include "utils/string_utils.h" namespace extensions { + TemplateRenderer::TemplateRenderer() { // Configure Inja environment env_.set_trim_blocks(true); @@ -21,10 +23,9 @@ TemplateRenderer::TemplateRenderer() { const auto& value = *args[0]; if (value.is_string()) { - std::string v = value.get(); - v = std::regex_replace(v, std::regex("\""), "\\\""); - v = std::regex_replace(v, std::regex("\n"), "\\n"); - return nlohmann::json(std::string("\"") + v + "\""); + return nlohmann::json(std::string("\"") + + string_utils::EscapeJson(value.get()) + + "\""); } return value; }); @@ -48,7 +49,7 @@ std::string TemplateRenderer::Render(const std::string& tmpl, std::string result = env_.render(tmpl, template_data); // Clean up any potential double quotes in JSON strings - result = std::regex_replace(result, std::regex("\\\"\\\""), "\""); + // result = std::regex_replace(result, std::regex("\\\"\\\""), "\""); LOG_DEBUG << "Result: " << result; diff --git a/engine/test/components/test_remote_engine.cc b/engine/test/components/test_remote_engine.cc index c86508771..11b6ae07f 100644 --- a/engine/test/components/test_remote_engine.cc +++ b/engine/test/components/test_remote_engine.cc @@ -337,7 +337,7 @@ TEST_F(RemoteEngineTest, CohereResponse) { // non-stream message = R"( { - "text": "Isaac Newton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 January 1643 (New Style).", + "text": "Isaac \t\tNewton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 January 1643 (New Style).", "generation_id": "0385c7cf-4247-43a3-a450-b25b547a31e1", "citations": [ { @@ -411,7 +411,7 @@ TEST_F(RemoteEngineTest, CohereResponse) { res_json = json_helper::ParseJsonString(res); EXPECT_EQ( res_json["choices"][0]["message"]["content"].asString(), - "Isaac Newton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 " + "Isaac \t\tNewton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 " "January 1643 (New Style)."); } diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h index 02d309169..a962109e8 100644 --- a/engine/utils/string_utils.h +++ b/engine/utils/string_utils.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -162,4 +163,41 @@ inline std::string FormatTimeElapsed(uint64_t pastTimestamp) { return oss.str(); } + +inline std::string EscapeJson(const std::string& s) { + std::ostringstream o; + for (auto c = s.cbegin(); c != s.cend(); c++) { + switch (*c) { + case '"': + o << "\\\""; + break; + case '\\': + o << "\\\\"; + break; + case '\b': + o << "\\b"; + break; + case '\f': + o << "\\f"; + break; + case '\n': + o << "\\n"; + break; + case '\r': + o << "\\r"; + break; + case '\t': + o << "\\t"; + break; + default: + if ('\x00' <= *c && *c <= '\x1f') { + o << "\\u" << std::hex << std::setw(4) << std::setfill('0') + << static_cast(*c); + } else { + o << *c; + } + } + } + return o.str(); +} } // namespace string_utils