diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 0d7ecbef1..989961092 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -35,6 +35,7 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, status["has_error"] = true; status["is_stream"] = true; status["status_code"] = k400BadRequest; + context->need_stop = false; (*context->callback)(std::move(status), std::move(check_error)); return size * nmemb; } @@ -58,7 +59,8 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, status["is_done"] = true; status["has_error"] = false; status["is_stream"] = true; - status["status_code"] = 200; + status["status_code"] = k200OK; + context->need_stop = false; (*context->callback)(std::move(status), Json::Value()); break; } @@ -169,6 +171,15 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( curl_slist_free_all(headers); curl_easy_cleanup(curl); + if (context.need_stop) { + CTL_DBG("No stop message received, need to stop"); + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = k200OK; + (*context.callback)(std::move(status), Json::Value()); + } return response; } @@ -602,6 +613,7 @@ void RemoteEngine::HandleChatCompletion( status["status_code"] = k500InternalServerError; Json::Value error; error["error"] = "Failed to parse response"; + LOG_WARN << "Failed to parse response: " << response.body; callback(std::move(status), std::move(error)); return; } @@ -626,6 +638,9 @@ void RemoteEngine::HandleChatCompletion( try { response_json["stream"] = false; + if (!response_json.isMember("model")) { + response_json["model"] = model; + } response_str = renderer_.Render(template_str, response_json); } catch (const std::exception& e) { throw std::runtime_error("Template rendering error: " + @@ -633,8 +648,9 @@ void RemoteEngine::HandleChatCompletion( } } catch (const std::exception& e) { // Log error and potentially rethrow or handle accordingly - LOG_WARN << "Error in TransformRequest: " << e.what(); - LOG_WARN << "Using original request body"; + LOG_WARN << "Error: " << e.what(); + LOG_WARN << "Response: " << response.body; + LOG_WARN << "Using original body"; response_str = response_json.toStyledString(); } @@ -649,6 +665,7 @@ void RemoteEngine::HandleChatCompletion( Json::Value error; error["error"] = "Failed to parse response"; callback(std::move(status), std::move(error)); + LOG_WARN << "Failed to parse response: " << response_str; return; } diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index bc6d534c5..46222467a 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -24,6 +24,7 @@ struct StreamContext { std::string model; extensions::TemplateRenderer& renderer; std::string stream_template; + bool need_stop = true; }; struct CurlResponse { std::string body; diff --git a/engine/extensions/template_renderer.cc b/engine/extensions/template_renderer.cc index 32e7d72f5..dfd50f000 100644 --- a/engine/extensions/template_renderer.cc +++ b/engine/extensions/template_renderer.cc @@ -7,7 +7,9 @@ #include #include #include "utils/logging_utils.h" +#include "utils/string_utils.h" namespace extensions { + TemplateRenderer::TemplateRenderer() { // Configure Inja environment env_.set_trim_blocks(true); @@ -21,7 +23,8 @@ TemplateRenderer::TemplateRenderer() { const auto& value = *args[0]; if (value.is_string()) { - return nlohmann::json(std::string("\"") + value.get() + + return nlohmann::json(std::string("\"") + + string_utils::EscapeJson(value.get()) + "\""); } return value; @@ -46,16 +49,14 @@ std::string TemplateRenderer::Render(const std::string& tmpl, std::string result = env_.render(tmpl, template_data); // Clean up any potential double quotes in JSON strings - result = std::regex_replace(result, std::regex("\\\"\\\""), "\""); + // result = std::regex_replace(result, std::regex("\\\"\\\""), "\""); LOG_DEBUG << "Result: " << result; - // Validate JSON - auto parsed = nlohmann::json::parse(result); - return result; } catch (const std::exception& e) { LOG_ERROR << "Template rendering failed: " << e.what(); + LOG_ERROR << "Data: " << data.toStyledString(); LOG_ERROR << "Template: " << tmpl; throw std::runtime_error(std::string("Template rendering failed: ") + e.what()); @@ -133,4 +134,4 @@ std::string TemplateRenderer::RenderFile(const std::string& template_path, e.what()); } } -} // namespace remote_engine \ No newline at end of file +} // namespace extensions \ No newline at end of file diff --git a/engine/test/components/test_remote_engine.cc b/engine/test/components/test_remote_engine.cc index 3bb3cdca3..11b6ae07f 100644 --- a/engine/test/components/test_remote_engine.cc +++ b/engine/test/components/test_remote_engine.cc @@ -18,14 +18,14 @@ TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) { "messages": [ {% for message in input_request.messages %} {% if not loop.is_first %} - {"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {"role": "{{ message.role }}", "content": {{ tojson(message.content) }} } {% if not loop.is_last %},{% endif %} {% endif %} {% endfor %} ] {% else %} "messages": [ {% for message in input_request.messages %} - {"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {"role": " {{ message.role}}", "content": {{ tojson(message.content) }} } {% if not loop.is_last %},{% endif %} {% endfor %} ] {% endif %} @@ -181,6 +181,240 @@ TEST_F(RemoteEngineTest, AnthropicResponse) { EXPECT_TRUE(res_json["choices"][0]["message"]["content"].isNull()); } +TEST_F(RemoteEngineTest, CohereRequest) { + std::string tpl = + R"({ + {% for key, value in input_request %} + {% if key == "messages" %} + {% if input_request.messages.0.role == "system" %} + "preamble": {{ tojson(input_request.messages.0.content) }}, + {% if length(input_request.messages) > 2 %} + "chatHistory": [ + {% for message in input_request.messages %} + {% if not loop.is_first and not loop.is_last %} + {"role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": {{ tojson(message.content) }} } {% if loop.index < length(input_request.messages) - 2 %},{% endif %} + {% endif %} + {% endfor %} + ], + {% endif %} + "message": {{ tojson(last(input_request.messages).content) }} + {% else %} + {% if length(input_request.messages) > 2 %} + "chatHistory": [ + {% for message in input_request.messages %} + {% if not loop.is_last %} + { "role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": {{ tojson(message.content) }} } {% if loop.index < length(input_request.messages) - 2 %},{% endif %} + {% endif %} + {% endfor %} + ], + {% endif %} + "message": {{ tojson(last(input_request.messages).content) }} + {% endif %} + {% if not loop.is_last %},{% endif %} + {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} + "{{ key }}": {{ tojson(value) }} + {% if not loop.is_last %},{% endif %} + {% endif %} + {% endfor %} })"; + { + std::string message_with_system = R"({ + "engine" : "cohere", + "max_tokens" : 1024, + "messages": [ + {"role": "system", "content": "You are a seasoned data scientist at a Fortune 500 company."}, + {"role": "user", "content": "Hello, world"}, + {"role": "assistant", "content": "The man who is widely credited with discovering gravity is Sir Isaac Newton"}, + {"role": "user", "content": "How are you today?"} + ], + "model": "command-r-plus-08-2024", + "stream" : true +})"; + + auto data = json_helper::ParseJsonString(message_with_system); + + extensions::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt()); + for (auto const& msg : data["messages"]) { + if (msg["role"].asString() == "system") { + EXPECT_EQ(msg["content"].asString(), res_json["preamble"].asString()); + } + } + EXPECT_EQ(res_json["message"].asString(), "How are you today?"); + } + + { + std::string message_without_system = R"({ + "messages": [ + {"role": "user", "content": "Hello, \"the\" \n\nworld"} + ], + "model": "command-r-plus-08-2024", + "max_tokens": 1024, + })"; + + auto data = json_helper::ParseJsonString(message_without_system); + + extensions::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt()); + EXPECT_EQ(data["messages"][0]["content"].asString(), + res_json["message"].asString()); + } +} + +TEST_F(RemoteEngineTest, CohereResponse) { + std::string tpl = R"( + {% if input_request.stream %} + {"object": "chat.completion.chunk", + "model": "{{ input_request.model }}", + "choices": [{"index": 0, "delta": { {% if input_request.event_type == "text-generation" %} "role": "assistant", "content": {{ tojson(input_request.text) }} {% else %} "role": "assistant", "content": null {% endif %} }, + {% if input_request.event_type == "stream-end" %} "finish_reason": "{{ input_request.finish_reason }}" {% else %} "finish_reason": null {% endif %} }] + } + {% else %} + {"id": "{{ input_request.generation_id }}", + "created": null, + "object": "chat.completion", + "model": "{{ input_request.model }}", + "choices": [{ "index": 0, "message": { "role": "assistant", "content": {% if not input_request.text %} null {% else %} {{ tojson(input_request.text) }} {% endif %}, "refusal": null }, "logprobs": null, "finish_reason": "{{ input_request.finish_reason }}" } ], "usage": { "prompt_tokens": {{ input_request.meta.tokens.input_tokens }}, "completion_tokens": {{ input_request.meta.tokens.output_tokens }}, "total_tokens": {{ input_request.meta.tokens.input_tokens + input_request.meta.tokens.output_tokens }}, "prompt_tokens_details": { "cached_tokens": 0 }, "completion_tokens_details": { "reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, "system_fingerprint": "fp_6b68a8204b"} {% endif %})"; + std::string message = R"({ + "event_type": "text-generation", + "text": " help" +})"; + auto data = json_helper::ParseJsonString(message); + data["stream"] = true; + data["model"] = "cohere"; + extensions::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(res_json["choices"][0]["delta"]["content"].asString(), " help"); + + message = R"( + { + "event_type": "stream-end", + "response": { + "text": "Hello! How can I help you today?", + "generation_id": "29f14a5a-11de-4cae-9800-25e4747408ea", + "chat_history": [ + { + "role": "USER", + "message": "hello world!" + }, + { + "role": "CHATBOT", + "message": "Hello! How can I help you today?" + } + ], + "finish_reason": "COMPLETE", + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 3, + "output_tokens": 9 + }, + "tokens": { + "input_tokens": 69, + "output_tokens": 9 + } + } + }, + "finish_reason": "COMPLETE" + })"; + data = json_helper::ParseJsonString(message); + data["stream"] = true; + data["model"] = "cohere"; + res = rdr.Render(tpl, data); + res_json = json_helper::ParseJsonString(res); + EXPECT_TRUE(res_json["choices"][0]["delta"]["content"].isNull()); + + // non-stream + message = R"( + { + "text": "Isaac \t\tNewton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 January 1643 (New Style).", + "generation_id": "0385c7cf-4247-43a3-a450-b25b547a31e1", + "citations": [ + { + "start": 25, + "end": 41, + "text": "25 December 1642", + "document_ids": [ + "web-search_0" + ] + } + ], + "search_queries": [ + { + "text": "Isaac Newton birth year", + "generation_id": "9a497980-c3e2-4460-b81c-ef44d293f95d" + } + ], + "search_results": [ + { + "connector": { + "id": "web-search" + }, + "document_ids": [ + "web-search_0" + ], + "search_query": { + "text": "Isaac Newton birth year", + "generation_id": "9a497980-c3e2-4460-b81c-ef44d293f95d" + } + } + ], + "finish_reason": "COMPLETE", + "chat_history": [ + { + "role": "USER", + "message": "Who discovered gravity?" + }, + { + "role": "CHATBOT", + "message": "The man who is widely credited with discovering gravity is Sir Isaac Newton" + }, + { + "role": "USER", + "message": "What year was he born?" + }, + { + "role": "CHATBOT", + "message": "Isaac Newton was born on 25 December 1642 (Old Style) or 4 January 1643 (New Style)." + } + ], + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 31738, + "output_tokens": 35 + }, + "tokens": { + "input_tokens": 32465, + "output_tokens": 205 + } + } +} + )"; + + data = json_helper::ParseJsonString(message); + data["stream"] = false; + data["model"] = "cohere"; + res = rdr.Render(tpl, data); + res_json = json_helper::ParseJsonString(res); + EXPECT_EQ( + res_json["choices"][0]["message"]["content"].asString(), + "Isaac \t\tNewton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 " + "January 1643 (New Style)."); +} + TEST_F(RemoteEngineTest, HeaderTemplate) { { std::string header_template = diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h index 02d309169..a962109e8 100644 --- a/engine/utils/string_utils.h +++ b/engine/utils/string_utils.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -162,4 +163,41 @@ inline std::string FormatTimeElapsed(uint64_t pastTimestamp) { return oss.str(); } + +inline std::string EscapeJson(const std::string& s) { + std::ostringstream o; + for (auto c = s.cbegin(); c != s.cend(); c++) { + switch (*c) { + case '"': + o << "\\\""; + break; + case '\\': + o << "\\\\"; + break; + case '\b': + o << "\\b"; + break; + case '\f': + o << "\\f"; + break; + case '\n': + o << "\\n"; + break; + case '\r': + o << "\\r"; + break; + case '\t': + o << "\\t"; + break; + default: + if ('\x00' <= *c && *c <= '\x1f') { + o << "\\u" << std::hex << std::setw(4) << std::setfill('0') + << static_cast(*c); + } else { + o << *c; + } + } + } + return o.str(); +} } // namespace string_utils