Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add proxy support
Browse files Browse the repository at this point in the history
namchuai committed Nov 18, 2024

Verified

This commit was signed with the committer’s verified signature.
reneme René Meusel
1 parent 24bebed commit 744ace2
Showing 12 changed files with 506 additions and 703 deletions.
850 changes: 205 additions & 645 deletions docs/static/openapi/cortex.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions engine/cli/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -73,6 +73,7 @@ add_executable(${TARGET_NAME} main.cc
${CMAKE_CURRENT_SOURCE_DIR}/../utils/cpuid/cpu_info.cc
${CMAKE_CURRENT_SOURCE_DIR}/../utils/file_logger.cc
${CMAKE_CURRENT_SOURCE_DIR}/command_line_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/config_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/download_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/engine_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc
128 changes: 125 additions & 3 deletions engine/common/api_server_configuration.h
Original file line number Diff line number Diff line change
@@ -5,22 +5,72 @@
#include <unordered_map>
#include <vector>

// current only support basic auth
enum class ProxyAuthMethod {
Basic,
Digest,
DigestIe,
Bearer,
Negotiate,
Ntlm,
NtlmWb,
Any,
AnySafe,
AuthOnly,
AwsSigV4
};

class ApiServerConfiguration {
public:
ApiServerConfiguration(bool cors = true,
std::vector<std::string> allowed_origins = {})
: cors{cors}, allowed_origins{allowed_origins} {}
ApiServerConfiguration(
bool cors = true, std::vector<std::string> allowed_origins = {},
bool verify_proxy_ssl = true, bool verify_proxy_host_ssl = true,
const std::string& proxy_url = "", const std::string& proxy_username = "",
const std::string& proxy_password = "", const std::string& no_proxy = "",
bool verify_peer_ssl = true, bool verify_host_ssl = true)
: cors{cors},
allowed_origins{allowed_origins},
verify_proxy_ssl{verify_proxy_ssl},
verify_proxy_host_ssl{verify_proxy_host_ssl},
proxy_url{proxy_url},
proxy_username{proxy_username},
proxy_password{proxy_password},
no_proxy{no_proxy},
verify_peer_ssl{verify_peer_ssl},
verify_host_ssl{verify_host_ssl} {}

// cors
bool cors{true};
std::vector<std::string> allowed_origins;

// proxy
bool verify_proxy_ssl{true};
bool verify_proxy_host_ssl{true};
ProxyAuthMethod proxy_auth_method{ProxyAuthMethod::Basic};
std::string proxy_url{""};
std::string proxy_username{""};
std::string proxy_password{""};
std::string no_proxy{""};

bool verify_peer_ssl{true};
bool verify_host_ssl{true};

Json::Value ToJson() const {
Json::Value root;
root["cors"] = cors;
root["allowed_origins"] = Json::Value(Json::arrayValue);
for (const auto& origin : allowed_origins) {
root["allowed_origins"].append(origin);
}
root["verify_proxy_ssl"] = verify_proxy_ssl;
root["verify_proxy_host_ssl"] = verify_proxy_host_ssl;
root["proxy_url"] = proxy_url;
root["proxy_username"] = proxy_username;
root["proxy_password"] = proxy_password;
root["no_proxy"] = no_proxy;
root["verify_peer_ssl"] = verify_peer_ssl;
root["verify_host_ssl"] = verify_host_ssl;

return root;
}

@@ -31,6 +81,78 @@ class ApiServerConfiguration {
const std::unordered_map<std::string,
std::function<bool(const Json::Value&)>>
field_updater{
{"verify_peer_ssl",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
return false;
}
verify_peer_ssl = value.asBool();
return true;
}},

{"verify_host_ssl",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
return false;
}
verify_host_ssl = value.asBool();
return true;
}},

{"verify_proxy_host_ssl",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
return false;
}
verify_proxy_host_ssl = value.asBool();
return true;
}},

{"verify_proxy_ssl",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
return false;
}
verify_proxy_ssl = value.asBool();
return true;
}},

{"no_proxy",
[this](const Json::Value& value) -> bool {
if (!value.isString()) {
return false;
}
no_proxy = value.asString();
return true;
}},

{"proxy_url",
[this](const Json::Value& value) -> bool {
if (!value.isString()) {
return false;
}
proxy_url = value.asString();
return true;
}},

{"proxy_username",
[this](const Json::Value& value) -> bool {
if (!value.isString()) {
return false;
}
proxy_username = value.asString();
return true;
}},

{"proxy_password",
[this](const Json::Value& value) -> bool {
if (!value.isString()) {
return false;
}
proxy_password = value.asString();
return true;
}},

