Skip to content

Commit

Permalink
add seed for chat completion
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiayu committed Aug 7, 2024
1 parent a8101f1 commit 4975e68
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
16 changes: 6 additions & 10 deletions liboai/components/chat.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#include <ctime>
#include <random>

#include "../include/components/chat.h"

liboai::Conversation::Conversation() {
Expand Down Expand Up @@ -592,7 +589,7 @@ bool liboai::Conversation::ParseStreamData(std::string data, std::string& delta_



liboai::Response liboai::ChatCompletion::create(const std::string& model, Conversation& conversation, std::optional<std::string> function_call, std::optional<float> temperature, std::optional<float> top_p, std::optional<uint16_t> n, std::optional<ChatStreamCallback> stream, std::optional<std::vector<std::string>> stop, std::optional<uint16_t> max_tokens, std::optional<float> presence_penalty, std::optional<float> frequency_penalty, std::optional<std::unordered_map<std::string, int8_t>> logit_bias, std::optional<std::string> user) const& noexcept(false) {
liboai::Response liboai::ChatCompletion::create(const std::string& model, Conversation& conversation, std::optional<std::string> function_call, std::optional<float> temperature, std::optional<float> top_p, std::optional<uint16_t> n, std::optional<ChatStreamCallback> stream, std::optional<int32_t> seed, std::optional<std::vector<std::string>> stop, std::optional<uint16_t> max_tokens, std::optional<float> presence_penalty, std::optional<float> frequency_penalty, std::optional<std::unordered_map<std::string, int8_t>> logit_bias, std::optional<std::string> user) const& noexcept(false) {
liboai::JsonConstructor jcon;
jcon.push_back("model", model);
jcon.push_back("temperature", std::move(temperature));
Expand All @@ -605,11 +602,6 @@ liboai::Response liboai::ChatCompletion::create(const std::string& model, Conver
jcon.push_back("logit_bias", std::move(logit_bias));
jcon.push_back("user", std::move(user));

// generate seed
std::mt19937 rng(static_cast<uint32_t>(std::time(nullptr)));
std::uniform_int_distribution<int32_t> dist(1, 10000);
jcon.push_back("seed", dist(rng));

if (function_call) {
if (function_call.value() == "none" || function_call.value() == "auto") {
nlohmann::json j; j["function_call"] = function_call.value();
Expand All @@ -631,6 +623,8 @@ liboai::Response liboai::ChatCompletion::create(const std::string& model, Conver
jcon.push_back("stream", _sscb);
}

jcon.push_back("seed", std::move(seed));

if (conversation.GetJSON().contains("messages")) {
jcon.push_back("messages", conversation.GetJSON()["messages"]);
}
Expand All @@ -655,7 +649,7 @@ liboai::Response liboai::ChatCompletion::create(const std::string& model, Conver
return res;
}

liboai::FutureResponse liboai::ChatCompletion::create_async(const std::string& model, Conversation& conversation, std::optional<std::string> function_call, std::optional<float> temperature, std::optional<float> top_p, std::optional<uint16_t> n, std::optional<ChatStreamCallback> stream, std::optional<std::vector<std::string>> stop, std::optional<uint16_t> max_tokens, std::optional<float> presence_penalty, std::optional<float> frequency_penalty, std::optional<std::unordered_map<std::string, int8_t>> logit_bias, std::optional<std::string> user) const& noexcept(false) {
liboai::FutureResponse liboai::ChatCompletion::create_async(const std::string& model, Conversation& conversation, std::optional<std::string> function_call, std::optional<float> temperature, std::optional<float> top_p, std::optional<uint16_t> n, std::optional<ChatStreamCallback> stream, std::optional<int32_t> seed, std::optional<std::vector<std::string>> stop, std::optional<uint16_t> max_tokens, std::optional<float> presence_penalty, std::optional<float> frequency_penalty, std::optional<std::unordered_map<std::string, int8_t>> logit_bias, std::optional<std::string> user) const& noexcept(false) {
liboai::JsonConstructor jcon;
jcon.push_back("model", model);
jcon.push_back("temperature", std::move(temperature));
Expand Down Expand Up @@ -693,6 +687,8 @@ liboai::FutureResponse liboai::ChatCompletion::create_async(const std::string& m

jcon.push_back("stream", _sscb);
}

jcon.push_back("seed", std::move(seed));

if (conversation.GetJSON().contains("messages")) {
jcon.push_back("messages", conversation.GetJSON()["messages"]);
Expand Down
2 changes: 2 additions & 0 deletions liboai/include/components/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,7 @@ namespace liboai {
std::optional<float> top_p = std::nullopt,
std::optional<uint16_t> n = std::nullopt,
std::optional<ChatStreamCallback> stream = std::nullopt,
std::optional<int32_t> seed = std::nullopt,
std::optional<std::vector<std::string>> stop = std::nullopt,
std::optional<uint16_t> max_tokens = std::nullopt,
std::optional<float> presence_penalty = std::nullopt,
Expand Down Expand Up @@ -935,6 +936,7 @@ namespace liboai {
std::optional<float> top_p = std::nullopt,
std::optional<uint16_t> n = std::nullopt,
std::optional<ChatStreamCallback> stream = std::nullopt,
std::optional<int32_t> seed = std::nullopt,
std::optional<std::vector<std::string>> stop = std::nullopt,
std::optional<uint16_t> max_tokens = std::nullopt,
std::optional<float> presence_penalty = std::nullopt,
Expand Down

0 comments on commit 4975e68

Please sign in to comment.