diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index bdbb67ed8..bf2538115 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -167,6 +167,7 @@ else() endif() aux_source_directory(controllers CTL_SRC) +aux_source_directory(repositories REPO_SRC) aux_source_directory(services SERVICES_SRC) aux_source_directory(common COMMON_SRC) aux_source_directory(models MODEL_SRC) @@ -177,7 +178,7 @@ aux_source_directory(migrations MIGR_SRC) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ) -target_sources(${TARGET_NAME} PRIVATE ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC} ${MIGR_SRC}) +target_sources(${TARGET_NAME} PRIVATE ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC} ${MIGR_SRC} ${REPO_SRC}) set_target_properties(${TARGET_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_BINARY_DIR} diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index 477e38ee2..954d5b5f1 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -177,7 +177,6 @@ bool EngineInstallCmd::Exec(const std::string& engine, auto response = curl_utils::SimplePostJson(install_url.ToFullPath(), body.toStyledString()); if (response.has_error()) { - // TODO: namh refactor later Json::Value root; Json::Reader reader; if (!reader.parse(response.error(), root)) { diff --git a/engine/common/api-dto/messages/delete_message_response.h b/engine/common/api-dto/messages/delete_message_response.h new file mode 100644 index 000000000..37243294e --- /dev/null +++ b/engine/common/api-dto/messages/delete_message_response.h @@ -0,0 +1,19 @@ +#pragma once + +#include "common/json_serializable.h" + +namespace ApiResponseDto { +struct DeleteMessageResponse : JsonSerializable { + std::string id; + std::string object; + bool deleted; + + cpp::result ToJson() override { + Json::Value json; + json["id"] = id; + json["object"] = object; + json["deleted"] = deleted; + return json; + } +}; +} // namespace ApiResponseDto diff --git a/engine/common/json_serializable.h b/engine/common/json_serializable.h new file mode 100644 index 000000000..4afec92c5 --- /dev/null +++ b/engine/common/json_serializable.h @@ -0,0 +1,11 @@ +#pragma once + +#include +#include "utils/result.hpp" + +struct JsonSerializable { + + virtual cpp::result ToJson() = 0; + + virtual ~JsonSerializable() = default; +}; diff --git a/engine/common/message.h b/engine/common/message.h new file mode 100644 index 000000000..dbc61c6fd --- /dev/null +++ b/engine/common/message.h @@ -0,0 +1,204 @@ +#pragma once + +#include +#include +#include +#include +#include +#include "common/message_attachment.h" +#include "common/message_attachment_factory.h" +#include "common/message_content.h" +#include "common/message_content_factory.h" +#include "common/message_incomplete_detail.h" +#include "common/message_role.h" +#include "common/message_status.h" +#include "common/variant_map.h" +#include "json_serializable.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +namespace ThreadMessage { + +// Represents a message within a thread. +struct Message : JsonSerializable { + + // The identifier, which can be referenced in API endpoints. + std::string id; + + // The object type, which is always thread.message. + std::string object = "thread.message"; + + // The Unix timestamp (in seconds) for when the message was created. + uint32_t created_at; + + // The thread ID that this message belongs to. + std::string thread_id; + + // The status of the message, which can be either in_progress, incomplete, or completed. + Status status; + + // On an incomplete message, details about why the message is incomplete. + std::optional incomplete_details; + + // The Unix timestamp (in seconds) for when the message was completed. + std::optional completed_at; + + // The Unix timestamp (in seconds) for when the message was marked as incomplete. + std::optional incomplete_at; + + Role role; + + // The content of the message in array of text and/or images. + std::vector> content; + + // If applicable, the ID of the assistant that authored this message. + std::optional assistant_id; + + // The ID of the run associated with the creation of this message. Value is null when messages are created manually using the create message or create thread endpoints. + std::optional run_id; + + // A list of files attached to the message, and the tools they were added to. + std::optional> attachments; + + // Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + Cortex::VariantMap metadata; + + static cpp::result FromJsonString( + std::string&& json_str) { + Json::Value root; + Json::Reader reader; + if (!reader.parse(json_str, root)) { + return cpp::fail("Failed to parse JSON: " + + reader.getFormattedErrorMessages()); + } + + Message message; + + try { + message.id = std::move(root["id"].asString()); + message.object = + std::move(root.get("object", "thread.message").asString()); + message.created_at = root["created_at"].asUInt(); + if (message.created_at == 0 && root["created"].asUInt64() != 0) { + message.created_at = root["created"].asUInt64() / 1000; + } + message.thread_id = root["thread_id"].asString(); + message.status = StatusFromString(root["status"].asString()); + + message.incomplete_details = + IncompleteDetail::FromJson(std::move(root["incomplete_details"])) + .value(); + message.completed_at = root["completed_at"].asUInt(); + message.incomplete_at = root["incomplete_at"].asUInt(); + message.role = RoleFromString(root["role"].asString()); + message.content = ParseContents(std::move(root["content"])).value(); + + message.assistant_id = root["assistant_id"].asString(); + message.run_id = root["run_id"].asString(); + message.attachments = + ParseAttachments(std::move(root["attachments"])).value(); + + if (root["metadata"].isObject() && !root["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(root["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + message.metadata = res.value(); + } + } + + return message; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJsonString failed: ") + e.what()); + } + } + + cpp::result ToSingleLineJsonString() { + auto json_result = ToJson(); + if (json_result.has_error()) { + return cpp::fail(json_result.error()); + } + + Json::FastWriter writer; + try { + return writer.write(json_result.value()); + } catch (const std::exception& e) { + return cpp::fail(std::string("Failed to write JSON: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value json; + + json["id"] = id; + json["object"] = object; + json["created_at"] = created_at; + json["thread_id"] = thread_id; + json["status"] = StatusToString(status); + + if (incomplete_details.has_value()) { + if (auto it = incomplete_details->ToJson(); it.has_value()) { + json["incomplete_details"] = it.value(); + } else { + CTL_WRN("Failed to convert incomplete_details to json: " + + it.error()); + } + } + if (completed_at.has_value() && completed_at.value() != 0) { + json["completed_at"] = *completed_at; + } + if (incomplete_at.has_value() && incomplete_at.value() != 0) { + json["incomplete_at"] = *incomplete_at; + } + + json["role"] = RoleToString(role); + + Json::Value content_json_arr{Json::arrayValue}; + for (auto& child_content : content) { + if (auto it = child_content->ToJson(); it.has_value()) { + content_json_arr.append(it.value()); + } else { + CTL_WRN("Failed to convert content to json: " + it.error()); + } + } + json["content"] = content_json_arr; + if (assistant_id.has_value() && !assistant_id->empty()) { + json["assistant_id"] = *assistant_id; + } + if (run_id.has_value() && !run_id->empty()) { + json["run_id"] = *run_id; + } + if (attachments.has_value()) { + Json::Value attachments_json_arr{Json::arrayValue}; + for (auto& attachment : *attachments) { + if (auto it = attachment.ToJson(); it.has_value()) { + attachments_json_arr.append(it.value()); + } else { + CTL_WRN("Failed to convert attachment to json: " + it.error()); + } + } + json["attachments"] = attachments_json_arr; + } + + Json::Value metadata_json{Json::objectValue}; + for (const auto& [key, value] : metadata) { + if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else { + metadata_json[key] = std::get(value); + } + } + json["metadata"] = metadata_json; + + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +}; // namespace ThreadMessage diff --git a/engine/common/message_attachment.h b/engine/common/message_attachment.h new file mode 100644 index 000000000..ea809990e --- /dev/null +++ b/engine/common/message_attachment.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include "common/json_serializable.h" + +namespace ThreadMessage { + +// The tools to add this file to. +struct Tool { + std::string type; + + Tool(const std::string& type) : type{type} {} +}; + +// The type of tool being defined: code_interpreter +struct CodeInterpreter : Tool { + CodeInterpreter() : Tool{"code_interpreter"} {} +}; + +// The type of tool being defined: file_search +struct FileSearch : Tool { + FileSearch() : Tool{"file_search"} {} +}; + +// A list of files attached to the message, and the tools they were added to. +struct Attachment : JsonSerializable { + + // The ID of the file to attach to the message. + std::string file_id; + + std::vector tools; + + cpp::result ToJson() override { + try { + Json::Value json; + json["file_id"] = file_id; + Json::Value tools_json_arr{Json::arrayValue}; + for (auto& tool : tools) { + Json::Value tool_json; + tool_json["type"] = tool.type; + tools_json_arr.append(tool_json); + } + json["tools"] = tools_json_arr; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +}; // namespace ThreadMessage diff --git a/engine/common/message_attachment_factory.h b/engine/common/message_attachment_factory.h new file mode 100644 index 000000000..d9f1b8d2e --- /dev/null +++ b/engine/common/message_attachment_factory.h @@ -0,0 +1,48 @@ +#include +#include "common/message_attachment.h" +#include "utils/result.hpp" + +namespace ThreadMessage { +inline cpp::result ParseAttachment( + Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + Attachment attachment; + attachment.file_id = json["file_id"].asString(); + + std::vector tools{}; + if (json["tools"].isArray()) { + for (auto& tool_json : json["tools"]) { + Tool tool{tool_json["type"].asString()}; + tools.push_back(tool); + } + } + attachment.tools = tools; + + return attachment; +} + +inline cpp::result>, std::string> +ParseAttachments(Json::Value&& json) { + if (json.empty()) { + // still count as success + return std::nullopt; + } + if (!json.isArray()) { + return cpp::fail("Json is not an array"); + } + + std::vector attachments; + for (auto& attachment_json : json) { + auto attachment = ParseAttachment(std::move(attachment_json)); + if (attachment.has_error()) { + return cpp::fail(attachment.error()); + } + attachments.push_back(attachment.value()); + } + + return attachments; +} +}; // namespace ThreadMessage diff --git a/engine/common/message_content.h b/engine/common/message_content.h new file mode 100644 index 000000000..7f3ec8e59 --- /dev/null +++ b/engine/common/message_content.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include "common/json_serializable.h" + +namespace ThreadMessage { + +struct Content : JsonSerializable { + std::string type; + + Content(const std::string& type) : type{type} {} + + virtual ~Content() = default; +}; +}; // namespace ThreadMessage diff --git a/engine/common/message_content_factory.h b/engine/common/message_content_factory.h new file mode 100644 index 000000000..854f6efd8 --- /dev/null +++ b/engine/common/message_content_factory.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include "common/message_content_image_file.h" +#include "common/message_content_image_url.h" +#include "common/message_content_refusal.h" +#include "common/message_content_text.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +namespace ThreadMessage { +inline cpp::result, std::string> ParseContent( + Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + try { + auto type = json["type"].asString(); + + if (type == "image_file") { + auto result = ImageFileContent::FromJson(std::move(json)); + if (result.has_error()) { + return cpp::fail(result.error()); + } + return std::make_unique(std::move(result.value())); + } else if (type == "image_url") { + auto result = ImageUrlContent::FromJson(std::move(json)); + if (result.has_error()) { + return cpp::fail(result.error()); + } + return std::make_unique(std::move(result.value())); + } else if (type == "text") { + auto result = TextContent::FromJson(std::move(json)); + if (result.has_error()) { + return cpp::fail(result.error()); + } + return std::make_unique(std::move(result.value())); + } else if (type == "refusal") { + auto result = Refusal::FromJson(std::move(json)); + if (result.has_error()) { + return cpp::fail(result.error()); + } + return std::make_unique(std::move(result.value())); + } else { + return cpp::fail("Unknown content type: " + type); + } + + return cpp::fail("Unknown content type: " + type); + } catch (const std::exception& e) { + return cpp::fail(std::string("ParseContent failed: ") + e.what()); + } +} + +inline cpp::result>, std::string> +ParseContents(Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + if (!json.isArray()) { + return cpp::fail("Json is not an array"); + } + + std::vector> contents; + Json::Value mutable_json = std::move(json); + + for (auto& content_json : mutable_json) { + auto content = ParseContent(std::move(content_json)); + if (content.has_error()) { + CTL_WRN(content.error()); + continue; + } + contents.push_back(std::move(content.value())); + } + return contents; +} +} // namespace ThreadMessage diff --git a/engine/common/message_content_image_file.h b/engine/common/message_content_image_file.h new file mode 100644 index 000000000..83cf62b9e --- /dev/null +++ b/engine/common/message_content_image_file.h @@ -0,0 +1,51 @@ +#pragma once + +#include "common/message_content.h" + +namespace ThreadMessage { +struct ImageFile { + // The File ID of the image in the message content. Set purpose="vision" when uploading the File if you need to later display the file content. + std::string file_id; + + // Specifies the detail level of the image if specified by the user. low uses fewer tokens, you can opt in to high resolution using high. + std::string detail; +}; + +// References an image File in the content of a message. +struct ImageFileContent : Content { + + ImageFileContent() : Content("image_file") {} + + ImageFile image_file; + + static cpp::result FromJson( + Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + try { + ImageFileContent content; + ImageFile image_file; + image_file.detail = json["image_file"]["detail"].asString(); + image_file.file_id = json["image_file"]["file_id"].asString(); + content.image_file = image_file; + return content; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value json; + json["type"] = type; + json["image_file"]["file_id"] = image_file.file_id; + json["image_file"]["detail"] = image_file.detail; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace ThreadMessage diff --git a/engine/common/message_content_image_url.h b/engine/common/message_content_image_url.h new file mode 100644 index 000000000..f52d299a4 --- /dev/null +++ b/engine/common/message_content_image_url.h @@ -0,0 +1,53 @@ +#pragma once + +#include "common/message_content.h" + +namespace ThreadMessage { + +struct ImageUrl { + // The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp. + std::string url; + + // Specifies the detail level of the image. low uses fewer tokens, you can opt in to high resolution using high. Default value is auto + std::string detail; +}; + +// References an image URL in the content of a message. +struct ImageUrlContent : Content { + + // The type of the content part. + ImageUrlContent(const std::string& type) : Content(type) {} + + ImageUrl image_url; + + static cpp::result FromJson( + Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + try { + ImageUrlContent content{"image_url"}; + ImageUrl image_url; + image_url.url = json["image_url"]["url"].asString(); + image_url.detail = json["image_url"]["detail"].asString(); + content.image_url = image_url; + return content; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value json; + json["type"] = type; + json["image_url"]["url"] = image_url.url; + json["image_url"]["detail"] = image_url.detail; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace ThreadMessage diff --git a/engine/common/message_content_refusal.h b/engine/common/message_content_refusal.h new file mode 100644 index 000000000..3fb7f78b6 --- /dev/null +++ b/engine/common/message_content_refusal.h @@ -0,0 +1,39 @@ +#pragma once + +#include "common/message_content.h" + +namespace ThreadMessage { +// The refusal content generated by the assistant. +struct Refusal : Content { + + // Always refusal. + Refusal() : Content("refusal") {} + + std::string refusal; + + static cpp::result FromJson(Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + try { + Refusal content; + content.refusal = json["refusal"].asString(); + return content; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value json; + json["type"] = type; + json["refusal"] = refusal; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace ThreadMessage diff --git a/engine/common/message_content_text.h b/engine/common/message_content_text.h new file mode 100644 index 000000000..0704ac7d5 --- /dev/null +++ b/engine/common/message_content_text.h @@ -0,0 +1,187 @@ +#pragma once + +#include "common/message_content.h" +#include "utils/logging_utils.h" + +namespace ThreadMessage { + +struct Annotation : JsonSerializable { + std::string type; + + // The text in the message content that needs to be replaced. + std::string text; + + uint32_t start_index; + + uint32_t end_index; + + Annotation(const std::string& type, const std::string& text, + uint32_t start_index, uint32_t end_index) + : type{type}, + text{text}, + start_index{start_index}, + end_index{end_index} {} + + virtual ~Annotation() = default; +}; + +// A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "file_search" tool to search files. +struct FileCitationWrapper : Annotation { + + // Always file_citation. + FileCitationWrapper(const std::string& text, uint32_t start_index, + uint32_t end_index) + : Annotation("file_citation", text, start_index, end_index) {} + + struct FileCitation { + // The ID of the specific File the citation is from. + std::string file_id; + }; + + FileCitation file_citation; + + cpp::result ToJson() override { + try { + Json::Value json; + json["text"] = text; + json["type"] = type; + json["file_citation"]["file_id"] = file_citation.file_id; + json["start_index"] = start_index; + json["end_index"] = end_index; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; + +// A URL for the file that's generated when the assistant used the code_interpreter tool to generate a file. +struct FilePathWrapper : Annotation { + // Always file_path. + FilePathWrapper(const std::string& text, uint32_t start_index, + uint32_t end_index) + : Annotation("file_path", text, start_index, end_index) {} + + struct FilePath { + // The ID of the file that was generated. + std::string file_id; + }; + + FilePath file_path; + + cpp::result ToJson() override { + try { + Json::Value json; + json["text"] = text; + json["type"] = type; + json["file_path"]["file_id"] = file_path.file_id; + json["start_index"] = start_index; + json["end_index"] = end_index; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; + +struct Text : JsonSerializable { + // The data that makes up the text. + std::string value; + + std::vector> annotations; + + static cpp::result FromJson(Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + try { + Text text; + text.value = json["value"].asString(); + + // Parse annotations array + if (json.isMember("annotations") && json["annotations"].isArray()) { + for (const auto& annotation_json : json["annotations"]) { + std::string type = annotation_json["type"].asString(); + std::string annotation_text = annotation_json["text"].asString(); + uint32_t start_index = annotation_json["start_index"].asUInt(); + uint32_t end_index = annotation_json["end_index"].asUInt(); + + if (type == "file_citation") { + auto citation = std::make_unique( + annotation_text, start_index, end_index); + citation->file_citation.file_id = + annotation_json["file_citation"]["file_id"].asString(); + text.annotations.push_back(std::move(citation)); + } else if (type == "file_path") { + auto file_path = std::make_unique( + annotation_text, start_index, end_index); + file_path->file_path.file_id = + annotation_json["file_path"]["file_id"].asString(); + text.annotations.push_back(std::move(file_path)); + } else { + CTL_WRN("Unknown annotation type: " + type); + } + } + } + + return text; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value json; + json["value"] = value; + Json::Value annotations_json_arr{Json::arrayValue}; + for (auto& annotation : annotations) { + if (auto it = annotation->ToJson(); it.has_value()) { + annotations_json_arr.append(it.value()); + } else { + CTL_WRN("Failed to convert annotation to json: " + it.error()); + } + } + json["annotations"] = annotations_json_arr; + return json; + } catch (const std::exception e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + }; +}; + +// The text content that is part of a message. +struct TextContent : Content { + // Always text. + TextContent() : Content("text") {} + + Text text; + + static cpp::result FromJson(Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + try { + TextContent content; + content.type = json["type"].asString(); + content.text = Text::FromJson(std::move(json["text"])).value(); + return content; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value json; + json["type"] = type; + json["text"] = text.ToJson().value(); + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace ThreadMessage diff --git a/engine/common/message_incomplete_detail.h b/engine/common/message_incomplete_detail.h new file mode 100644 index 000000000..25e9c1169 --- /dev/null +++ b/engine/common/message_incomplete_detail.h @@ -0,0 +1,32 @@ +#pragma once + +#include "common/json_serializable.h" + +namespace ThreadMessage { + +// On an incomplete message, details about why the message is incomplete. +struct IncompleteDetail : JsonSerializable { + // The reason the message is incomplete. + std::string reason; + + static cpp::result, std::string> FromJson( + Json::Value&& json) { + if (json.empty()) { + return std::nullopt; + } + IncompleteDetail incomplete_detail; + incomplete_detail.reason = json["reason"].asString(); + return incomplete_detail; + } + + cpp::result ToJson() override { + try { + Json::Value json; + json["reason"] = reason; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace ThreadMessage diff --git a/engine/common/message_role.h b/engine/common/message_role.h new file mode 100644 index 000000000..9d428eddc --- /dev/null +++ b/engine/common/message_role.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include "utils/string_utils.h" + +namespace ThreadMessage { +// The entity that produced the message. One of user or assistant. +enum class Role { USER, ASSISTANT }; + +inline std::string RoleToString(Role role) { + switch (role) { + case Role::USER: + return "user"; + case Role::ASSISTANT: + return "assistant"; + default: + throw new std::invalid_argument("Invalid role: " + + std::to_string((int)role)); + } +} + +inline Role RoleFromString(const std::string& input) { + if (string_utils::EqualsIgnoreCase(input, "user")) { + return Role::USER; + } else { + // for backward compatible with jan. Before, jan was mark text with `ready` + return Role::ASSISTANT; + } +} +}; // namespace ThreadMessage diff --git a/engine/common/message_status.h b/engine/common/message_status.h new file mode 100644 index 000000000..e8844ee13 --- /dev/null +++ b/engine/common/message_status.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include "utils/string_utils.h" + +namespace ThreadMessage { +// The status of the message, which can be either in_progress, incomplete, or completed. +enum class Status { IN_PROGRESS, INCOMPLETE, COMPLETED }; + +// Convert a Status enum to a string. +inline std::string StatusToString(Status status) { + switch (status) { + case Status::IN_PROGRESS: + return "in_progress"; + case Status::INCOMPLETE: + return "incomplete"; + // default as completed for backward compatible with jan + default: + return "completed"; + } +} + +// Convert a string to a Status enum. +inline Status StatusFromString(const std::string& input) { + if (string_utils::EqualsIgnoreCase(input, "in_progress")) { + return Status::IN_PROGRESS; + } else if (string_utils::EqualsIgnoreCase(input, "incomplete")) { + return Status::INCOMPLETE; + } else { + // for backward compatible with jan. Before, jan was mark text with `ready` + return Status::COMPLETED; + } +} +}; // namespace ThreadMessage diff --git a/engine/common/repository/message_repository.h b/engine/common/repository/message_repository.h new file mode 100644 index 000000000..cffc73675 --- /dev/null +++ b/engine/common/repository/message_repository.h @@ -0,0 +1,27 @@ +#pragma once + +#include "common/message.h" +#include "utils/result.hpp" + +class MessageRepository { + public: + virtual cpp::result CreateMessage( + ThreadMessage::Message& message) = 0; + + virtual cpp::result, std::string> + ListMessages(const std::string& thread_id, uint8_t limit = 20, + const std::string& order = "desc", const std::string& after = "", + const std::string& before = "", + const std::string& run_id = "") const = 0; + + virtual cpp::result RetrieveMessage( + const std::string& thread_id, const std::string& message_id) const = 0; + + virtual cpp::result ModifyMessage( + ThreadMessage::Message& message) = 0; + + virtual cpp::result DeleteMessage( + const std::string& thread_id, const std::string& message_id) = 0; + + virtual ~MessageRepository() = default; +}; diff --git a/engine/common/variant_map.h b/engine/common/variant_map.h new file mode 100644 index 000000000..c8da77317 --- /dev/null +++ b/engine/common/variant_map.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include +#include +#include +#include "utils/result.hpp" + +namespace Cortex { + +using ValueVariant = std::variant; +using VariantMap = std::unordered_map; + +inline cpp::result ConvertJsonValueToMap( + const Json::Value& json) { + VariantMap result; + + if (!json.isObject()) { + return cpp::fail("Input json is not an object"); + } + + for (const auto& key : json.getMemberNames()) { + const Json::Value& value = json[key]; + + switch (value.type()) { + case Json::nullValue: + // Skip null values + break; + + case Json::stringValue: + result.emplace(key, value.asString()); + break; + + case Json::booleanValue: + result.emplace(key, value.asBool()); + break; + + case Json::uintValue: + case Json::intValue: + // Handle both signed and unsigned integers + if (value.isUInt64()) { + result.emplace(key, value.asUInt64()); + } else { + // Convert to double if the integer is negative or too large + result.emplace(key, value.asDouble()); + } + break; + + case Json::realValue: + result.emplace(key, value.asDouble()); + break; + + case Json::arrayValue: + case Json::objectValue: + // currently does not handle complex type + break; + } + } + + return result; +} +}; // namespace Cortex diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 044fd8dd3..f5eae72a4 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -1,13 +1,12 @@ #pragma once #include -#include -#include #include #include #include #include #include "utils/format_utils.h" + namespace config { struct ModelConfig { std::string name; diff --git a/engine/controllers/messages.cc b/engine/controllers/messages.cc new file mode 100644 index 000000000..8551ec03d --- /dev/null +++ b/engine/controllers/messages.cc @@ -0,0 +1,298 @@ +#include "messages.h" +#include "common/api-dto/messages/delete_message_response.h" +#include "common/message_content.h" +#include "common/message_role.h" +#include "common/variant_map.h" +#include "utils/cortex_utils.h" +#include "utils/logging_utils.h" +#include "utils/string_utils.h" + +void Messages::ListMessages( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, std::optional limit, + std::optional order, std::optional after, + std::optional before, + std::optional run_id) const { + auto res = message_service_->ListMessages( + thread_id, limit.value_or(20), order.value_or("desc"), after.value_or(""), + before.value_or(""), run_id.value_or("")); + + Json::Value root; + if (res.has_error()) { + root["message"] = res.error(); + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k400BadRequest); + callback(response); + return; + } + Json::Value msg_arr(Json::arrayValue); + for (auto& msg : res.value()) { + if (auto it = msg.ToJson(); it.has_value()) { + msg_arr.append(it.value()); + } else { + CTL_WRN("Failed to convert message to json: " + it.error()); + } + } + + root["object"] = "list"; + root["data"] = msg_arr; + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k200OK); + callback(response); +} + +void Messages::CreateMessage( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // role + auto role_str = json_body->get("role", "").asString(); + if (role_str.empty()) { + Json::Value ret; + ret["message"] = "Role is required"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + if (role_str != "user" && role_str != "assistant") { + Json::Value ret; + ret["message"] = "Role must be either 'user' or 'assistant'"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + ThreadMessage::Role role = role_str == "user" + ? ThreadMessage::Role::USER + : ThreadMessage::Role::ASSISTANT; + + std::variant>> + content; + + if (json_body->get("content", "").isArray()) { + auto result = ThreadMessage::ParseContents(json_body->get("content", "")); + if (result.has_error()) { + Json::Value ret; + ret["message"] = "Failed to parse content array: " + result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + if (result.value().empty()) { + Json::Value ret; + ret["message"] = "Content array cannot be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + content = std::move(result.value()); + } else if (json_body->get("content", "").isString()) { + auto content_str = json_body->get("content", "").asString(); + string_utils::Trim(content_str); + if (content_str.empty()) { + Json::Value ret; + ret["message"] = "Content can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // success get content as string + content = content_str; + } else { + Json::Value ret; + ret["message"] = "Content must be either a string or an array"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // attachments + std::optional> attachments = + std::nullopt; + if (json_body->get("attachments", "").isArray()) { + attachments = ThreadMessage::ParseAttachments( + std::move(json_body->get("attachments", ""))) + .value(); + } + + std::optional metadata = std::nullopt; + if (json_body->get("metadata", "").isObject()) { + auto res = Cortex::ConvertJsonValueToMap(json_body->get("metadata", "")); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + metadata = res.value(); + } + } + + auto res = message_service_->CreateMessage( + thread_id, role, std::move(content), attachments, metadata); + if (res.has_error()) { + Json::Value ret; + ret["message"] = "Content must be either a string or an array"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto message_to_json = res->ToJson(); + if (message_to_json.has_error()) { + CTL_ERR("Failed to convert message to json: " + message_to_json.error()); + Json::Value ret; + ret["message"] = message_to_json.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Messages::RetrieveMessage( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, const std::string& message_id) const { + auto res = message_service_->RetrieveMessage(thread_id, message_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto message_to_json = res->ToJson(); + if (message_to_json.has_error()) { + CTL_ERR("Failed to convert message to json: " + message_to_json.error()); + Json::Value ret; + ret["message"] = message_to_json.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Messages::ModifyMessage( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, const std::string& message_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + std::optional metadata = std::nullopt; + if (auto it = json_body->get("metadata", ""); it) { + if (it.empty()) { + Json::Value ret; + ret["message"] = "Metadata can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + auto convert_res = Cortex::ConvertJsonValueToMap(it); + if (convert_res.has_error()) { + Json::Value ret; + ret["message"] = + "Failed to convert metadata to map: " + convert_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + metadata = convert_res.value(); + } + + if (!metadata.has_value()) { + Json::Value ret; + ret["message"] = "Metadata is mandatory"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto res = + message_service_->ModifyMessage(thread_id, message_id, metadata.value()); + if (res.has_error()) { + Json::Value ret; + ret["message"] = "Failed to modify message: " + res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto message_to_json = res->ToJson(); + if (message_to_json.has_error()) { + CTL_ERR("Failed to convert message to json: " + message_to_json.error()); + Json::Value ret; + ret["message"] = message_to_json.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Messages::DeleteMessage( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, const std::string& message_id) { + auto res = message_service_->DeleteMessage(thread_id, message_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = "Failed to delete message: " + res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + ApiResponseDto::DeleteMessageResponse response; + response.id = message_id; + response.object = "thread.message.deleted"; + response.deleted = true; + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(response.ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} diff --git a/engine/controllers/messages.h b/engine/controllers/messages.h new file mode 100644 index 000000000..340317eb8 --- /dev/null +++ b/engine/controllers/messages.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include "services/message_service.h" + +using namespace drogon; + +class Messages : public drogon::HttpController { + public: + METHOD_LIST_BEGIN + ADD_METHOD_TO(Messages::CreateMessage, "/v1/threads/{1}/messages", Options, + Post); + + ADD_METHOD_TO(Messages::ListMessages, + "/v1/threads/{thread_id}/" + "messages?limit={limit}&order={order}&after={after}&before={" + "before}&run_id={run_id}", + Get); + + ADD_METHOD_TO(Messages::RetrieveMessage, "/v1/threads/{1}/messages/{2}", Get); + ADD_METHOD_TO(Messages::ModifyMessage, "/v1/threads/{1}/messages/{2}", + Options, Post); + ADD_METHOD_TO(Messages::DeleteMessage, "/v1/threads/{1}/messages/{2}", + Options, Delete); + METHOD_LIST_END + + Messages(std::shared_ptr msg_srv) + : message_service_{msg_srv} {} + + void CreateMessage(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id); + + void ListMessages(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, std::optional limit, + std::optional order, + std::optional after, + std::optional before, + std::optional run_id) const; + + void RetrieveMessage(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, + const std::string& message_id) const; + + void ModifyMessage(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, + const std::string& message_id); + + void DeleteMessage(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, + const std::string& message_id); + + private: + std::shared_ptr message_service_; +}; diff --git a/engine/main.cc b/engine/main.cc index 1aa024a10..1224b07d2 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -1,18 +1,22 @@ #include #include #include +#include "common/repository/message_repository.h" #include "controllers/configs.h" #include "controllers/engines.h" #include "controllers/events.h" #include "controllers/hardware.h" +#include "controllers/messages.h" #include "controllers/models.h" #include "controllers/process_manager.h" #include "controllers/server.h" #include "cortex-common/cortexpythoni.h" #include "database/database.h" #include "migrations/migration_manager.h" +#include "repositories/message_fs_repository.h" #include "services/config_service.h" #include "services/file_watcher_service.h" +#include "services/message_service.h" #include "services/model_service.h" #include "utils/archive_utils.h" #include "utils/cortex_utils.h" @@ -110,6 +114,9 @@ void RunServer(std::optional port, bool ignore_cout) { auto event_queue_ptr = std::make_shared(); cortex::event::EventProcessor event_processor(event_queue_ptr); + std::shared_ptr msg_repo = + std::make_shared(); + auto message_srv = std::make_shared(msg_repo); auto model_dir_path = file_manager_utils::GetModelsContainerPath(); auto config_service = std::make_shared(); auto download_service = @@ -125,6 +132,7 @@ void RunServer(std::optional port, bool ignore_cout) { file_watcher_srv->start(); // initialize custom controllers + auto message_ctl = std::make_shared(message_srv); auto engine_ctl = std::make_shared(engine_service); auto model_ctl = std::make_shared(model_service, engine_service); auto event_ctl = std::make_shared(event_queue_ptr); @@ -134,6 +142,7 @@ void RunServer(std::optional port, bool ignore_cout) { std::make_shared(inference_svc, engine_service); auto config_ctl = std::make_shared(config_service); + drogon::app().registerController(message_ctl); drogon::app().registerController(engine_ctl); drogon::app().registerController(model_ctl); drogon::app().registerController(event_ctl); diff --git a/engine/repositories/message_fs_repository.cc b/engine/repositories/message_fs_repository.cc new file mode 100644 index 000000000..35be9039b --- /dev/null +++ b/engine/repositories/message_fs_repository.cc @@ -0,0 +1,229 @@ +#include "message_fs_repository.h" +#include "utils/file_manager_utils.h" +#include "utils/result.hpp" + +namespace { +constexpr static const std::string_view kMessageFile = "messages.jsonl"; + +inline cpp::result GetMessageFileAbsPath( + const std::string& thread_id) { + auto path = + file_manager_utils::GetThreadsContainerPath() / thread_id / kMessageFile; + if (!std::filesystem::exists(path)) { + return cpp::fail("Message file not exist at path: " + path.string()); + } + return path; +} +} // namespace + +cpp::result MessageFsRepository::CreateMessage( + ThreadMessage::Message& message) { + CTL_INF("CreateMessage for thread " + message.thread_id); + auto path = GetMessageFileAbsPath(message.thread_id); + if (path.has_error()) { + return cpp::fail(path.error()); + } + + std::ofstream file(path->string(), std::ios::app); + if (!file) { + return cpp::failure("Failed to open file for writing: " + path->string()); + } + + auto mutex = GrabMutex(message.thread_id); + std::shared_lock lock(*mutex); + + auto json_str = message.ToSingleLineJsonString(); + if (json_str.has_error()) { + return cpp::fail(json_str.error()); + } + file << json_str.value(); + + file.flush(); + if (file.fail()) { + return cpp::failure("Failed to write to file: " + path->string()); + } + file.close(); + if (file.fail()) { + return cpp::failure("Failed to close file after writing: " + + path->string()); + } + + return {}; +} + +cpp::result, std::string> +MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, + const std::string& order, + const std::string& after, + const std::string& before, + const std::string& run_id) const { + CTL_INF("Listing messages for thread " + thread_id); + auto path = GetMessageFileAbsPath(thread_id); + if (path.has_error()) { + return cpp::fail(path.error()); + } + + auto mutex = GrabMutex(thread_id); + std::shared_lock lock(*mutex); + + return ReadMessageFromFile(thread_id); +} + +cpp::result +MessageFsRepository::RetrieveMessage(const std::string& thread_id, + const std::string& message_id) const { + auto path = GetMessageFileAbsPath(thread_id); + if (path.has_error()) { + return cpp::fail(path.error()); + } + + auto mutex = GrabMutex(thread_id); + std::unique_lock lock(*mutex); + + auto messages = ReadMessageFromFile(thread_id); + if (messages.has_error()) { + return cpp::fail(messages.error()); + } + + for (auto& msg : messages.value()) { + if (msg.id == message_id) { + return std::move(msg); + } + } + + return cpp::failure("Message not found"); +} + +cpp::result MessageFsRepository::ModifyMessage( + ThreadMessage::Message& message) { + auto path = GetMessageFileAbsPath(message.thread_id); + if (path.has_error()) { + return cpp::fail(path.error()); + } + + auto mutex = GrabMutex(message.thread_id); + std::unique_lock lock(*mutex); + + auto messages = ReadMessageFromFile(message.thread_id); + if (messages.has_error()) { + return cpp::fail(messages.error()); + } + + std::ofstream file(path.value().string(), std::ios::trunc); + if (!file) { + return cpp::failure("Failed to open file for writing: " + + path.value().string()); + } + + bool found = false; + for (auto& msg : messages.value()) { + if (msg.id == message.id) { + file << message.ToSingleLineJsonString().value(); + found = true; + } else { + file << msg.ToSingleLineJsonString().value(); + } + } + + file.flush(); + if (file.fail()) { + return cpp::failure("Failed to write to file: " + path->string()); + } + file.close(); + if (file.fail()) { + return cpp::failure("Failed to close file after writing: " + + path->string()); + } + + if (!found) { + return cpp::failure("Message not found"); + } + return {}; +} + +cpp::result MessageFsRepository::DeleteMessage( + const std::string& thread_id, const std::string& message_id) { + auto path = GetMessageFileAbsPath(thread_id); + if (path.has_error()) { + return cpp::fail(path.error()); + } + + auto mutex = GrabMutex(thread_id); + std::unique_lock lock(*mutex); + auto messages = ReadMessageFromFile(thread_id); + if (messages.has_error()) { + return cpp::fail(messages.error()); + } + + std::ofstream file(path.value().string(), std::ios::trunc); + if (!file) { + return cpp::failure("Failed to open file for writing: " + + path.value().string()); + } + + bool found = false; + for (auto& msg : messages.value()) { + if (msg.id != message_id) { + file << msg.ToSingleLineJsonString().value(); + } else { + found = true; + } + } + + file.flush(); + if (file.fail()) { + return cpp::failure("Failed to write to file: " + path->string()); + } + file.close(); + if (file.fail()) { + return cpp::failure("Failed to close file after writing: " + + path->string()); + } + + if (!found) { + return cpp::failure("Message not found"); + } + + return {}; +} + +cpp::result, std::string> +MessageFsRepository::ReadMessageFromFile(const std::string& thread_id) const { + LOG_TRACE << "Reading messages from file for thread " << thread_id; + auto path = GetMessageFileAbsPath(thread_id); + if (path.has_error()) { + return cpp::fail(path.error()); + } + + std::ifstream file(path.value()); + if (!file) { + return cpp::failure("Failed to open file: " + path->string()); + } + + std::vector messages; + std::string line; + while (std::getline(file, line)) { + if (line.empty()) + continue; + auto msg_parse_result = + ThreadMessage::Message::FromJsonString(std::move(line)); + if (msg_parse_result.has_error()) { + CTL_WRN("Failed to parse message: " + msg_parse_result.error()); + continue; + } + + messages.push_back(std::move(msg_parse_result.value())); + } + + return messages; +} + +std::shared_mutex* MessageFsRepository::GrabMutex( + const std::string& thread_id) const { + std::lock_guard lock(mutex_map_mutex_); + auto& thread_mutex = thread_mutexes_[thread_id]; + if (!thread_mutex) { + thread_mutex = std::make_unique(); + } + return thread_mutex.get(); +} diff --git a/engine/repositories/message_fs_repository.h b/engine/repositories/message_fs_repository.h new file mode 100644 index 000000000..d8bcd02a7 --- /dev/null +++ b/engine/repositories/message_fs_repository.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include "common/repository/message_repository.h" + +class MessageFsRepository : public MessageRepository { + public: + cpp::result CreateMessage( + ThreadMessage::Message& message) override; + + cpp::result, std::string> ListMessages( + const std::string& thread_id, uint8_t limit = 20, + const std::string& order = "desc", const std::string& after = "", + const std::string& before = "", + const std::string& run_id = "") const override; + + cpp::result RetrieveMessage( + const std::string& thread_id, + const std::string& message_id) const override; + + cpp::result ModifyMessage( + ThreadMessage::Message& message) override; + + cpp::result DeleteMessage( + const std::string& thread_id, const std::string& message_id) override; + + ~MessageFsRepository() = default; + + private: + cpp::result, std::string> + ReadMessageFromFile(const std::string& thread_id) const; + + std::shared_mutex* GrabMutex(const std::string& thread_id) const; + + mutable std::unordered_map> + thread_mutexes_; + mutable std::mutex mutex_map_mutex_; +}; diff --git a/engine/services/hardware_service.cc b/engine/services/hardware_service.cc index 02693a48d..10e563bd3 100644 --- a/engine/services/hardware_service.cc +++ b/engine/services/hardware_service.cc @@ -8,7 +8,6 @@ #endif #include "cli/commands/cortex_upd_cmd.h" #include "database/hardware.h" -#include "services/engine_service.h" #include "utils/cortex_utils.h" namespace services { @@ -316,4 +315,4 @@ bool HardwareService::IsValidConfig( } return true; } -} // namespace services \ No newline at end of file +} // namespace services diff --git a/engine/services/message_service.cc b/engine/services/message_service.cc new file mode 100644 index 000000000..31ae38420 --- /dev/null +++ b/engine/services/message_service.cc @@ -0,0 +1,105 @@ +#include "services/message_service.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" +#include "utils/ulid/ulid.hh" + +cpp::result MessageService::CreateMessage( + const std::string& thread_id, const ThreadMessage::Role& role, + std::variant>>&& + content, + std::optional> attachments, + std::optional metadata) { + LOG_TRACE << "CreateMessage for thread " << thread_id; + auto now = std::chrono::system_clock::now(); + auto seconds_since_epoch = + std::chrono::duration_cast(now.time_since_epoch()) + .count(); + std::vector> content_list{}; + // if content is string + if (std::holds_alternative(content)) { + auto text_content = std::make_unique(); + text_content->text.value = std::get(content); + content_list.push_back(std::move(text_content)); + } else { + content_list = std::move( + std::get>>( + content)); + } + + ulid::ULID ulid = ulid::Create(seconds_since_epoch, []() { return 4; }); + std::string str = ulid::Marshal(ulid); + LOG_TRACE << "Generated message ID: " << str; + + ThreadMessage::Message msg; + msg.id = str; + msg.object = "thread.message"; + msg.created_at = 0; + msg.thread_id = thread_id; + msg.status = ThreadMessage::Status::COMPLETED; + msg.completed_at = seconds_since_epoch; + msg.incomplete_at = std::nullopt; + msg.incomplete_details = std::nullopt; + msg.role = role; + msg.content = std::move(content_list); + msg.assistant_id = std::nullopt; + msg.run_id = std::nullopt; + msg.attachments = attachments; + msg.metadata = metadata.value_or(Cortex::VariantMap{}); + auto res = message_repository_->CreateMessage(msg); + if (res.has_error()) { + return cpp::fail("Failed to create message: " + res.error()); + } else { + return msg; + } +} + +cpp::result, std::string> +MessageService::ListMessages(const std::string& thread_id, uint8_t limit, + const std::string& order, const std::string& after, + const std::string& before, + const std::string& run_id) const { + CTL_INF("ListMessages for thread " + thread_id); + return message_repository_->ListMessages(thread_id); +} + +cpp::result +MessageService::RetrieveMessage(const std::string& thread_id, + const std::string& message_id) const { + CTL_INF("RetrieveMessage for thread " + thread_id); + return message_repository_->RetrieveMessage(thread_id, message_id); +} + +cpp::result MessageService::ModifyMessage( + const std::string& thread_id, const std::string& message_id, + std::optional metadata) { + LOG_TRACE << "ModifyMessage for thread " << thread_id << ", message " + << message_id; + auto msg = RetrieveMessage(thread_id, message_id); + if (msg.has_error()) { + return cpp::fail("Failed to retrieve message: " + msg.error()); + } + + msg->metadata = metadata.value(); + auto ptr = &msg.value(); + + auto res = message_repository_->ModifyMessage(msg.value()); + if (res.has_error()) { + CTL_ERR("Failed to modify message: " + res.error()); + return cpp::fail("Failed to modify message: " + res.error()); + } else { + return RetrieveMessage(thread_id, message_id); + } +} + +cpp::result MessageService::DeleteMessage( + const std::string& thread_id, const std::string& message_id) { + LOG_TRACE << "DeleteMessage for thread " + thread_id; + auto res = message_repository_->DeleteMessage(thread_id, message_id); + if (res.has_error()) { + LOG_ERROR << "Failed to delete message: " + res.error(); + return cpp::fail("Failed to delete message: " + res.error()); + } else { + return message_id; + } +} diff --git a/engine/services/message_service.h b/engine/services/message_service.h new file mode 100644 index 000000000..e62970b54 --- /dev/null +++ b/engine/services/message_service.h @@ -0,0 +1,39 @@ +#pragma once + +#include "common/repository/message_repository.h" +#include "common/variant_map.h" +#include "utils/result.hpp" + +class MessageService { + public: + explicit MessageService(std::shared_ptr message_repository) + : message_repository_{message_repository} {} + + cpp::result CreateMessage( + const std::string& thread_id, const ThreadMessage::Role& role, + std::variant>>&& + content, + std::optional> attachments, + std::optional metadata); + + cpp::result, std::string> ListMessages( + const std::string& thread_id, uint8_t limit = 20, + const std::string& order = "desc", const std::string& after = "", + const std::string& before = "", const std::string& run_id = "") const; + + cpp::result RetrieveMessage( + const std::string& thread_id, const std::string& message_id) const; + + cpp::result ModifyMessage( + const std::string& thread_id, const std::string& message_id, + std::optional>> + metadata); + + cpp::result DeleteMessage( + const std::string& thread_id, const std::string& message_id); + + private: + std::shared_ptr message_repository_; +}; diff --git a/engine/utils/file_manager_utils.h b/engine/utils/file_manager_utils.h index 399afcfa6..8ff5fe971 100644 --- a/engine/utils/file_manager_utils.h +++ b/engine/utils/file_manager_utils.h @@ -265,6 +265,11 @@ inline void CreateDirectoryRecursively(const std::string& path) { } } +inline std::filesystem::path GetThreadsContainerPath() { + auto cortex_path = GetCortexDataPath(); + return cortex_path / "threads"; +} + inline std::filesystem::path GetModelsContainerPath() { auto result = CreateConfigFileIfNotExist(); if (result.has_error()) { diff --git a/engine/utils/ulid/ulid.hh b/engine/utils/ulid/ulid.hh new file mode 100644 index 000000000..22b6f19b5 --- /dev/null +++ b/engine/utils/ulid/ulid.hh @@ -0,0 +1,16 @@ +#ifndef ULID_HH +#define ULID_HH + +// https://github.com/suyash/ulid +// http://stackoverflow.com/a/23981011 +#ifdef __SIZEOF_INT128__ +#define ULIDUINT128 +#endif + +#ifdef ULIDUINT128 +#include "ulid_uint128.hh" +#else +#include "ulid_struct.hh" +#endif // ULIDUINT128 + +#endif // ULID_HH diff --git a/engine/utils/ulid/ulid_struct.hh b/engine/utils/ulid/ulid_struct.hh new file mode 100644 index 000000000..ad0da59ec --- /dev/null +++ b/engine/utils/ulid/ulid_struct.hh @@ -0,0 +1,710 @@ +#ifndef ULID_STRUCT_HH +#define ULID_STRUCT_HH + +#include +#include +#include +#include +#include +#include + +#if _MSC_VER > 0 +typedef uint32_t rand_t; +#else +typedef uint8_t rand_t; +#endif + +namespace ulid { + +/** + * ULID is a 16 byte Universally Unique Lexicographically Sortable Identifier + * */ +struct ULID { + uint8_t data[16]; + + ULID() { + // for (int i = 0 ; i < 16 ; i++) { + // data[i] = 0; + // } + + // unrolled loop + data[0] = 0; + data[1] = 0; + data[2] = 0; + data[3] = 0; + data[4] = 0; + data[5] = 0; + data[6] = 0; + data[7] = 0; + data[8] = 0; + data[9] = 0; + data[10] = 0; + data[11] = 0; + data[12] = 0; + data[13] = 0; + data[14] = 0; + data[15] = 0; + } + + ULID(uint64_t val) { + // for (int i = 0 ; i < 16 ; i++) { + // data[15 - i] = static_cast(val); + // val >>= 8; + // } + + // unrolled loop + data[15] = static_cast(val); + + val >>= 8; + data[14] = static_cast(val); + + val >>= 8; + data[13] = static_cast(val); + + val >>= 8; + data[12] = static_cast(val); + + val >>= 8; + data[11] = static_cast(val); + + val >>= 8; + data[10] = static_cast(val); + + val >>= 8; + data[9] = static_cast(val); + + val >>= 8; + data[8] = static_cast(val); + + data[7] = 0; + data[6] = 0; + data[5] = 0; + data[4] = 0; + data[3] = 0; + data[2] = 0; + data[1] = 0; + data[0] = 0; + } + + ULID(const ULID& other) { + // for (int i = 0 ; i < 16 ; i++) { + // data[i] = other.data[i]; + // } + + // unrolled loop + data[0] = other.data[0]; + data[1] = other.data[1]; + data[2] = other.data[2]; + data[3] = other.data[3]; + data[4] = other.data[4]; + data[5] = other.data[5]; + data[6] = other.data[6]; + data[7] = other.data[7]; + data[8] = other.data[8]; + data[9] = other.data[9]; + data[10] = other.data[10]; + data[11] = other.data[11]; + data[12] = other.data[12]; + data[13] = other.data[13]; + data[14] = other.data[14]; + data[15] = other.data[15]; + } + + ULID& operator=(const ULID& other) { + // for (int i = 0 ; i < 16 ; i++) { + // data[i] = other.data[i]; + // } + + // unrolled loop + data[0] = other.data[0]; + data[1] = other.data[1]; + data[2] = other.data[2]; + data[3] = other.data[3]; + data[4] = other.data[4]; + data[5] = other.data[5]; + data[6] = other.data[6]; + data[7] = other.data[7]; + data[8] = other.data[8]; + data[9] = other.data[9]; + data[10] = other.data[10]; + data[11] = other.data[11]; + data[12] = other.data[12]; + data[13] = other.data[13]; + data[14] = other.data[14]; + data[15] = other.data[15]; + + return *this; + } + + ULID(ULID&& other) { + // for (int i = 0 ; i < 16 ; i++) { + // data[i] = other.data[i]; + // other.data[i] = 0; + // } + + // unrolled loop + data[0] = other.data[0]; + other.data[0] = 0; + + data[1] = other.data[1]; + other.data[1] = 0; + + data[2] = other.data[2]; + other.data[2] = 0; + + data[3] = other.data[3]; + other.data[3] = 0; + + data[4] = other.data[4]; + other.data[4] = 0; + + data[5] = other.data[5]; + other.data[5] = 0; + + data[6] = other.data[6]; + other.data[6] = 0; + + data[7] = other.data[7]; + other.data[7] = 0; + + data[8] = other.data[8]; + other.data[8] = 0; + + data[9] = other.data[9]; + other.data[9] = 0; + + data[10] = other.data[10]; + other.data[10] = 0; + + data[11] = other.data[11]; + other.data[11] = 0; + + data[12] = other.data[12]; + other.data[12] = 0; + + data[13] = other.data[13]; + other.data[13] = 0; + + data[14] = other.data[14]; + other.data[14] = 0; + + data[15] = other.data[15]; + other.data[15] = 0; + } + + ULID& operator=(ULID&& other) { + // for (int i = 0 ; i < 16 ; i++) { + // data[i] = other.data[i]; + // other.data[i] = 0; + // } + + // unrolled loop + data[0] = other.data[0]; + other.data[0] = 0; + + data[1] = other.data[1]; + other.data[1] = 0; + + data[2] = other.data[2]; + other.data[2] = 0; + + data[3] = other.data[3]; + other.data[3] = 0; + + data[4] = other.data[4]; + other.data[4] = 0; + + data[5] = other.data[5]; + other.data[5] = 0; + + data[6] = other.data[6]; + other.data[6] = 0; + + data[7] = other.data[7]; + other.data[7] = 0; + + data[8] = other.data[8]; + other.data[8] = 0; + + data[9] = other.data[9]; + other.data[9] = 0; + + data[10] = other.data[10]; + other.data[10] = 0; + + data[11] = other.data[11]; + other.data[11] = 0; + + data[12] = other.data[12]; + other.data[12] = 0; + + data[13] = other.data[13]; + other.data[13] = 0; + + data[14] = other.data[14]; + other.data[14] = 0; + + data[15] = other.data[15]; + other.data[15] = 0; + + return *this; + } +}; + +/** + * EncodeTime will encode the first 6 bytes of a uint8_t array to the passed + * timestamp + * */ +inline void EncodeTime(time_t timestamp, ULID& ulid) { + ulid.data[0] = static_cast(timestamp >> 40); + ulid.data[1] = static_cast(timestamp >> 32); + ulid.data[2] = static_cast(timestamp >> 24); + ulid.data[3] = static_cast(timestamp >> 16); + ulid.data[4] = static_cast(timestamp >> 8); + ulid.data[5] = static_cast(timestamp); +} + +/** + * EncodeTimeNow will encode a ULID using the time obtained using std::time(nullptr) + * */ +inline void EncodeTimeNow(ULID& ulid) { + EncodeTime(std::time(nullptr), ulid); +} + +/** + * EncodeTimeSystemClockNow will encode a ULID using the time obtained using + * std::chrono::system_clock::now() by taking the timestamp in milliseconds. + * */ +inline void EncodeTimeSystemClockNow(ULID& ulid) { + auto now = std::chrono::system_clock::now(); + auto ms = std::chrono::duration_cast( + now.time_since_epoch()); + EncodeTime(ms.count(), ulid); +} + +/** + * EncodeEntropy will encode the last 10 bytes of the passed uint8_t array with + * the values generated using the passed random number generator. + * */ +inline void EncodeEntropy(const std::function& rng, ULID& ulid) { + ulid.data[6] = rng(); + ulid.data[7] = rng(); + ulid.data[8] = rng(); + ulid.data[9] = rng(); + ulid.data[10] = rng(); + ulid.data[11] = rng(); + ulid.data[12] = rng(); + ulid.data[13] = rng(); + ulid.data[14] = rng(); + ulid.data[15] = rng(); +} + +/** + * EncodeEntropyRand will encode a ulid using std::rand + * + * std::rand returns values in [0, RAND_MAX] + * */ +inline void EncodeEntropyRand(ULID& ulid) { + ulid.data[6] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[7] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[8] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[9] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[10] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[11] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[12] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[13] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[14] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[15] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; +} + +static std::uniform_int_distribution Distribution_0_255(0, 255); + +/** + * EncodeEntropyMt19937 will encode a ulid using std::mt19937 + * + * It also creates a std::uniform_int_distribution to generate values in [0, 255] + * */ +inline void EncodeEntropyMt19937(std::mt19937& generator, ULID& ulid) { + ulid.data[6] = Distribution_0_255(generator); + ulid.data[7] = Distribution_0_255(generator); + ulid.data[8] = Distribution_0_255(generator); + ulid.data[9] = Distribution_0_255(generator); + ulid.data[10] = Distribution_0_255(generator); + ulid.data[11] = Distribution_0_255(generator); + ulid.data[12] = Distribution_0_255(generator); + ulid.data[13] = Distribution_0_255(generator); + ulid.data[14] = Distribution_0_255(generator); + ulid.data[15] = Distribution_0_255(generator); +} + +/** + * Encode will create an encoded ULID with a timestamp and a generator. + * */ +inline void Encode(time_t timestamp, const std::function& rng, + ULID& ulid) { + EncodeTime(timestamp, ulid); + EncodeEntropy(rng, ulid); +} + +/** + * EncodeNowRand = EncodeTimeNow + EncodeEntropyRand. + * */ +inline void EncodeNowRand(ULID& ulid) { + EncodeTimeNow(ulid); + EncodeEntropyRand(ulid); +} + +/** + * Create will create a ULID with a timestamp and a generator. + * */ +inline ULID Create(time_t timestamp, const std::function& rng) { + ULID ulid; + Encode(timestamp, rng, ulid); + return ulid; +} + +/** + * CreateNowRand:EncodeNowRand = Create:Encode. + * */ +inline ULID CreateNowRand() { + ULID ulid; + EncodeNowRand(ulid); + return ulid; +} + +/** + * Crockford's Base32 + * */ +static const char Encoding[33] = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"; + +/** + * MarshalTo will marshal a ULID to the passed character array. + * + * Implementation taken directly from oklog/ulid + * (https://sourcegraph.com/github.com/oklog/ulid@0774f81f6e44af5ce5e91c8d7d76cf710e889ebb/-/blob/ulid.go#L162-190) + * + * timestamp:
+ * dst[0]: first 3 bits of data[0]
+ * dst[1]: last 5 bits of data[0]
+ * dst[2]: first 5 bits of data[1]
+ * dst[3]: last 3 bits of data[1] + first 2 bits of data[2]
+ * dst[4]: bits 3-7 of data[2]
+ * dst[5]: last bit of data[2] + first 4 bits of data[3]
+ * dst[6]: last 4 bits of data[3] + first bit of data[4]
+ * dst[7]: bits 2-6 of data[4]
+ * dst[8]: last 2 bits of data[4] + first 3 bits of data[5]
+ * dst[9]: last 5 bits of data[5]
+ * + * entropy: + * follows similarly, except now all components are set to 5 bits. + * */ +inline void MarshalTo(const ULID& ulid, char dst[26]) { + // 10 byte timestamp + dst[0] = Encoding[(ulid.data[0] & 224) >> 5]; + dst[1] = Encoding[ulid.data[0] & 31]; + dst[2] = Encoding[(ulid.data[1] & 248) >> 3]; + dst[3] = Encoding[((ulid.data[1] & 7) << 2) | ((ulid.data[2] & 192) >> 6)]; + dst[4] = Encoding[(ulid.data[2] & 62) >> 1]; + dst[5] = Encoding[((ulid.data[2] & 1) << 4) | ((ulid.data[3] & 240) >> 4)]; + dst[6] = Encoding[((ulid.data[3] & 15) << 1) | ((ulid.data[4] & 128) >> 7)]; + dst[7] = Encoding[(ulid.data[4] & 124) >> 2]; + dst[8] = Encoding[((ulid.data[4] & 3) << 3) | ((ulid.data[5] & 224) >> 5)]; + dst[9] = Encoding[ulid.data[5] & 31]; + + // 16 bytes of entropy + dst[10] = Encoding[(ulid.data[6] & 248) >> 3]; + dst[11] = Encoding[((ulid.data[6] & 7) << 2) | ((ulid.data[7] & 192) >> 6)]; + dst[12] = Encoding[(ulid.data[7] & 62) >> 1]; + dst[13] = Encoding[((ulid.data[7] & 1) << 4) | ((ulid.data[8] & 240) >> 4)]; + dst[14] = Encoding[((ulid.data[8] & 15) << 1) | ((ulid.data[9] & 128) >> 7)]; + dst[15] = Encoding[(ulid.data[9] & 124) >> 2]; + dst[16] = Encoding[((ulid.data[9] & 3) << 3) | ((ulid.data[10] & 224) >> 5)]; + dst[17] = Encoding[ulid.data[10] & 31]; + dst[18] = Encoding[(ulid.data[11] & 248) >> 3]; + dst[19] = Encoding[((ulid.data[11] & 7) << 2) | ((ulid.data[12] & 192) >> 6)]; + dst[20] = Encoding[(ulid.data[12] & 62) >> 1]; + dst[21] = Encoding[((ulid.data[12] & 1) << 4) | ((ulid.data[13] & 240) >> 4)]; + dst[22] = + Encoding[((ulid.data[13] & 15) << 1) | ((ulid.data[14] & 128) >> 7)]; + dst[23] = Encoding[(ulid.data[14] & 124) >> 2]; + dst[24] = Encoding[((ulid.data[14] & 3) << 3) | ((ulid.data[15] & 224) >> 5)]; + dst[25] = Encoding[ulid.data[15] & 31]; +} + +/** + * Marshal will marshal a ULID to a std::string. + * */ +inline std::string Marshal(const ULID& ulid) { + char data[27]; + data[26] = '\0'; + MarshalTo(ulid, data); + return std::string(data); +} + +/** + * MarshalBinaryTo will Marshal a ULID to the passed byte array + * */ +inline void MarshalBinaryTo(const ULID& ulid, uint8_t dst[16]) { + // timestamp + dst[0] = ulid.data[0]; + dst[1] = ulid.data[1]; + dst[2] = ulid.data[2]; + dst[3] = ulid.data[3]; + dst[4] = ulid.data[4]; + dst[5] = ulid.data[5]; + + // entropy + dst[6] = ulid.data[6]; + dst[7] = ulid.data[7]; + dst[8] = ulid.data[8]; + dst[9] = ulid.data[9]; + dst[10] = ulid.data[10]; + dst[11] = ulid.data[11]; + dst[12] = ulid.data[12]; + dst[13] = ulid.data[13]; + dst[14] = ulid.data[14]; + dst[15] = ulid.data[15]; +} + +/** + * MarshalBinary will Marshal a ULID to a byte vector. + * */ +inline std::vector MarshalBinary(const ULID& ulid) { + std::vector dst(16); + MarshalBinaryTo(ulid, dst.data()); + return dst; +} + +/** + * dec storesdecimal encodings for characters. + * 0xFF indicates invalid character. + * 48-57 are digits. + * 65-90 are capital alphabets. + * */ +static const uint8_t dec[256] = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, + /* 0 1 2 3 4 5 6 7 */ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + /* 8 9 */ + 0x08, 0x09, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + /* 10(A) 11(B) 12(C) 13(D) 14(E) 15(F) 16(G) */ + 0xFF, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + /*17(H) 18(J) 19(K) 20(M) 21(N) */ + 0x11, 0xFF, 0x12, 0x13, 0xFF, 0x14, 0x15, 0xFF, + /*22(P)23(Q)24(R) 25(S) 26(T) 27(V) 28(W) */ + 0x16, 0x17, 0x18, 0x19, 0x1A, 0xFF, 0x1B, 0x1C, + /*29(X)30(Y)31(Z) */ + 0x1D, 0x1E, 0x1F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; + +/** + * UnmarshalFrom will unmarshal a ULID from the passed character array. + * */ +inline void UnmarshalFrom(const char str[26], ULID& ulid) { + // timestamp + ulid.data[0] = (dec[int(str[0])] << 5) | dec[int(str[1])]; + ulid.data[1] = (dec[int(str[2])] << 3) | (dec[int(str[3])] >> 2); + ulid.data[2] = (dec[int(str[3])] << 6) | (dec[int(str[4])] << 1) | + (dec[int(str[5])] >> 4); + ulid.data[3] = (dec[int(str[5])] << 4) | (dec[int(str[6])] >> 1); + ulid.data[4] = (dec[int(str[6])] << 7) | (dec[int(str[7])] << 2) | + (dec[int(str[8])] >> 3); + ulid.data[5] = (dec[int(str[8])] << 5) | dec[int(str[9])]; + + // entropy + ulid.data[6] = (dec[int(str[10])] << 3) | (dec[int(str[11])] >> 2); + ulid.data[7] = (dec[int(str[11])] << 6) | (dec[int(str[12])] << 1) | + (dec[int(str[13])] >> 4); + ulid.data[8] = (dec[int(str[13])] << 4) | (dec[int(str[14])] >> 1); + ulid.data[9] = (dec[int(str[14])] << 7) | (dec[int(str[15])] << 2) | + (dec[int(str[16])] >> 3); + ulid.data[10] = (dec[int(str[16])] << 5) | dec[int(str[17])]; + ulid.data[11] = (dec[int(str[18])] << 3) | (dec[int(str[19])] >> 2); + ulid.data[12] = (dec[int(str[19])] << 6) | (dec[int(str[20])] << 1) | + (dec[int(str[21])] >> 4); + ulid.data[13] = (dec[int(str[21])] << 4) | (dec[int(str[22])] >> 1); + ulid.data[14] = (dec[int(str[22])] << 7) | (dec[int(str[23])] << 2) | + (dec[int(str[24])] >> 3); + ulid.data[15] = (dec[int(str[24])] << 5) | dec[int(str[25])]; +} + +/** + * Unmarshal will create a new ULID by unmarshaling the passed string. + * */ +inline ULID Unmarshal(const std::string& str) { + ULID ulid; + UnmarshalFrom(str.c_str(), ulid); + return ulid; +} + +/** + * UnmarshalBinaryFrom will unmarshal a ULID from the passed byte array. + * */ +inline void UnmarshalBinaryFrom(const uint8_t b[16], ULID& ulid) { + // timestamp + ulid.data[0] = b[0]; + ulid.data[1] = b[1]; + ulid.data[2] = b[2]; + ulid.data[3] = b[3]; + ulid.data[4] = b[4]; + ulid.data[5] = b[5]; + + // entropy + ulid.data[6] = b[6]; + ulid.data[7] = b[7]; + ulid.data[8] = b[8]; + ulid.data[9] = b[9]; + ulid.data[10] = b[10]; + ulid.data[11] = b[11]; + ulid.data[12] = b[12]; + ulid.data[13] = b[13]; + ulid.data[14] = b[14]; + ulid.data[15] = b[15]; +} + +/** + * Unmarshal will create a new ULID by unmarshaling the passed byte vector. + * */ +inline ULID UnmarshalBinary(const std::vector& b) { + ULID ulid; + UnmarshalBinaryFrom(b.data(), ulid); + return ulid; +} + +/** + * CompareULIDs will compare two ULIDs. + * returns: + * -1 if ulid1 is Lexicographically before ulid2 + * 1 if ulid1 is Lexicographically after ulid2 + * 0 if ulid1 is same as ulid2 + * */ +inline int CompareULIDs(const ULID& ulid1, const ULID& ulid2) { + // for (int i = 0 ; i < 16 ; i++) { + // if (ulid1.data[i] != ulid2.data[i]) { + // return (ulid1.data[i] < ulid2.data[i]) * -2 + 1; + // } + // } + + // unrolled loop + + if (ulid1.data[0] != ulid2.data[0]) { + return (ulid1.data[0] < ulid2.data[0]) * -2 + 1; + } + + if (ulid1.data[1] != ulid2.data[1]) { + return (ulid1.data[1] < ulid2.data[1]) * -2 + 1; + } + + if (ulid1.data[2] != ulid2.data[2]) { + return (ulid1.data[2] < ulid2.data[2]) * -2 + 1; + } + + if (ulid1.data[3] != ulid2.data[3]) { + return (ulid1.data[3] < ulid2.data[3]) * -2 + 1; + } + + if (ulid1.data[4] != ulid2.data[4]) { + return (ulid1.data[4] < ulid2.data[4]) * -2 + 1; + } + + if (ulid1.data[5] != ulid2.data[5]) { + return (ulid1.data[5] < ulid2.data[5]) * -2 + 1; + } + + if (ulid1.data[6] != ulid2.data[6]) { + return (ulid1.data[6] < ulid2.data[6]) * -2 + 1; + } + + if (ulid1.data[7] != ulid2.data[7]) { + return (ulid1.data[7] < ulid2.data[7]) * -2 + 1; + } + + if (ulid1.data[8] != ulid2.data[8]) { + return (ulid1.data[8] < ulid2.data[8]) * -2 + 1; + } + + if (ulid1.data[9] != ulid2.data[9]) { + return (ulid1.data[9] < ulid2.data[9]) * -2 + 1; + } + + if (ulid1.data[10] != ulid2.data[10]) { + return (ulid1.data[10] < ulid2.data[10]) * -2 + 1; + } + + if (ulid1.data[11] != ulid2.data[11]) { + return (ulid1.data[11] < ulid2.data[11]) * -2 + 1; + } + + if (ulid1.data[12] != ulid2.data[12]) { + return (ulid1.data[12] < ulid2.data[12]) * -2 + 1; + } + + if (ulid1.data[13] != ulid2.data[13]) { + return (ulid1.data[13] < ulid2.data[13]) * -2 + 1; + } + + if (ulid1.data[14] != ulid2.data[14]) { + return (ulid1.data[14] < ulid2.data[14]) * -2 + 1; + } + + if (ulid1.data[15] != ulid2.data[15]) { + return (ulid1.data[15] < ulid2.data[15]) * -2 + 1; + } + + return 0; +} + +/** + * Time will extract the timestamp used to generate a ULID + * */ +inline time_t Time(const ULID& ulid) { + time_t ans = 0; + + ans |= ulid.data[0]; + + ans <<= 8; + ans |= ulid.data[1]; + + ans <<= 8; + ans |= ulid.data[2]; + + ans <<= 8; + ans |= ulid.data[3]; + + ans <<= 8; + ans |= ulid.data[4]; + + ans <<= 8; + ans |= ulid.data[5]; + + return ans; +} + +}; // namespace ulid + +#endif // ULID_STRUCT_HH diff --git a/engine/utils/ulid/ulid_uint128.hh b/engine/utils/ulid/ulid_uint128.hh new file mode 100644 index 000000000..b3f200141 --- /dev/null +++ b/engine/utils/ulid/ulid_uint128.hh @@ -0,0 +1,561 @@ +#ifndef ULID_UINT128_HH +#define ULID_UINT128_HH + +#include +#include +#include +#include +#include +#include + +#if _MSC_VER > 0 +typedef uint32_t rand_t; +#else +typedef uint8_t rand_t; +#endif + +namespace ulid { + +/** + * ULID is a 16 byte Universally Unique Lexicographically Sortable Identifier + * */ +typedef __uint128_t ULID; + +/** + * EncodeTime will encode the first 6 bytes of a uint8_t array to the passed + * timestamp + * */ +inline void EncodeTime(time_t timestamp, ULID& ulid) { + ULID t = static_cast(timestamp >> 40); + + t <<= 8; + t |= static_cast(timestamp >> 32); + + t <<= 8; + t |= static_cast(timestamp >> 24); + + t <<= 8; + t |= static_cast(timestamp >> 16); + + t <<= 8; + t |= static_cast(timestamp >> 8); + + t <<= 8; + t |= static_cast(timestamp); + + t <<= 80; + + ULID mask = 1; + mask <<= 80; + mask--; + + ulid = t | (ulid & mask); +} + +/** + * EncodeTimeNow will encode a ULID using the time obtained using std::time(nullptr) + * */ +inline void EncodeTimeNow(ULID& ulid) { + EncodeTime(std::time(nullptr), ulid); +} + +/** + * EncodeTimeSystemClockNow will encode a ULID using the time obtained using + * std::chrono::system_clock::now() by taking the timestamp in milliseconds. + * */ +inline void EncodeTimeSystemClockNow(ULID& ulid) { + auto now = std::chrono::system_clock::now(); + auto ms = std::chrono::duration_cast( + now.time_since_epoch()); + EncodeTime(ms.count(), ulid); +} + +/** + * EncodeEntropy will encode the last 10 bytes of the passed uint8_t array with + * the values generated using the passed random number generator. + * */ +inline void EncodeEntropy(const std::function& rng, ULID& ulid) { + ulid = (ulid >> 80) << 80; + + ULID e = rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + ulid |= e; +} + +/** + * EncodeEntropyRand will encode a ulid using std::rand + * + * std::rand returns values in [0, RAND_MAX] + * */ +inline void EncodeEntropyRand(ULID& ulid) { + ulid = (ulid >> 80) << 80; + + ULID e = (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + ulid |= e; +} + +static std::uniform_int_distribution Distribution_0_255(0, 255); + +/** + * EncodeEntropyMt19937 will encode a ulid using std::mt19937 + * + * It also creates a std::uniform_int_distribution to generate values in [0, 255] + * */ +inline void EncodeEntropyMt19937(std::mt19937& generator, ULID& ulid) { + ulid = (ulid >> 80) << 80; + + ULID e = Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + ulid |= e; +} + +/** + * Encode will create an encoded ULID with a timestamp and a generator. + * */ +inline void Encode(time_t timestamp, const std::function& rng, + ULID& ulid) { + EncodeTime(timestamp, ulid); + EncodeEntropy(rng, ulid); +} + +/** + * EncodeNowRand = EncodeTimeNow + EncodeEntropyRand. + * */ +inline void EncodeNowRand(ULID& ulid) { + EncodeTimeNow(ulid); + EncodeEntropyRand(ulid); +} + +/** + * Create will create a ULID with a timestamp and a generator. + * */ +inline ULID Create(time_t timestamp, const std::function& rng) { + ULID ulid = 0; + Encode(timestamp, rng, ulid); + return ulid; +} + +/** + * CreateNowRand:EncodeNowRand = Create:Encode. + * */ +inline ULID CreateNowRand() { + ULID ulid = 0; + EncodeNowRand(ulid); + return ulid; +} + +/** + * Crockford's Base32 + * */ +static const char Encoding[33] = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"; + +/** + * MarshalTo will marshal a ULID to the passed character array. + * + * Implementation taken directly from oklog/ulid + * (https://sourcegraph.com/github.com/oklog/ulid@0774f81f6e44af5ce5e91c8d7d76cf710e889ebb/-/blob/ulid.go#L162-190) + * + * timestamp: + * dst[0]: first 3 bits of data[0] + * dst[1]: last 5 bits of data[0] + * dst[2]: first 5 bits of data[1] + * dst[3]: last 3 bits of data[1] + first 2 bits of data[2] + * dst[4]: bits 3-7 of data[2] + * dst[5]: last bit of data[2] + first 4 bits of data[3] + * dst[6]: last 4 bits of data[3] + first bit of data[4] + * dst[7]: bits 2-6 of data[4] + * dst[8]: last 2 bits of data[4] + first 3 bits of data[5] + * dst[9]: last 5 bits of data[5] + * + * entropy: + * follows similarly, except now all components are set to 5 bits. + * */ +inline void MarshalTo(const ULID& ulid, char dst[26]) { + // 10 byte timestamp + dst[0] = Encoding[(static_cast(ulid >> 120) & 224) >> 5]; + dst[1] = Encoding[static_cast(ulid >> 120) & 31]; + dst[2] = Encoding[(static_cast(ulid >> 112) & 248) >> 3]; + dst[3] = Encoding[((static_cast(ulid >> 112) & 7) << 2) | + ((static_cast(ulid >> 104) & 192) >> 6)]; + dst[4] = Encoding[(static_cast(ulid >> 104) & 62) >> 1]; + dst[5] = Encoding[((static_cast(ulid >> 104) & 1) << 4) | + ((static_cast(ulid >> 96) & 240) >> 4)]; + dst[6] = Encoding[((static_cast(ulid >> 96) & 15) << 1) | + ((static_cast(ulid >> 88) & 128) >> 7)]; + dst[7] = Encoding[(static_cast(ulid >> 88) & 124) >> 2]; + dst[8] = Encoding[((static_cast(ulid >> 88) & 3) << 3) | + ((static_cast(ulid >> 80) & 224) >> 5)]; + dst[9] = Encoding[static_cast(ulid >> 80) & 31]; + + // 16 bytes of entropy + dst[10] = Encoding[(static_cast(ulid >> 72) & 248) >> 3]; + dst[11] = Encoding[((static_cast(ulid >> 72) & 7) << 2) | + ((static_cast(ulid >> 64) & 192) >> 6)]; + dst[12] = Encoding[(static_cast(ulid >> 64) & 62) >> 1]; + dst[13] = Encoding[((static_cast(ulid >> 64) & 1) << 4) | + ((static_cast(ulid >> 56) & 240) >> 4)]; + dst[14] = Encoding[((static_cast(ulid >> 56) & 15) << 1) | + ((static_cast(ulid >> 48) & 128) >> 7)]; + dst[15] = Encoding[(static_cast(ulid >> 48) & 124) >> 2]; + dst[16] = Encoding[((static_cast(ulid >> 48) & 3) << 3) | + ((static_cast(ulid >> 40) & 224) >> 5)]; + dst[17] = Encoding[static_cast(ulid >> 40) & 31]; + dst[18] = Encoding[(static_cast(ulid >> 32) & 248) >> 3]; + dst[19] = Encoding[((static_cast(ulid >> 32) & 7) << 2) | + ((static_cast(ulid >> 24) & 192) >> 6)]; + dst[20] = Encoding[(static_cast(ulid >> 24) & 62) >> 1]; + dst[21] = Encoding[((static_cast(ulid >> 24) & 1) << 4) | + ((static_cast(ulid >> 16) & 240) >> 4)]; + dst[22] = Encoding[((static_cast(ulid >> 16) & 15) << 1) | + ((static_cast(ulid >> 8) & 128) >> 7)]; + dst[23] = Encoding[(static_cast(ulid >> 8) & 124) >> 2]; + dst[24] = Encoding[((static_cast(ulid >> 8) & 3) << 3) | + (((static_cast(ulid)) & 224) >> 5)]; + dst[25] = Encoding[(static_cast(ulid)) & 31]; +} + +/** + * Marshal will marshal a ULID to a std::string. + * */ +inline std::string Marshal(const ULID& ulid) { + char data[27]; + data[26] = '\0'; + MarshalTo(ulid, data); + return std::string(data); +} + +/** + * MarshalBinaryTo will Marshal a ULID to the passed byte array + * */ +inline void MarshalBinaryTo(const ULID& ulid, uint8_t dst[16]) { + // timestamp + dst[0] = static_cast(ulid >> 120); + dst[1] = static_cast(ulid >> 112); + dst[2] = static_cast(ulid >> 104); + dst[3] = static_cast(ulid >> 96); + dst[4] = static_cast(ulid >> 88); + dst[5] = static_cast(ulid >> 80); + + // entropy + dst[6] = static_cast(ulid >> 72); + dst[7] = static_cast(ulid >> 64); + dst[8] = static_cast(ulid >> 56); + dst[9] = static_cast(ulid >> 48); + dst[10] = static_cast(ulid >> 40); + dst[11] = static_cast(ulid >> 32); + dst[12] = static_cast(ulid >> 24); + dst[13] = static_cast(ulid >> 16); + dst[14] = static_cast(ulid >> 8); + dst[15] = static_cast(ulid); +} + +/** + * MarshalBinary will Marshal a ULID to a byte vector. + * */ +inline std::vector MarshalBinary(const ULID& ulid) { + std::vector dst(16); + MarshalBinaryTo(ulid, dst.data()); + return dst; +} + +/** + * dec storesdecimal encodings for characters. + * 0xFF indicates invalid character. + * 48-57 are digits. + * 65-90 are capital alphabets. + * */ +static const uint8_t dec[256] = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, + /* 0 1 2 3 4 5 6 7 */ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + /* 8 9 */ + 0x08, 0x09, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + /* 10(A) 11(B) 12(C) 13(D) 14(E) 15(F) 16(G) */ + 0xFF, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + /*17(H) 18(J) 19(K) 20(M) 21(N) */ + 0x11, 0xFF, 0x12, 0x13, 0xFF, 0x14, 0x15, 0xFF, + /*22(P)23(Q)24(R) 25(S) 26(T) 27(V) 28(W) */ + 0x16, 0x17, 0x18, 0x19, 0x1A, 0xFF, 0x1B, 0x1C, + /*29(X)30(Y)31(Z) */ + 0x1D, 0x1E, 0x1F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; + +/** + * UnmarshalFrom will unmarshal a ULID from the passed character array. + * */ +inline void UnmarshalFrom(const char str[26], ULID& ulid) { + // timestamp + ulid = (dec[int(str[0])] << 5) | dec[int(str[1])]; + + ulid <<= 8; + ulid |= (dec[int(str[2])] << 3) | (dec[int(str[3])] >> 2); + + ulid <<= 8; + ulid |= (dec[int(str[3])] << 6) | (dec[int(str[4])] << 1) | + (dec[int(str[5])] >> 4); + + ulid <<= 8; + ulid |= (dec[int(str[5])] << 4) | (dec[int(str[6])] >> 1); + + ulid <<= 8; + ulid |= (dec[int(str[6])] << 7) | (dec[int(str[7])] << 2) | + (dec[int(str[8])] >> 3); + + ulid <<= 8; + ulid |= (dec[int(str[8])] << 5) | dec[int(str[9])]; + + // entropy + ulid <<= 8; + ulid |= (dec[int(str[10])] << 3) | (dec[int(str[11])] >> 2); + + ulid <<= 8; + ulid |= (dec[int(str[11])] << 6) | (dec[int(str[12])] << 1) | + (dec[int(str[13])] >> 4); + + ulid <<= 8; + ulid |= (dec[int(str[13])] << 4) | (dec[int(str[14])] >> 1); + + ulid <<= 8; + ulid |= (dec[int(str[14])] << 7) | (dec[int(str[15])] << 2) | + (dec[int(str[16])] >> 3); + + ulid <<= 8; + ulid |= (dec[int(str[16])] << 5) | dec[int(str[17])]; + + ulid <<= 8; + ulid |= (dec[int(str[18])] << 3) | (dec[int(str[19])] >> 2); + + ulid <<= 8; + ulid |= (dec[int(str[19])] << 6) | (dec[int(str[20])] << 1) | + (dec[int(str[21])] >> 4); + + ulid <<= 8; + ulid |= (dec[int(str[21])] << 4) | (dec[int(str[22])] >> 1); + + ulid <<= 8; + ulid |= (dec[int(str[22])] << 7) | (dec[int(str[23])] << 2) | + (dec[int(str[24])] >> 3); + + ulid <<= 8; + ulid |= (dec[int(str[24])] << 5) | dec[int(str[25])]; +} + +/** + * Unmarshal will create a new ULID by unmarshaling the passed string. + * */ +inline ULID Unmarshal(const std::string& str) { + ULID ulid; + UnmarshalFrom(str.c_str(), ulid); + return ulid; +} + +/** + * UnmarshalBinaryFrom will unmarshal a ULID from the passed byte array. + * */ +inline void UnmarshalBinaryFrom(const uint8_t b[16], ULID& ulid) { + // timestamp + ulid = b[0]; + + ulid <<= 8; + ulid |= b[1]; + + ulid <<= 8; + ulid |= b[2]; + + ulid <<= 8; + ulid |= b[3]; + + ulid <<= 8; + ulid |= b[4]; + + ulid <<= 8; + ulid |= b[5]; + + // entropy + ulid <<= 8; + ulid |= b[6]; + + ulid <<= 8; + ulid |= b[7]; + + ulid <<= 8; + ulid |= b[8]; + + ulid <<= 8; + ulid |= b[9]; + + ulid <<= 8; + ulid |= b[10]; + + ulid <<= 8; + ulid |= b[11]; + + ulid <<= 8; + ulid |= b[12]; + + ulid <<= 8; + ulid |= b[13]; + + ulid <<= 8; + ulid |= b[14]; + + ulid <<= 8; + ulid |= b[15]; +} + +/** + * Unmarshal will create a new ULID by unmarshaling the passed byte vector. + * */ +inline ULID UnmarshalBinary(const std::vector& b) { + ULID ulid; + UnmarshalBinaryFrom(b.data(), ulid); + return ulid; +} + +/** + * CompareULIDs will compare two ULIDs. + * returns: + * -1 if ulid1 is Lexicographically before ulid2 + * 1 if ulid1 is Lexicographically after ulid2 + * 0 if ulid1 is same as ulid2 + * */ +inline int CompareULIDs(const ULID& ulid1, const ULID& ulid2) { + return -2 * (ulid1 < ulid2) - 1 * (ulid1 == ulid2) + 1; +} + +/** + * Time will extract the timestamp used to generate a ULID + * */ +inline time_t Time(const ULID& ulid) { + time_t ans = 0; + + ans |= static_cast(ulid >> 120); + + ans <<= 8; + ans |= static_cast(ulid >> 112); + + ans <<= 8; + ans |= static_cast(ulid >> 104); + + ans <<= 8; + ans |= static_cast(ulid >> 96); + + ans <<= 8; + ans |= static_cast(ulid >> 88); + + ans <<= 8; + ans |= static_cast(ulid >> 80); + + return ans; +} + +}; // namespace ulid + +#endif // ULID_UINT128_HH