{"cors",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
7 changes: 5 additions & 2 deletions engine/config/yaml_config.cc
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
#include "utils/file_manager_utils.h"
#include "utils/format_utils.h"
#include "yaml_config.h"

namespace config {
// Method to read YAML file
void YamlHandler::Reset() {
@@ -44,6 +45,7 @@ void YamlHandler::ReadYamlFile(const std::string& file_path) {
throw;
}
}

void YamlHandler::SplitPromptTemplate(ModelConfig& mc) {
if (mc.prompt_template.size() > 0) {
auto& pt = mc.prompt_template;
@@ -220,7 +222,7 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) {
yaml_node_["ngl"] = model_config_.ngl;
if (!std::isnan(static_cast<double>(model_config_.ctx_len)))
yaml_node_["ctx_len"] = model_config_.ctx_len;
if (!std::isnan(static_cast<double>(model_config_.n_parallel)))
if (!std::isnan(static_cast<double>(model_config_.n_parallel)))
yaml_node_["n_parallel"] = model_config_.n_parallel;
if (!std::isnan(static_cast<double>(model_config_.tp)))
yaml_node_["tp"] = model_config_.tp;
@@ -377,7 +379,8 @@ void YamlHandler::WriteYamlFile(const std::string& file_path) const {
outFile << format_utils::writeKeyValue(
"ctx_len", yaml_node_["ctx_len"],
"llama.context_length | 0 or undefined = loaded from model");
outFile << format_utils::writeKeyValue("n_parallel", yaml_node_["n_parallel"]);
outFile << format_utils::writeKeyValue("n_parallel",
yaml_node_["n_parallel"]);
outFile << format_utils::writeKeyValue("ngl", yaml_node_["ngl"],
"Undefined = loaded from model");
outFile << "# END OPTIONAL\n";
5 changes: 3 additions & 2 deletions engine/main.cc
Original file line number Diff line number Diff line change
@@ -106,13 +106,14 @@ void RunServer(std::optional<int> port, bool ignore_cout) {
auto event_queue_ptr = std::make_shared<EventQueue>();
cortex::event::EventProcessor event_processor(event_queue_ptr);

auto download_service = std::make_shared<DownloadService>(event_queue_ptr);
auto config_service = std::make_shared<ConfigService>();
auto download_service =
std::make_shared<DownloadService>(event_queue_ptr, config_service);
auto engine_service = std::make_shared<EngineService>(download_service);
auto inference_svc =
std::make_shared<services::InferenceService>(engine_service);
auto model_service = std::make_shared<ModelService>(
download_service, inference_svc, engine_service);
auto config_service = std::make_shared<ConfigService>();

// initialize custom controllers
auto engine_ctl = std::make_shared<Engines>(engine_service);
26 changes: 22 additions & 4 deletions engine/services/config_service.cc
Original file line number Diff line number Diff line change
@@ -5,8 +5,12 @@
cpp::result<ApiServerConfiguration, std::string>
ConfigService::UpdateApiServerConfiguration(const Json::Value& json) {
auto config = file_manager_utils::GetCortexConfig();
ApiServerConfiguration api_server_config{config.enableCors,
config.allowedOrigins};
ApiServerConfiguration api_server_config{
config.enableCors, config.allowedOrigins, config.verifyProxySsl,
config.verifyProxyHostSsl, config.proxyUrl, config.proxyUsername,
config.proxyPassword, config.noProxy, config.verifyPeerSsl,
config.verifyHostSsl};

std::vector<std::string> updated_fields;
std::vector<std::string> invalid_fields;
std::vector<std::string> unknown_fields;
@@ -20,13 +24,27 @@ ConfigService::UpdateApiServerConfiguration(const Json::Value& json) {

config.enableCors = api_server_config.cors;
config.allowedOrigins = api_server_config.allowed_origins;
config.verifyProxySsl = api_server_config.verify_proxy_ssl;
config.verifyProxyHostSsl = api_server_config.verify_proxy_host_ssl;

config.proxyUrl = api_server_config.proxy_url;
config.proxyUsername = api_server_config.proxy_username;
config.proxyPassword = api_server_config.proxy_password;
config.noProxy = api_server_config.no_proxy;

config.verifyPeerSsl = api_server_config.verify_peer_ssl;
config.verifyHostSsl = api_server_config.verify_host_ssl;

auto result = file_manager_utils::UpdateCortexConfig(config);
return api_server_config;
}

cpp::result<ApiServerConfiguration, std::string>
ConfigService::GetApiServerConfiguration() const {
ConfigService::GetApiServerConfiguration() {
auto config = file_manager_utils::GetCortexConfig();
return ApiServerConfiguration{config.enableCors, config.allowedOrigins};
return ApiServerConfiguration{
config.enableCors, config.allowedOrigins, config.verifyProxySsl,
config.verifyProxyHostSsl, config.proxyUrl, config.proxyUsername,
config.proxyPassword, config.noProxy, config.verifyPeerSsl,
config.verifyHostSsl};
}
3 changes: 1 addition & 2 deletions engine/services/config_service.h
Original file line number Diff line number Diff line change
@@ -8,6 +8,5 @@ class ConfigService {
cpp::result<ApiServerConfiguration, std::string> UpdateApiServerConfiguration(
const Json::Value& json);

cpp::result<ApiServerConfiguration, std::string> GetApiServerConfiguration()
const;
cpp::result<ApiServerConfiguration, std::string> GetApiServerConfiguration();
};
47 changes: 47 additions & 0 deletions engine/services/download_service.cc
Original file line number Diff line number Diff line change
@@ -87,6 +87,7 @@ cpp::result<uint64_t, std::string> DownloadService::GetFileSize(
return cpp::fail(static_cast<std::string>("Failed to init CURL"));
}

// TODO: namh add header here
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_NOBODY, 1L);
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
@@ -189,6 +190,7 @@ cpp::result<bool, std::string> DownloadService::Download(

curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers);
}
// TODO: namh add proxy setting here
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, &WriteCallback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, file);
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
@@ -337,6 +339,7 @@ void DownloadService::ProcessTask(DownloadTask& task, int worker_id) {
});
worker_data->downloading_data_map[item.id] = dl_data_ptr;

CTL_ERR("Namh Setup curl");
SetUpCurlHandle(handle, item, file, dl_data_ptr.get());
curl_multi_add_handle(worker_data->multi_handle, handle);
task_handles.push_back(std::make_pair(handle, file));
@@ -407,13 +410,57 @@ cpp::result<void, ProcessDownloadFailed> DownloadService::ProcessMultiDownload(

void DownloadService::SetUpCurlHandle(CURL* handle, const DownloadItem& item,
FILE* file, DownloadingData* dl_data) {
auto configuration = config_service_->GetApiServerConfiguration();
if (configuration.has_value()) {
if (!configuration->proxy_url.empty()) {
auto proxy_url = configuration->proxy_url;
auto verify_proxy_ssl = configuration->verify_proxy_ssl;
auto verify_proxy_host_ssl = configuration->verify_proxy_host_ssl;

auto verify_ssl = configuration->verify_peer_ssl;
auto verify_host_ssl = configuration->verify_host_ssl;

auto proxy_username = configuration->proxy_username;
auto proxy_password = configuration->proxy_password;

CTL_ERR("=== Proxy configuration ===");
CTL_ERR("Proxy url: " << proxy_url);
CTL_ERR("Verify proxy ssl: " << verify_proxy_ssl);
CTL_ERR("Verify proxy host ssl: " << verify_proxy_host_ssl);
CTL_ERR("Verify ssl: " << verify_ssl);
CTL_ERR("Verify host ssl: " << verify_host_ssl);

curl_easy_setopt(handle, CURLOPT_PROXY, proxy_url.c_str());
curl_easy_setopt(handle, CURLOPT_SSL_VERIFYPEER, verify_ssl ? 1L : 0L);
curl_easy_setopt(handle, CURLOPT_SSL_VERIFYHOST,
verify_host_ssl ? 2L : 0L);

curl_easy_setopt(handle, CURLOPT_PROXY_SSL_VERIFYPEER,
verify_proxy_ssl ? 1L : 0L);
curl_easy_setopt(handle, CURLOPT_PROXY_SSL_VERIFYHOST,
verify_proxy_host_ssl ? 2L : 0L);

if (!proxy_username.empty()) {
std::string proxy_auth = proxy_username + ":" + proxy_password;
CTL_ERR("Proxy auth: " << proxy_auth);
curl_easy_setopt(handle, CURLOPT_PROXYUSERPWD, proxy_auth.c_str());
}

curl_easy_setopt(handle, CURLOPT_NOPROXY,
configuration->no_proxy.c_str());
}
} else {
CTL_ERR("Failed to get configuration");
}

curl_easy_setopt(handle, CURLOPT_URL, item.downloadUrl.c_str());
curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, WriteCallback);
curl_easy_setopt(handle, CURLOPT_WRITEDATA, file);
curl_easy_setopt(handle, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(handle, CURLOPT_NOPROGRESS, 0L);
curl_easy_setopt(handle, CURLOPT_XFERINFOFUNCTION, ProgressCallback);
curl_easy_setopt(handle, CURLOPT_XFERINFODATA, dl_data);
curl_easy_setopt(handle, CURLOPT_VERBOSE, 1L);

auto headers = curl_utils::GetHeaders(item.downloadUrl);
if (headers) {
8 changes: 6 additions & 2 deletions engine/services/download_service.h
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
#include <unordered_set>
#include "common/download_task_queue.h"
#include "common/event.h"
#include "services/config_service.h"
#include "utils/result.hpp"

struct ProcessDownloadFailed {
@@ -20,6 +21,8 @@ class DownloadService {
private:
static constexpr int MAX_CONCURRENT_TASKS = 4;

std::shared_ptr<ConfigService> config_service_;

struct DownloadingData {
std::string task_id;
std::string item_id;
@@ -82,8 +85,9 @@ class DownloadService {

explicit DownloadService() = default;

explicit DownloadService(std::shared_ptr<EventQueue> event_queue)
: event_queue_{event_queue} {
explicit DownloadService(std::shared_ptr<EventQueue> event_queue,
std::shared_ptr<ConfigService> config_service)
: event_queue_{event_queue}, config_service_{config_service} {
InitializeWorkers();
};

1 change: 1 addition & 0 deletions engine/test/components/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ add_executable(${PROJECT_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/../../config/gguf_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../cli/commands/cortex_upd_cmd.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../cli/commands/server_stop_cmd.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../services/config_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../services/download_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../database/models.cc
)
Loading

0 comments on commit 744ace2

Please sign in to comment.