Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: unit test for cohere and handle stop curl #1856

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions engine/extensions/remote-engine/remote_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
status["has_error"] = true;
status["is_stream"] = true;
status["status_code"] = k400BadRequest;
context->need_stop = false;
(*context->callback)(std::move(status), std::move(check_error));
return size * nmemb;
}
Expand All @@ -58,7 +59,8 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = 200;
status["status_code"] = k200OK;
context->need_stop = false;
(*context->callback)(std::move(status), Json::Value());
break;
}
Expand Down Expand Up @@ -169,6 +171,15 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest(

curl_slist_free_all(headers);
curl_easy_cleanup(curl);
if (context.need_stop) {
CTL_DBG("No stop message received, need to stop");
Json::Value status;
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = k200OK;
(*context.callback)(std::move(status), Json::Value());
}
return response;
}

Expand Down Expand Up @@ -602,6 +613,7 @@ void RemoteEngine::HandleChatCompletion(
status["status_code"] = k500InternalServerError;
Json::Value error;
error["error"] = "Failed to parse response";
LOG_WARN << "Failed to parse response: " << response.body;
callback(std::move(status), std::move(error));
return;
}
Expand All @@ -626,15 +638,19 @@ void RemoteEngine::HandleChatCompletion(

try {
response_json["stream"] = false;
if (!response_json.isMember("model")) {
response_json["model"] = model;
}
response_str = renderer_.Render(template_str, response_json);
} catch (const std::exception& e) {
throw std::runtime_error("Template rendering error: " +
std::string(e.what()));
}
} catch (const std::exception& e) {
// Log error and potentially rethrow or handle accordingly
LOG_WARN << "Error in TransformRequest: " << e.what();
LOG_WARN << "Using original request body";
LOG_WARN << "Error: " << e.what();
LOG_WARN << "Response: " << response.body;
LOG_WARN << "Using original body";
response_str = response_json.toStyledString();
}

Expand All @@ -649,6 +665,7 @@ void RemoteEngine::HandleChatCompletion(
Json::Value error;
error["error"] = "Failed to parse response";
callback(std::move(status), std::move(error));
LOG_WARN << "Failed to parse response: " << response_str;
return;
}

