Skip to content

Commit

Permalink
feat: add proxy support
Browse files Browse the repository at this point in the history
  • Loading branch information
namchuai committed Nov 18, 2024
1 parent 6892823 commit 824de03
Show file tree
Hide file tree
Showing 12 changed files with 506 additions and 703 deletions.
850 changes: 205 additions & 645 deletions docs/static/openapi/cortex.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions engine/cli/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ add_executable(${TARGET_NAME} main.cc
${CMAKE_CURRENT_SOURCE_DIR}/../utils/cpuid/cpu_info.cc
${CMAKE_CURRENT_SOURCE_DIR}/../utils/file_logger.cc
${CMAKE_CURRENT_SOURCE_DIR}/command_line_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/config_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/download_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/engine_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc
Expand Down
128 changes: 125 additions & 3 deletions engine/common/api_server_configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,72 @@
#include <unordered_map>
#include <vector>

// current only support basic auth
enum class ProxyAuthMethod {
Basic,
Digest,
DigestIe,
Bearer,
Negotiate,
Ntlm,
NtlmWb,
Any,
AnySafe,
AuthOnly,
AwsSigV4
};

class ApiServerConfiguration {
public:
ApiServerConfiguration(bool cors = true,
std::vector<std::string> allowed_origins = {})
: cors{cors}, allowed_origins{allowed_origins} {}
ApiServerConfiguration(
bool cors = true, std::vector<std::string> allowed_origins = {},
bool verify_proxy_ssl = true, bool verify_proxy_host_ssl = true,
const std::string& proxy_url = "", const std::string& proxy_username = "",
const std::string& proxy_password = "", const std::string& no_proxy = "",
bool verify_peer_ssl = true, bool verify_host_ssl = true)
: cors{cors},
allowed_origins{allowed_origins},
verify_proxy_ssl{verify_proxy_ssl},
verify_proxy_host_ssl{verify_proxy_host_ssl},
proxy_url{proxy_url},
proxy_username{proxy_username},
proxy_password{proxy_password},
no_proxy{no_proxy},
verify_peer_ssl{verify_peer_ssl},
verify_host_ssl{verify_host_ssl} {}

// cors
bool cors{true};
std::vector<std::string> allowed_origins;

// proxy
bool verify_proxy_ssl{true};
bool verify_proxy_host_ssl{true};
ProxyAuthMethod proxy_auth_method{ProxyAuthMethod::Basic};
std::string proxy_url{""};
std::string proxy_username{""};
std::string proxy_password{""};
std::string no_proxy{""};

bool verify_peer_ssl{true};
bool verify_host_ssl{true};

Json::Value ToJson() const {
Json::Value root;
root["cors"] = cors;
root["allowed_origins"] = Json::Value(Json::arrayValue);
for (const auto& origin : allowed_origins) {
root["allowed_origins"].append(origin);
}
root["verify_proxy_ssl"] = verify_proxy_ssl;
root["verify_proxy_host_ssl"] = verify_proxy_host_ssl;
root["proxy_url"] = proxy_url;
root["proxy_username"] = proxy_username;
root["proxy_password"] = proxy_password;
root["no_proxy"] = no_proxy;
root["verify_peer_ssl"] = verify_peer_ssl;
root["verify_host_ssl"] = verify_host_ssl;

return root;
}

Expand All @@ -31,6 +81,78 @@ class ApiServerConfiguration {
const std::unordered_map<std::string,
std::function<bool(const Json::Value&)>>
field_updater{
{"verify_peer_ssl",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
return false;
}
verify_peer_ssl = value.asBool();
return true;
}},

{"verify_host_ssl",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
return false;
}
verify_host_ssl = value.asBool();
return true;
}},

{"verify_proxy_host_ssl",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
return false;
}
verify_proxy_host_ssl = value.asBool();
return true;
}},

