Skip to content

Commit

Permalink
fix: Agent/Worker comm using files in Python and Cpp (#421)
Browse files Browse the repository at this point in the history
  • Loading branch information
aneojgurhem authored Sep 28, 2023
2 parents 9fd5060 + 6c662c2 commit dafff72
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 368 deletions.
62 changes: 62 additions & 0 deletions packages/cpp/ArmoniK.Api.Common/header/utils/string_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#pragma once

#include <algorithm>
#include <cctype>
#include <locale>
#include <string>

namespace armonik {
namespace api {
namespace common {
namespace utils {
// trim from start (in place)
static inline void ltrim(std::string &s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); }));
}

// trim from end (in place)
static inline void rtrim(std::string &s) {
s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base(), s.end());
}

// trim from both ends (in place)
static inline void trim(std::string &s) {
rtrim(s);
ltrim(s);
}

// trim from start (copying)
static inline std::string ltrim_copy(std::string s) {
ltrim(s);
return s;
}

// trim from end (copying)
static inline std::string rtrim_copy(std::string s) {
rtrim(s);
return s;
}

// trim from both ends (copying)
static inline std::string trim_copy(std::string s) {
trim(s);
return s;
}

inline std::string pathJoin(const std::string &p1, const std::string &p2) {
#ifdef _WIN32
constexpr char sep = '\\';
#else
constexpr char sep = '/';
#endif
std::string tmp = trim_copy(p1);

if (tmp[tmp.length() - 1] != sep) {
tmp += sep;
}
return tmp + trim_copy(p2);
}
} // namespace utils
} // namespace common
} // namespace api
} // namespace armonik
5 changes: 1 addition & 4 deletions packages/cpp/ArmoniK.Api.Worker.Tests/source/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ class TestWorker : public armonik::api::worker::ArmoniKWorker {

try {
if (!taskHandler.getExpectedResults().empty()) {
auto res = taskHandler.send_result(taskHandler.getExpectedResults()[0], taskHandler.getPayload()).get();
if (res.has_error()) {
throw armonik::api::common::exceptions::ArmoniKApiException(res.error());
}
taskHandler.send_result(taskHandler.getExpectedResults()[0], taskHandler.getPayload()).get();
}

} catch (const std::exception &e) {
Expand Down
9 changes: 4 additions & 5 deletions packages/cpp/ArmoniK.Api.Worker/header/Worker/ArmoniKWorker.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ class ArmoniKWorker : public armonik::api::grpc::v1::worker::Worker::Service {
* @brief Implements the Process method of the Worker service.
*
* @param context The ServerContext object.
* @param reader The request iterator
* @param request The Process request
* @param response The ProcessReply object.
*
* @return The status of the method.
*/
[[maybe_unused]] ::grpc::Status
Process(::grpc::ServerContext *context,
::grpc::ServerReader<::armonik::api::grpc::v1::worker::ProcessRequest> *reader,
::armonik::api::grpc::v1::worker::ProcessReply *response) override;
::grpc::Status Process(::grpc::ServerContext *context,
const ::armonik::api::grpc::v1::worker::ProcessRequest *request,
::armonik::api::grpc::v1::worker::ProcessReply *response) override;

/**
* @brief Function which does the actual work
Expand Down
15 changes: 5 additions & 10 deletions packages/cpp/ArmoniK.Api.Worker/header/Worker/TaskHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TaskHandler {

private:
armonik::api::grpc::v1::agent::Agent::Stub &stub_;
::grpc::ServerReader<armonik::api::grpc::v1::worker::ProcessRequest> &request_iterator_;
const armonik::api::grpc::v1::worker::ProcessRequest &request_;
std::string session_id_;
std::string task_id_;
armonik::api::grpc::v1::TaskOptions task_options_;
Expand All @@ -31,22 +31,17 @@ class TaskHandler {
std::map<std::string, std::string> data_dependencies_;
std::string token_;
armonik::api::grpc::v1::Configuration config_;
std::string data_folder_;

public:
/**
* @brief Construct a new Task Handler object
*
* @param client the agent client
* @param request_iterator The request iterator
* @param request The process request
*/
TaskHandler(armonik::api::grpc::v1::agent::Agent::Stub &client,
::grpc::ServerReader<armonik::api::grpc::v1::worker::ProcessRequest> &request_iterator);

/**
* @brief Initialise the task handler
*
*/
void init();
const armonik::api::grpc::v1::worker::ProcessRequest &request);

/**
* @brief Create a task_chunk_stream.
Expand Down Expand Up @@ -89,7 +84,7 @@ class TaskHandler {
* @param data The result data
* @return A future containing a vector of ResultReply
*/
std::future<armonik::api::grpc::v1::agent::ResultReply> send_result(std::string key, absl::string_view data);
std::future<void> send_result(std::string key, absl::string_view data);

/**
* @brief Get the result ids object
Expand Down
43 changes: 24 additions & 19 deletions packages/cpp/ArmoniK.Api.Worker/source/Worker/ArmoniKWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,32 +47,37 @@ armonik::api::worker::ArmoniKWorker::ArmoniKWorker(std::unique_ptr<armonik::api:
* @return The status of the method.
*/
[[maybe_unused]] Status
armonik::api::worker::ArmoniKWorker::Process([[maybe_unused]] ::grpc::ServerContext *context,
::grpc::ServerReader<ProcessRequest> *reader,
armonik::api::worker::ArmoniKWorker::Process(::grpc::ServerContext *context,
const ::armonik::api::grpc::v1::worker::ProcessRequest *request,
::armonik::api::grpc::v1::worker::ProcessReply *response) {

(void)context;
logger_.debug("Receive new request From C++ Worker");

TaskHandler task_handler(*agent_, *reader);

task_handler.init();
try {
ProcessStatus status = Execute(task_handler);

logger_.debug("Finish call C++");

armonik::api::grpc::v1::Output output;
if (status.ok()) {
*output.mutable_ok() = armonik::api::grpc::v1::Empty();
} else {
output.mutable_error()->set_details(std::move(status).details());
TaskHandler task_handler(*agent_, *request);
try {
ProcessStatus status = Execute(task_handler);

logger_.debug("Finish call C++");

armonik::api::grpc::v1::Output output;
if (status.ok()) {
*output.mutable_ok() = armonik::api::grpc::v1::Empty();
} else {
output.mutable_error()->set_details(std::move(status).details());
}
*response->mutable_output() = std::move(output);
} catch (const std::exception &e) {
logger_.error("Error processing task : {what}", {{"what", e.what()}});
std::stringstream ss;
ss << "Error processing task : " << e.what();
return {::grpc::StatusCode::UNAVAILABLE, ss.str(), e.what()};
}
*response->mutable_output() = std::move(output);
} catch (const std::exception &e) {
logger_.error("Error processing task : {what}", {{"what", e.what()}});
logger_.error("Error in the request handling : {what}", {{"what", e.what()}});
std::stringstream ss;
ss << "Error processing task : " << e.what();
return {::grpc::StatusCode::UNAVAILABLE, ss.str(), e.what()};
ss << "Error in the request handling : " << e.what();
return {::grpc::StatusCode::INVALID_ARGUMENT, ss.str(), e.what()};
}

return ::grpc::Status::OK;
Expand Down
159 changes: 42 additions & 117 deletions packages/cpp/ArmoniK.Api.Worker/source/Worker/TaskHandler.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "Worker/TaskHandler.h"
#include "exceptions/ArmoniKApiException.h"
#include "utils/string_utils.h"
#include <fstream>
#include <future>
#include <sstream>
#include <string>
Expand Down Expand Up @@ -28,89 +30,29 @@ using ::grpc::Status;
* @param client the agent client
* @param request_iterator The request iterator
*/
armonik::api::worker::TaskHandler::TaskHandler(Agent::Stub &client,
::grpc::ServerReader<ProcessRequest> &request_iterator)
: stub_(client), request_iterator_(request_iterator) {}

/**
* @brief Initialise the task handler
*
*/
void armonik::api::worker::TaskHandler::init() {
ProcessRequest Request;
if (!request_iterator_.Read(&Request)) {
throw std::runtime_error("Request stream ended unexpectedly.");
armonik::api::worker::TaskHandler::TaskHandler(Agent::Stub &client, const ProcessRequest &request)
: stub_(client), request_(request) {
token_ = request_.communication_token();
session_id_ = request_.session_id();
task_id_ = request_.task_id();
task_options_ = request_.task_options();
const std::string payload_id = request_.payload_id();
data_folder_ = request_.data_folder();
std::ostringstream string_stream(std::ios::binary);
string_stream
<< std::ifstream(armonik::api::common::utils::pathJoin(data_folder_, payload_id), std::fstream::binary).rdbuf();
payload_ = string_stream.str();
string_stream.clear();
config_ = request_.configuration();
expected_result_.assign(request_.expected_output_keys().begin(), request_.expected_output_keys().end());

for (auto &&dd : request_.data_dependencies()) {
// TODO Replace with lazy loading via a custom std::map (to not break compatibility)
string_stream
<< std::ifstream(armonik::api::common::utils::pathJoin(data_folder_, dd), std::fstream::binary).rdbuf();
data_dependencies_[dd] = string_stream.str();
string_stream.clear();
}

if (Request.compute().type_case() != armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::kInitRequest) {
throw std::runtime_error("Expected a Compute request type with InitRequest to start the stream.");
}
auto *init_request = Request.mutable_compute()->mutable_init_request();
session_id_ = init_request->session_id();
task_id_ = init_request->task_id();
task_options_ = init_request->task_options();
expected_result_.assign(std::make_move_iterator(init_request->mutable_expected_output_keys()->begin()),
std::make_move_iterator(init_request->mutable_expected_output_keys()->end()));
token_ = Request.communication_token();
config_ = std::move(*init_request->mutable_configuration());

auto *datachunk = &init_request->payload();
assert(payload_.empty());
payload_.append(datachunk->data());

while (!datachunk->data_complete()) {
if (!request_iterator_.Read(&Request)) {
throw std::runtime_error("Request stream ended unexpectedly.");
}
if (Request.compute().type_case() != armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::kPayload) {
throw std::runtime_error("Expected a Compute request type with Payload to continue the stream.");
}

datachunk = &Request.compute().payload();
if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::kData) {
payload_.append(datachunk->data());
} else if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::TYPE_NOT_SET) {
throw std::runtime_error("Expected a Compute request type with a DataChunk Payload to continue the stream.");
} else if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::kDataComplete) {
break;
}
}

armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::InitData *init_data;

do {
if (!request_iterator_.Read(&Request)) {
throw std::runtime_error("Request stream ended unexpectedly.");
}
if (Request.compute().type_case() != armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::kInitData) {
throw std::runtime_error("Expected a Compute request type with InitData to continue the stream.");
}

init_data = Request.mutable_compute()->mutable_init_data();
if (init_data->type_case() == armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest_InitData::kKey) {
const std::string &key = init_data->key();
std::string data_dep;
while (true) {
ProcessRequest dep_request;
if (!request_iterator_.Read(&dep_request)) {
throw std::runtime_error("Request stream ended unexpectedly.");
}
if (dep_request.compute().type_case() != armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::kData) {
throw std::runtime_error("Expected a Compute request type with Data to continue the stream.");
}

auto chunk = dep_request.compute().data();
if (chunk.type_case() == armonik::api::grpc::v1::DataChunk::kData) {
data_dep.append(chunk.data());
} else if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::TYPE_NOT_SET) {
throw std::runtime_error("Expected a Compute request type with a DataChunk Payload to continue the stream.");
} else if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::kDataComplete) {
break;
}
}
data_dependencies_[key] = data_dep;
}
} while (init_data->type_case() == armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest_InitData::kKey);
}

/**
Expand Down Expand Up @@ -273,52 +215,35 @@ armonik::api::worker::TaskHandler::create_tasks_async(TaskOptions task_options,
* @param data The result data
* @return A future containing a vector of ResultReply
*/
std::future<armonik::api::grpc::v1::agent::ResultReply>
armonik::api::worker::TaskHandler::send_result(std::string key, absl::string_view data) {
std::future<void> armonik::api::worker::TaskHandler::send_result(std::string key, absl::string_view data) {
return std::async(std::launch::async, [this, key = std::move(key), data]() mutable {
::grpc::ClientContext context_client_writer;

armonik::api::grpc::v1::agent::ResultReply reply;

size_t max_chunk = config_.data_chunk_max_size();
const size_t data_size = data.size();
size_t start = 0;
::grpc::ClientContext context;

auto stream = stub_.SendResult(&context_client_writer, &reply);
std::ofstream output(armonik::api::common::utils::pathJoin(data_folder_, key),
std::fstream::binary | std::fstream::trunc);
output << data;
output.close();

armonik::api::grpc::v1::agent::Result init_msg;
init_msg.mutable_init()->set_key(std::move(key));
init_msg.set_communication_token(token_);
armonik::api::grpc::v1::agent::NotifyResultDataResponse reply;
armonik::api::grpc::v1::agent::NotifyResultDataRequest request;
request.set_communication_token(token_);
armonik::api::grpc::v1::agent::NotifyResultDataRequest::ResultIdentifier result_id;
result_id.set_session_id(session_id_);
result_id.set_result_id(key);
*(request.mutable_ids()->Add()) = result_id;

stream->Write(init_msg);

while (start < data_size) {
size_t chunkSize = std::min(max_chunk, data_size - start);

armonik::api::grpc::v1::agent::Result msg;
msg.set_communication_token(token_);
msg.mutable_data()->mutable_data()->assign(data.data() + start, chunkSize);

stream->Write(msg);

start += chunkSize;
}

armonik::api::grpc::v1::agent::Result end_msg;
end_msg.set_communication_token(token_);
end_msg.mutable_data()->set_data_complete(true);
stream->Write(end_msg);

stream->WritesDone();
::grpc::Status status = stream->Finish();
auto status = stub_.NotifyResultData(&context, request, &reply);

if (!status.ok()) {
std::stringstream message;
message << "Error: " << status.error_code() << ": " << status.error_message()
<< ". details: " << status.error_details() << std::endl;
throw armonik::api::common::exceptions::ArmoniKApiException(message.str());
}
return reply;

if (reply.result_ids_size() != 1) {
throw armonik::api::common::exceptions::ArmoniKApiException("Received erroneous reply for send data");
}
});
}

Expand Down
Loading

0 comments on commit dafff72

Please sign in to comment.