Expand Down
1 change: 1 addition & 0 deletions engine/extensions/remote-engine/remote_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct StreamContext {
std::string model;
extensions::TemplateRenderer& renderer;
std::string stream_template;
bool need_stop = true;
};
struct CurlResponse {
std::string body;
Expand Down
13 changes: 7 additions & 6 deletions engine/extensions/template_renderer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include <regex>
#include <stdexcept>
#include "utils/logging_utils.h"
#include "utils/string_utils.h"
namespace extensions {

TemplateRenderer::TemplateRenderer() {
// Configure Inja environment
env_.set_trim_blocks(true);
Expand All @@ -21,7 +23,8 @@ TemplateRenderer::TemplateRenderer() {
const auto& value = *args[0];

if (value.is_string()) {
return nlohmann::json(std::string("\"") + value.get<std::string>() +
return nlohmann::json(std::string("\"") +
string_utils::EscapeJson(value.get<std::string>()) +
"\"");
}
return value;
Expand All @@ -46,16 +49,14 @@ std::string TemplateRenderer::Render(const std::string& tmpl,
std::string result = env_.render(tmpl, template_data);

// Clean up any potential double quotes in JSON strings
result = std::regex_replace(result, std::regex("\\\"\\\""), "\"");
// result = std::regex_replace(result, std::regex("\\\"\\\""), "\"");

LOG_DEBUG << "Result: " << result;

// Validate JSON
auto parsed = nlohmann::json::parse(result);

return result;
} catch (const std::exception& e) {
LOG_ERROR << "Template rendering failed: " << e.what();
LOG_ERROR << "Data: " << data.toStyledString();
LOG_ERROR << "Template: " << tmpl;
throw std::runtime_error(std::string("Template rendering failed: ") +
e.what());
Expand Down Expand Up @@ -133,4 +134,4 @@ std::string TemplateRenderer::RenderFile(const std::string& template_path,
e.what());
}
}
} // namespace remote_engine
} // namespace extensions
238 changes: 236 additions & 2 deletions engine/test/components/test_remote_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) {
"messages": [
{% for message in input_request.messages %}
{% if not loop.is_first %}
{"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %}
{"role": "{{ message.role }}", "content": {{ tojson(message.content) }} } {% if not loop.is_last %},{% endif %}
{% endif %}
{% endfor %}
]
{% else %}
"messages": [
{% for message in input_request.messages %}
{"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %}
{"role": " {{ message.role}}", "content": {{ tojson(message.content) }} } {% if not loop.is_last %},{% endif %}
{% endfor %}
]
{% endif %}
Expand Down Expand Up @@ -181,6 +181,240 @@ TEST_F(RemoteEngineTest, AnthropicResponse) {
EXPECT_TRUE(res_json["choices"][0]["message"]["content"].isNull());
}

TEST_F(RemoteEngineTest, CohereRequest) {
std::string tpl =
R"({
{% for key, value in input_request %}
{% if key == "messages" %}
{% if input_request.messages.0.role == "system" %}
"preamble": {{ tojson(input_request.messages.0.content) }},
{% if length(input_request.messages) > 2 %}
"chatHistory": [
{% for message in input_request.messages %}
{% if not loop.is_first and not loop.is_last %}
{"role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": {{ tojson(message.content) }} } {% if loop.index < length(input_request.messages) - 2 %},{% endif %}
{% endif %}
{% endfor %}
],
{% endif %}
"message": {{ tojson(last(input_request.messages).content) }}
{% else %}
{% if length(input_request.messages) > 2 %}
"chatHistory": [
{% for message in input_request.messages %}
{% if not loop.is_last %}
{ "role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": {{ tojson(message.content) }} } {% if loop.index < length(input_request.messages) - 2 %},{% endif %}
{% endif %}
{% endfor %}
],
{% endif %}
"message": {{ tojson(last(input_request.messages).content) }}
{% endif %}
{% if not loop.is_last %},{% endif %}
{% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %}
"{{ key }}": {{ tojson(value) }}
{% if not loop.is_last %},{% endif %}
{% endif %}
{% endfor %} })";
{
std::string message_with_system = R"({
"engine" : "cohere",
"max_tokens" : 1024,
"messages": [
{"role": "system", "content": "You are a seasoned data scientist at a Fortune 500 company."},
{"role": "user", "content": "Hello, world"},
{"role": "assistant", "content": "The man who is widely credited with discovering gravity is Sir Isaac Newton"},
{"role": "user", "content": "How are you today?"}
],
"model": "command-r-plus-08-2024",
"stream" : true
})";

auto data = json_helper::ParseJsonString(message_with_system);

extensions::TemplateRenderer rdr;
auto res = rdr.Render(tpl, data);

auto res_json = json_helper::ParseJsonString(res);
EXPECT_EQ(data["model"].asString(), res_json["model"].asString());
EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt());
for (auto const& msg : data["messages"]) {
if (msg["role"].asString() == "system") {
EXPECT_EQ(msg["content"].asString(), res_json["preamble"].asString());
}
}
EXPECT_EQ(res_json["message"].asString(), "How are you today?");
}

{
std::string message_without_system = R"({
"messages": [
{"role": "user", "content": "Hello, \"the\" \n\nworld"}
],
"model": "command-r-plus-08-2024",
"max_tokens": 1024,
})";

auto data = json_helper::ParseJsonString(message_without_system);

extensions::TemplateRenderer rdr;
auto res = rdr.Render(tpl, data);

auto res_json = json_helper::ParseJsonString(res);
EXPECT_EQ(data["model"].asString(), res_json["model"].asString());
EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt());
EXPECT_EQ(data["messages"][0]["content"].asString(),
res_json["message"].asString());
}
}

