Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Agent/Worker comm using files in Python and Cpp #421

Merged
merged 4 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions packages/cpp/ArmoniK.Api.Common/header/utils/string_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#pragma once

#include <algorithm>
#include <cctype>
#include <locale>
#include <string>

namespace armonik {
namespace api {
namespace common {
namespace utils {
// trim from start (in place)
static inline void ltrim(std::string &s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); }));
}

// trim from end (in place)
static inline void rtrim(std::string &s) {
s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base(), s.end());
}

// trim from both ends (in place)
static inline void trim(std::string &s) {
rtrim(s);
ltrim(s);
}

// trim from start (copying)
static inline std::string ltrim_copy(std::string s) {
ltrim(s);
return s;
}

// trim from end (copying)
static inline std::string rtrim_copy(std::string s) {
rtrim(s);
return s;
}

// trim from both ends (copying)
static inline std::string trim_copy(std::string s) {
trim(s);
return s;
}

inline std::string pathJoin(const std::string &p1, const std::string &p2) {
#ifdef _WIN32
constexpr char sep = '\\';
#else
constexpr char sep = '/';
#endif
std::string tmp = trim_copy(p1);

if (tmp[tmp.length() - 1] != sep) {
tmp += sep;
}
return tmp + trim_copy(p2);
}
} // namespace utils
} // namespace common
} // namespace api
} // namespace armonik
5 changes: 1 addition & 4 deletions packages/cpp/ArmoniK.Api.Worker.Tests/source/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ class TestWorker : public armonik::api::worker::ArmoniKWorker {

try {
if (!taskHandler.getExpectedResults().empty()) {
auto res = taskHandler.send_result(taskHandler.getExpectedResults()[0], taskHandler.getPayload()).get();
if (res.has_error()) {
throw armonik::api::common::exceptions::ArmoniKApiException(res.error());
}
taskHandler.send_result(taskHandler.getExpectedResults()[0], taskHandler.getPayload()).get();
}

} catch (const std::exception &e) {
Expand Down
9 changes: 4 additions & 5 deletions packages/cpp/ArmoniK.Api.Worker/header/Worker/ArmoniKWorker.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ class ArmoniKWorker : public armonik::api::grpc::v1::worker::Worker::Service {
* @brief Implements the Process method of the Worker service.
*
* @param context The ServerContext object.
* @param reader The request iterator
* @param request The Process request
* @param response The ProcessReply object.
*
* @return The status of the method.
*/
[[maybe_unused]] ::grpc::Status
Process(::grpc::ServerContext *context,
::grpc::ServerReader<::armonik::api::grpc::v1::worker::ProcessRequest> *reader,
::armonik::api::grpc::v1::worker::ProcessReply *response) override;
::grpc::Status Process(::grpc::ServerContext *context,
const ::armonik::api::grpc::v1::worker::ProcessRequest *request,
::armonik::api::grpc::v1::worker::ProcessReply *response) override;

/**
* @brief Function which does the actual work
Expand Down
15 changes: 5 additions & 10 deletions packages/cpp/ArmoniK.Api.Worker/header/Worker/TaskHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TaskHandler {

private:
armonik::api::grpc::v1::agent::Agent::Stub &stub_;
::grpc::ServerReader<armonik::api::grpc::v1::worker::ProcessRequest> &request_iterator_;
const armonik::api::grpc::v1::worker::ProcessRequest &request_;
std::string session_id_;
std::string task_id_;
armonik::api::grpc::v1::TaskOptions task_options_;
Expand All @@ -31,22 +31,17 @@ class TaskHandler {
std::map<std::string, std::string> data_dependencies_;
std::string token_;
armonik::api::grpc::v1::Configuration config_;
std::string data_folder_;

public:
/**
* @brief Construct a new Task Handler object
*
* @param client the agent client
* @param request_iterator The request iterator
* @param request The process request
*/
TaskHandler(armonik::api::grpc::v1::agent::Agent::Stub &client,
::grpc::ServerReader<armonik::api::grpc::v1::worker::ProcessRequest> &request_iterator);

/**
* @brief Initialise the task handler
*
*/
void init();
const armonik::api::grpc::v1::worker::ProcessRequest &request);

/**
* @brief Create a task_chunk_stream.
Expand Down Expand Up @@ -89,7 +84,7 @@ class TaskHandler {
* @param data The result data
* @return A future containing a vector of ResultReply
*/
std::future<armonik::api::grpc::v1::agent::ResultReply> send_result(std::string key, absl::string_view data);
std::future<void> send_result(std::string key, absl::string_view data);

/**
* @brief Get the result ids object
Expand Down
43 changes: 24 additions & 19 deletions packages/cpp/ArmoniK.Api.Worker/source/Worker/ArmoniKWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,32 +47,37 @@ armonik::api::worker::ArmoniKWorker::ArmoniKWorker(std::unique_ptr<armonik::api:
* @return The status of the method.
*/
[[maybe_unused]] Status
armonik::api::worker::ArmoniKWorker::Process([[maybe_unused]] ::grpc::ServerContext *context,
::grpc::ServerReader<ProcessRequest> *reader,
armonik::api::worker::ArmoniKWorker::Process(::grpc::ServerContext *context,
const ::armonik::api::grpc::v1::worker::ProcessRequest *request,
::armonik::api::grpc::v1::worker::ProcessReply *response) {

(void)context;
logger_.debug("Receive new request From C++ Worker");

TaskHandler task_handler(*agent_, *reader);

task_handler.init();
try {
ProcessStatus status = Execute(task_handler);

logger_.debug("Finish call C++");

armonik::api::grpc::v1::Output output;
if (status.ok()) {
*output.mutable_ok() = armonik::api::grpc::v1::Empty();
} else {
output.mutable_error()->set_details(std::move(status).details());
TaskHandler task_handler(*agent_, *request);
try {
ProcessStatus status = Execute(task_handler);

logger_.debug("Finish call C++");

armonik::api::grpc::v1::Output output;
if (status.ok()) {
*output.mutable_ok() = armonik::api::grpc::v1::Empty();
} else {
output.mutable_error()->set_details(std::move(status).details());
}
*response->mutable_output() = std::move(output);
} catch (const std::exception &e) {
logger_.error("Error processing task : {what}", {{"what", e.what()}});
std::stringstream ss;
ss << "Error processing task : " << e.what();
return {::grpc::StatusCode::UNAVAILABLE, ss.str(), e.what()};
}
*response->mutable_output() = std::move(output);
} catch (const std::exception &e) {
logger_.error("Error processing task : {what}", {{"what", e.what()}});
logger_.error("Error in the request handling : {what}", {{"what", e.what()}});
std::stringstream ss;
ss << "Error processing task : " << e.what();
return {::grpc::StatusCode::UNAVAILABLE, ss.str(), e.what()};
ss << "Error in the request handling : " << e.what();
return {::grpc::StatusCode::INVALID_ARGUMENT, ss.str(), e.what()};
}

return ::grpc::Status::OK;
Expand Down
159 changes: 42 additions & 117 deletions packages/cpp/ArmoniK.Api.Worker/source/Worker/TaskHandler.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "Worker/TaskHandler.h"
#include "exceptions/ArmoniKApiException.h"
#include "utils/string_utils.h"
#include <fstream>
#include <future>
#include <sstream>
#include <string>
Expand Down Expand Up @@ -28,89 +30,29 @@ using ::grpc::Status;
* @param client the agent client
* @param request_iterator The request iterator
*/
armonik::api::worker::TaskHandler::TaskHandler(Agent::Stub &client,
::grpc::ServerReader<ProcessRequest> &request_iterator)
: stub_(client), request_iterator_(request_iterator) {}

/**
* @brief Initialise the task handler
*
*/
void armonik::api::worker::TaskHandler::init() {
ProcessRequest Request;
if (!request_iterator_.Read(&Request)) {
throw std::runtime_error("Request stream ended unexpectedly.");
armonik::api::worker::TaskHandler::TaskHandler(Agent::Stub &client, const ProcessRequest &request)
: stub_(client), request_(request) {
token_ = request_.communication_token();
session_id_ = request_.session_id();
task_id_ = request_.task_id();
task_options_ = request_.task_options();
const std::string payload_id = request_.payload_id();
data_folder_ = request_.data_folder();
std::ostringstream string_stream(std::ios::binary);
string_stream
<< std::ifstream(armonik::api::common::utils::pathJoin(data_folder_, payload_id), std::fstream::binary).rdbuf();
payload_ = string_stream.str();
string_stream.clear();
config_ = request_.configuration();
expected_result_.assign(request_.expected_output_keys().begin(), request_.expected_output_keys().end());

for (auto &&dd : request_.data_dependencies()) {
// TODO Replace with lazy loading via a custom std::map (to not break compatibility)
string_stream
<< std::ifstream(armonik::api::common::utils::pathJoin(data_folder_, dd), std::fstream::binary).rdbuf();
data_dependencies_[dd] = string_stream.str();
string_stream.clear();
}

if (Request.compute().type_case() != armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::kInitRequest) {
throw std::runtime_error("Expected a Compute request type with InitRequest to start the stream.");
}
auto *init_request = Request.mutable_compute()->mutable_init_request();
session_id_ = init_request->session_id();
task_id_ = init_request->task_id();
task_options_ = init_request->task_options();
expected_result_.assign(std::make_move_iterator(init_request->mutable_expected_output_keys()->begin()),
std::make_move_iterator(init_request->mutable_expected_output_keys()->end()));
token_ = Request.communication_token();
config_ = std::move(*init_request->mutable_configuration());

auto *datachunk = &init_request->payload();
assert(payload_.empty());
payload_.append(datachunk->data());

while (!datachunk->data_complete()) {
if (!request_iterator_.Read(&Request)) {
throw std::runtime_error("Request stream ended unexpectedly.");
}
if (Request.compute().type_case() != armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::kPayload) {
throw std::runtime_error("Expected a Compute request type with Payload to continue the stream.");
}

datachunk = &Request.compute().payload();
if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::kData) {
payload_.append(datachunk->data());
} else if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::TYPE_NOT_SET) {
throw std::runtime_error("Expected a Compute request type with a DataChunk Payload to continue the stream.");
} else if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::kDataComplete) {
break;
}
}

armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::InitData *init_data;

do {
if (!request_iterator_.Read(&Request)) {
throw std::runtime_error("Request stream ended unexpectedly.");
}
if (Request.compute().type_case() != armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::kInitData) {
throw std::runtime_error("Expected a Compute request type with InitData to continue the stream.");
}

init_data = Request.mutable_compute()->mutable_init_data();
if (init_data->type_case() == armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest_InitData::kKey) {
const std::string &key = init_data->key();
std::string data_dep;
while (true) {
ProcessRequest dep_request;
if (!request_iterator_.Read(&dep_request)) {
throw std::runtime_error("Request stream ended unexpectedly.");
}
if (dep_request.compute().type_case() != armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::kData) {
throw std::runtime_error("Expected a Compute request type with Data to continue the stream.");
}

auto chunk = dep_request.compute().data();
if (chunk.type_case() == armonik::api::grpc::v1::DataChunk::kData) {
data_dep.append(chunk.data());
} else if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::TYPE_NOT_SET) {
throw std::runtime_error("Expected a Compute request type with a DataChunk Payload to continue the stream.");
} else if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::kDataComplete) {
break;
}
}
data_dependencies_[key] = data_dep;
}
} while (init_data->type_case() == armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest_InitData::kKey);
}

/**
Expand Down Expand Up @@ -273,52 +215,35 @@ armonik::api::worker::TaskHandler::create_tasks_async(TaskOptions task_options,
* @param data The result data
* @return A future containing a vector of ResultReply
*/
std::future<armonik::api::grpc::v1::agent::ResultReply>
armonik::api::worker::TaskHandler::send_result(std::string key, absl::string_view data) {
std::future<void> armonik::api::worker::TaskHandler::send_result(std::string key, absl::string_view data) {
return std::async(std::launch::async, [this, key = std::move(key), data]() mutable {
::grpc::ClientContext context_client_writer;

armonik::api::grpc::v1::agent::ResultReply reply;

size_t max_chunk = config_.data_chunk_max_size();
const size_t data_size = data.size();
size_t start = 0;
::grpc::ClientContext context;

auto stream = stub_.SendResult(&context_client_writer, &reply);
std::ofstream output(armonik::api::common::utils::pathJoin(data_folder_, key),
std::fstream::binary | std::fstream::trunc);
output << data;
output.close();

armonik::api::grpc::v1::agent::Result init_msg;
init_msg.mutable_init()->set_key(std::move(key));
init_msg.set_communication_token(token_);
armonik::api::grpc::v1::agent::NotifyResultDataResponse reply;
armonik::api::grpc::v1::agent::NotifyResultDataRequest request;
request.set_communication_token(token_);
armonik::api::grpc::v1::agent::NotifyResultDataRequest::ResultIdentifier result_id;
result_id.set_session_id(session_id_);
result_id.set_result_id(key);
*(request.mutable_ids()->Add()) = result_id;

stream->Write(init_msg);

while (start < data_size) {
size_t chunkSize = std::min(max_chunk, data_size - start);

armonik::api::grpc::v1::agent::Result msg;
msg.set_communication_token(token_);
msg.mutable_data()->mutable_data()->assign(data.data() + start, chunkSize);

stream->Write(msg);

start += chunkSize;
}

armonik::api::grpc::v1::agent::Result end_msg;
end_msg.set_communication_token(token_);
end_msg.mutable_data()->set_data_complete(true);
stream->Write(end_msg);

stream->WritesDone();
::grpc::Status status = stream->Finish();
auto status = stub_.NotifyResultData(&context, request, &reply);

if (!status.ok()) {
std::stringstream message;
message << "Error: " << status.error_code() << ": " << status.error_message()
<< ". details: " << status.error_details() << std::endl;
throw armonik::api::common::exceptions::ArmoniKApiException(message.str());
}
return reply;

if (reply.result_ids_size() != 1) {
throw armonik::api::common::exceptions::ArmoniKApiException("Received erroneous reply for send data");
}
});
}

Expand Down
Loading