Skip to content

Commit

Permalink
fix: check model status before inferencing (#1864)
Browse files Browse the repository at this point in the history
Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored Jan 16, 2025
1 parent 0746ec9 commit d847779
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 34 deletions.
82 changes: 49 additions & 33 deletions engine/services/inference_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
}
function_calling_utils::PreprocessRequest(json_body);
auto tool_choice = json_body->get("tool_choice", Json::Value::null);
auto model_id = json_body->get("model", "").asString();
if (saved_models_.find(model_id) != saved_models_.end()) {
// check if model is started, if not start it first
Json::Value root;
root["model"] = model_id;
root["engine"] = engine_type;
auto ir = GetModelStatus(std::make_shared<Json::Value>(root));
auto status = std::get<0>(ir)["status_code"].asInt();
if (status != drogon::k200OK) {
CTL_INF("Model is not loaded, start loading it: " << model_id);
auto res = LoadModel(saved_models_.at(model_id));
// ignore return result
}
}

auto engine_result = engine_service_->GetLoadedEngine(engine_type);
if (engine_result.has_error()) {
Json::Value res;
Expand All @@ -23,45 +38,42 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
LOG_WARN << "Engine is not loaded yet";
return cpp::fail(std::make_pair(stt, res));
}

if (!model_id.empty()) {
if (auto model_service = model_service_.lock()) {
auto metadata_ptr = model_service->GetCachedModelMetadata(model_id);
if (metadata_ptr != nullptr &&
!metadata_ptr->tokenizer->chat_template.empty()) {
auto tokenizer = metadata_ptr->tokenizer;
auto messages = (*json_body)["messages"];
Json::Value messages_jsoncpp(Json::arrayValue);
for (auto message : messages) {
messages_jsoncpp.append(message);
}

{
auto model_id = json_body->get("model", "").asString();
if (!model_id.empty()) {
if (auto model_service = model_service_.lock()) {
auto metadata_ptr = model_service->GetCachedModelMetadata(model_id);
if (metadata_ptr != nullptr &&
!metadata_ptr->tokenizer->chat_template.empty()) {
auto tokenizer = metadata_ptr->tokenizer;
auto messages = (*json_body)["messages"];
Json::Value messages_jsoncpp(Json::arrayValue);
for (auto message : messages) {
messages_jsoncpp.append(message);
}

Json::Value tools(Json::arrayValue);
Json::Value template_data_json;
template_data_json["messages"] = messages_jsoncpp;
// template_data_json["tools"] = tools;

auto prompt_result = jinja::RenderTemplate(
tokenizer->chat_template, template_data_json,
tokenizer->bos_token, tokenizer->eos_token,
tokenizer->add_bos_token, tokenizer->add_eos_token,
tokenizer->add_generation_prompt);
if (prompt_result.has_value()) {
(*json_body)["prompt"] = prompt_result.value();
Json::Value stops(Json::arrayValue);
stops.append(tokenizer->eos_token);
(*json_body)["stop"] = stops;
} else {
CTL_ERR("Failed to render prompt: " + prompt_result.error());
}
Json::Value tools(Json::arrayValue);
Json::Value template_data_json;
template_data_json["messages"] = messages_jsoncpp;
// template_data_json["tools"] = tools;

auto prompt_result = jinja::RenderTemplate(
tokenizer->chat_template, template_data_json, tokenizer->bos_token,
tokenizer->eos_token, tokenizer->add_bos_token,
tokenizer->add_eos_token, tokenizer->add_generation_prompt);
if (prompt_result.has_value()) {
(*json_body)["prompt"] = prompt_result.value();
Json::Value stops(Json::arrayValue);
stops.append(tokenizer->eos_token);
(*json_body)["stop"] = stops;
} else {
CTL_ERR("Failed to render prompt: " + prompt_result.error());
}
}
}
}

CTL_INF("Json body inference: " + json_body->toStyledString());

CTL_DBG("Json body inference: " + json_body->toStyledString());

auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
if (!tool_choice.isNull()) {
Expand Down Expand Up @@ -205,6 +217,10 @@ InferResult InferenceService::LoadModel(
std::get<RemoteEngineI*>(engine_result.value())
->LoadModel(json_body, std::move(cb));
}
if (!engine_service_->IsRemoteEngine(engine_type)) {
auto model_id = json_body->get("model", "").asString();
saved_models_[model_id] = json_body;
}
return std::make_pair(stt, r);
}

Expand Down
4 changes: 3 additions & 1 deletion engine/services/inference_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class InferenceService {

cpp::result<void, InferResult> HandleRouteRequest(
std::shared_ptr<SyncQueue> q, std::shared_ptr<Json::Value> json_body);

InferResult LoadModel(std::shared_ptr<Json::Value> json_body);

InferResult UnloadModel(const std::string& engine,
Expand All @@ -74,4 +74,6 @@ class InferenceService {
private:
std::shared_ptr<EngineService> engine_service_;
std::weak_ptr<ModelService> model_service_;
using SavedModel = std::shared_ptr<Json::Value>;
std::unordered_map<std::string, SavedModel> saved_models_;
};

0 comments on commit d847779

Please sign in to comment.