TEST_F(RemoteEngineTest, CohereResponse) {
std::string tpl = R"(
{% if input_request.stream %}
{"object": "chat.completion.chunk",
"model": "{{ input_request.model }}",
"choices": [{"index": 0, "delta": { {% if input_request.event_type == "text-generation" %} "role": "assistant", "content": {{ tojson(input_request.text) }} {% else %} "role": "assistant", "content": null {% endif %} },
{% if input_request.event_type == "stream-end" %} "finish_reason": "{{ input_request.finish_reason }}" {% else %} "finish_reason": null {% endif %} }]
}
{% else %}
{"id": "{{ input_request.generation_id }}",
"created": null,
"object": "chat.completion",
"model": "{{ input_request.model }}",
"choices": [{ "index": 0, "message": { "role": "assistant", "content": {% if not input_request.text %} null {% else %} {{ tojson(input_request.text) }} {% endif %}, "refusal": null }, "logprobs": null, "finish_reason": "{{ input_request.finish_reason }}" } ], "usage": { "prompt_tokens": {{ input_request.meta.tokens.input_tokens }}, "completion_tokens": {{ input_request.meta.tokens.output_tokens }}, "total_tokens": {{ input_request.meta.tokens.input_tokens + input_request.meta.tokens.output_tokens }}, "prompt_tokens_details": { "cached_tokens": 0 }, "completion_tokens_details": { "reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, "system_fingerprint": "fp_6b68a8204b"} {% endif %})";
std::string message = R"({
"event_type": "text-generation",
"text": " help"
})";
auto data = json_helper::ParseJsonString(message);
data["stream"] = true;
data["model"] = "cohere";
extensions::TemplateRenderer rdr;
auto res = rdr.Render(tpl, data);
auto res_json = json_helper::ParseJsonString(res);
EXPECT_EQ(res_json["choices"][0]["delta"]["content"].asString(), " help");

message = R"(
{
"event_type": "stream-end",
"response": {
"text": "Hello! How can I help you today?",
"generation_id": "29f14a5a-11de-4cae-9800-25e4747408ea",
"chat_history": [
{
"role": "USER",
"message": "hello world!"
},
{
"role": "CHATBOT",
"message": "Hello! How can I help you today?"
}
],
"finish_reason": "COMPLETE",
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"input_tokens": 3,
"output_tokens": 9
},
"tokens": {
"input_tokens": 69,
"output_tokens": 9
}
}
},
"finish_reason": "COMPLETE"
})";
data = json_helper::ParseJsonString(message);
data["stream"] = true;
data["model"] = "cohere";
res = rdr.Render(tpl, data);
res_json = json_helper::ParseJsonString(res);
EXPECT_TRUE(res_json["choices"][0]["delta"]["content"].isNull());

// non-stream
message = R"(
{
"text": "Isaac \t\tNewton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 January 1643 (New Style).",
"generation_id": "0385c7cf-4247-43a3-a450-b25b547a31e1",
"citations": [
{
"start": 25,
"end": 41,
"text": "25 December 1642",
"document_ids": [
"web-search_0"
]
}
],
"search_queries": [
{
"text": "Isaac Newton birth year",
"generation_id": "9a497980-c3e2-4460-b81c-ef44d293f95d"
}
],
"search_results": [
{
"connector": {
"id": "web-search"
},
"document_ids": [
"web-search_0"
],
"search_query": {
"text": "Isaac Newton birth year",
"generation_id": "9a497980-c3e2-4460-b81c-ef44d293f95d"
}
}
],
"finish_reason": "COMPLETE",
"chat_history": [
{
"role": "USER",
"message": "Who discovered gravity?"
},
{
"role": "CHATBOT",
"message": "The man who is widely credited with discovering gravity is Sir Isaac Newton"
},
{
"role": "USER",
"message": "What year was he born?"
},
{
"role": "CHATBOT",
"message": "Isaac Newton was born on 25 December 1642 (Old Style) or 4 January 1643 (New Style)."
}
],
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"input_tokens": 31738,
"output_tokens": 35
},
"tokens": {
"input_tokens": 32465,
"output_tokens": 205
}
}
}
)";

data = json_helper::ParseJsonString(message);
data["stream"] = false;
data["model"] = "cohere";
res = rdr.Render(tpl, data);
res_json = json_helper::ParseJsonString(res);
EXPECT_EQ(
res_json["choices"][0]["message"]["content"].asString(),
"Isaac \t\tNewton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 "
"January 1643 (New Style).");
}

TEST_F(RemoteEngineTest, HeaderTemplate) {
{
std::string header_template =
Expand Down
Loading
Loading