Skip to content

Commit

Permalink
feat: support pre_prompt (#9)
Browse files Browse the repository at this point in the history
* feat: support pre_prompt

* fix: move string

---------

Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored Jun 14, 2024
1 parent 977b3d3 commit 79f1299
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/onnx_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ void OnnxEngine::HandleChatCompletion(
auto req = onnx::inferences::fromJson(json_body);
auto is_stream = json_body->get("stream", false).asBool();

std::string formatted_output;
std::string formatted_output = pre_prompt_;
for (const auto& message : req.messages) {
std::string input_role = message["role"].asString();
std::string role;
Expand All @@ -205,13 +205,12 @@ void OnnxEngine::HandleChatCompletion(
formatted_output += ai_prompt_;

// LOG_DEBUG << formatted_output;
// TODO(sang)
q_->runTaskInQueue([this, cb = std::move(callback), formatted_output, req] {
q_->runTaskInQueue([this, cb = std::move(callback),
fo = std::move(formatted_output), req] {
try {
if (req.stream) {

auto sequences = OgaSequences::Create();
tokenizer_->Encode(formatted_output.c_str(), *sequences);
tokenizer_->Encode(fo.c_str(), *sequences);

auto params = OgaGeneratorParams::Create(*oga_model_);
// TODO(sang)
Expand Down Expand Up @@ -273,7 +272,7 @@ void OnnxEngine::HandleChatCompletion(

} else {
auto sequences = OgaSequences::Create();
tokenizer_->Encode(formatted_output.c_str(), *sequences);
tokenizer_->Encode(fo.c_str(), *sequences);

auto params = OgaGeneratorParams::Create(*oga_model_);
params->SetSearchOption("max_length", req.max_tokens);
Expand Down

0 comments on commit 79f1299

Please sign in to comment.