{"verify_proxy_ssl",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
return false;
}
verify_proxy_ssl = value.asBool();
return true;
}},

{"no_proxy",
[this](const Json::Value& value) -> bool {
if (!value.isString()) {
return false;
}
no_proxy = value.asString();
return true;
}},

{"proxy_url",
[this](const Json::Value& value) -> bool {
if (!value.isString()) {
return false;
}
proxy_url = value.asString();
return true;
}},

{"proxy_username",
[this](const Json::Value& value) -> bool {
if (!value.isString()) {
return false;
}
proxy_username = value.asString();
return true;
}},

{"proxy_password",
[this](const Json::Value& value) -> bool {
if (!value.isString()) {
return false;
}
proxy_password = value.asString();
return true;
}},

{"cors",
[this](const Json::Value& value) -> bool {
if (!value.isBool()) {
Expand Down
7 changes: 5 additions & 2 deletions engine/config/yaml_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "utils/file_manager_utils.h"
#include "utils/format_utils.h"
#include "yaml_config.h"

namespace config {
// Method to read YAML file
void YamlHandler::Reset() {
Expand Down Expand Up @@ -44,6 +45,7 @@ void YamlHandler::ReadYamlFile(const std::string& file_path) {
throw;
}
}

void YamlHandler::SplitPromptTemplate(ModelConfig& mc) {
if (mc.prompt_template.size() > 0) {
auto& pt = mc.prompt_template;
Expand Down Expand Up @@ -220,7 +222,7 @@ void YamlHandler::UpdateModelConfig(ModelConfig new_model_config) {
yaml_node_["ngl"] = model_config_.ngl;
if (!std::isnan(static_cast<double>(model_config_.ctx_len)))
yaml_node_["ctx_len"] = model_config_.ctx_len;
if (!std::isnan(static_cast<double>(model_config_.n_parallel)))
if (!std::isnan(static_cast<double>(model_config_.n_parallel)))
yaml_node_["n_parallel"] = model_config_.n_parallel;
if (!std::isnan(static_cast<double>(model_config_.tp)))
yaml_node_["tp"] = model_config_.tp;
Expand Down Expand Up @@ -377,7 +379,8 @@ void YamlHandler::WriteYamlFile(const std::string& file_path) const {
outFile << format_utils::writeKeyValue(
"ctx_len", yaml_node_["ctx_len"],
"llama.context_length | 0 or undefined = loaded from model");
outFile << format_utils::writeKeyValue("n_parallel", yaml_node_["n_parallel"]);
outFile << format_utils::writeKeyValue("n_parallel",
yaml_node_["n_parallel"]);
outFile << format_utils::writeKeyValue("ngl", yaml_node_["ngl"],
"Undefined = loaded from model");
outFile << "# END OPTIONAL\n";
Expand Down
5 changes: 3 additions & 2 deletions engine/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,14 @@ void RunServer(std::optional<int> port, bool ignore_cout) {
auto event_queue_ptr = std::make_shared<EventQueue>();
cortex::event::EventProcessor event_processor(event_queue_ptr);

auto download_service = std::make_shared<DownloadService>(event_queue_ptr);
auto config_service = std::make_shared<ConfigService>();
auto download_service =
std::make_shared<DownloadService>(event_queue_ptr, config_service);
auto engine_service = std::make_shared<EngineService>(download_service);
auto inference_svc =
std::make_shared<services::InferenceService>(engine_service);
auto model_service = std::make_shared<ModelService>(
download_service, inference_svc, engine_service);
auto config_service = std::make_shared<ConfigService>();

// initialize custom controllers
auto engine_ctl = std::make_shared<Engines>(engine_service);
Expand Down
26 changes: 22 additions & 4 deletions engine/services/config_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
cpp::result<ApiServerConfiguration, std::string>
ConfigService::UpdateApiServerConfiguration(const Json::Value& json) {
auto config = file_manager_utils::GetCortexConfig();
ApiServerConfiguration api_server_config{config.enableCors,
config.allowedOrigins};
ApiServerConfiguration api_server_config{
config.enableCors, config.allowedOrigins, config.verifyProxySsl,
config.verifyProxyHostSsl, config.proxyUrl, config.proxyUsername,
config.proxyPassword, config.noProxy, config.verifyPeerSsl,
config.verifyHostSsl};

std::vector<std::string> updated_fields;
std::vector<std::string> invalid_fields;
std::vector<std::string> unknown_fields;
Expand All @@ -20,13 +24,27 @@ ConfigService::UpdateApiServerConfiguration(const Json::Value& json) {

config.enableCors = api_server_config.cors;
config.allowedOrigins = api_server_config.allowed_origins;
config.verifyProxySsl = api_server_config.verify_proxy_ssl;
config.verifyProxyHostSsl = api_server_config.verify_proxy_host_ssl;

config.proxyUrl = api_server_config.proxy_url;
config.proxyUsername = api_server_config.proxy_username;
config.proxyPassword = api_server_config.proxy_password;
config.noProxy = api_server_config.no_proxy;

config.verifyPeerSsl = api_server_config.verify_peer_ssl;
config.verifyHostSsl = api_server_config.verify_host_ssl;

auto result = file_manager_utils::UpdateCortexConfig(config);
return api_server_config;
}

cpp::result<ApiServerConfiguration, std::string>
ConfigService::GetApiServerConfiguration() const {
ConfigService::GetApiServerConfiguration() {
auto config = file_manager_utils::GetCortexConfig();
return ApiServerConfiguration{config.enableCors, config.allowedOrigins};
return ApiServerConfiguration{
config.enableCors, config.allowedOrigins, config.verifyProxySsl,
config.verifyProxyHostSsl, config.proxyUrl, config.proxyUsername,
config.proxyPassword, config.noProxy, config.verifyPeerSsl,
config.verifyHostSsl};
}
3 changes: 1 addition & 2 deletions engine/services/config_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@ class ConfigService {
cpp::result<ApiServerConfiguration, std::string> UpdateApiServerConfiguration(
const Json::Value& json);

cpp::result<ApiServerConfiguration, std::string> GetApiServerConfiguration()
const;
cpp::result<ApiServerConfiguration, std::string> GetApiServerConfiguration();
};
47 changes: 47 additions & 0 deletions engine/services/download_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ cpp::result<uint64_t, std::string> DownloadService::GetFileSize(
return cpp::fail(static_cast<std::string>("Failed to init CURL"));
}

// TODO: namh add header here
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_NOBODY, 1L);
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
Expand Down Expand Up @@ -189,6 +190,7 @@ cpp::result<bool, std::string> DownloadService::Download(

curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers);
}
// TODO: namh add proxy setting here
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, &WriteCallback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, file);
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
Expand Down Expand Up @@ -337,6 +339,7 @@ void DownloadService::ProcessTask(DownloadTask& task, int worker_id) {
});
worker_data->downloading_data_map[item.id] = dl_data_ptr;

CTL_ERR("Namh Setup curl");
SetUpCurlHandle(handle, item, file, dl_data_ptr.get());
curl_multi_add_handle(worker_data->multi_handle, handle);
task_handles.push_back(std::make_pair(handle, file));
Expand Down Expand Up @@ -407,13 +410,57 @@ cpp::result<void, ProcessDownloadFailed> DownloadService::ProcessMultiDownload(

void DownloadService::SetUpCurlHandle(CURL* handle, const DownloadItem& item,
FILE* file, DownloadingData* dl_data) {
auto configuration = config_service_->GetApiServerConfiguration();
if (configuration.has_value()) {
if (!configuration->proxy_url.empty()) {
auto proxy_url = configuration->proxy_url;
auto verify_proxy_ssl = configuration->verify_proxy_ssl;
auto verify_proxy_host_ssl = configuration->verify_proxy_host_ssl;

auto verify_ssl = configuration->verify_peer_ssl;
auto verify_host_ssl = configuration->verify_host_ssl;

auto proxy_username = configuration->proxy_username;
auto proxy_password = configuration->proxy_password;

CTL_ERR("=== Proxy configuration ===");
CTL_ERR("Proxy url: " << proxy_url);
CTL_ERR("Verify proxy ssl: " << verify_proxy_ssl);
CTL_ERR("Verify proxy host ssl: " << verify_proxy_host_ssl);
CTL_ERR("Verify ssl: " << verify_ssl);
CTL_ERR("Verify host ssl: " << verify_host_ssl);

curl_easy_setopt(handle, CURLOPT_PROXY, proxy_url.c_str());
curl_easy_setopt(handle, CURLOPT_SSL_VERIFYPEER, verify_ssl ? 1L : 0L);
curl_easy_setopt(handle, CURLOPT_SSL_VERIFYHOST,
verify_host_ssl ? 2L : 0L);

curl_easy_setopt(handle, CURLOPT_PROXY_SSL_VERIFYPEER,
verify_proxy_ssl ? 1L : 0L);
curl_easy_setopt(handle, CURLOPT_PROXY_SSL_VERIFYHOST,
verify_proxy_host_ssl ? 2L : 0L);

if (!proxy_username.empty()) {
std::string proxy_auth = proxy_username + ":" + proxy_password;
CTL_ERR("Proxy auth: " << proxy_auth);
curl_easy_setopt(handle, CURLOPT_PROXYUSERPWD, proxy_auth.c_str());
}

curl_easy_setopt(handle, CURLOPT_NOPROXY,
configuration->no_proxy.c_str());
}
} else {
CTL_ERR("Failed to get configuration");
}

curl_easy_setopt(handle, CURLOPT_URL, item.downloadUrl.c_str());
curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, WriteCallback);
curl_easy_setopt(handle, CURLOPT_WRITEDATA, file);
curl_easy_setopt(handle, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(handle, CURLOPT_NOPROGRESS, 0L);
curl_easy_setopt(handle, CURLOPT_XFERINFOFUNCTION, ProgressCallback);
curl_easy_setopt(handle, CURLOPT_XFERINFODATA, dl_data);
curl_easy_setopt(handle, CURLOPT_VERBOSE, 1L);

auto headers = curl_utils::GetHeaders(item.downloadUrl);
if (headers) {
Expand Down
8 changes: 6 additions & 2 deletions engine/services/download_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <unordered_set>
#include "common/download_task_queue.h"
#include "common/event.h"
#include "services/config_service.h"
#include "utils/result.hpp"

struct ProcessDownloadFailed {
Expand All @@ -20,6 +21,8 @@ class DownloadService {
private:
static constexpr int MAX_CONCURRENT_TASKS = 4;

std::shared_ptr<ConfigService> config_service_;

struct DownloadingData {
std::string task_id;
std::string item_id;
Expand Down Expand Up @@ -82,8 +85,9 @@ class DownloadService {

explicit DownloadService() = default;

explicit DownloadService(std::shared_ptr<EventQueue> event_queue)
: event_queue_{event_queue} {
explicit DownloadService(std::shared_ptr<EventQueue> event_queue,
std::shared_ptr<ConfigService> config_service)
: event_queue_{event_queue}, config_service_{config_service} {
InitializeWorkers();
};

Expand Down
1 change: 1 addition & 0 deletions engine/test/components/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_executable(${PROJECT_NAME}
${CMAKE_CURRENT_SOURCE_DIR}/../../config/gguf_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../cli/commands/cortex_upd_cmd.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../cli/commands/server_stop_cmd.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../services/config_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../services/download_service.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../database/models.cc
)
Expand Down
Loading

0 comments on commit 824de03

Please sign in to comment.