diff --git a/hpc/LoadBalancer.hpp b/hpc/LoadBalancer.hpp index e8d6158..7e5f36a 100644 --- a/hpc/LoadBalancer.hpp +++ b/hpc/LoadBalancer.hpp @@ -59,6 +59,8 @@ std::string readLineFromFile(const std::string& filename) return line; } +using SafeUniqueModelPointer = std::unique_ptr>; + class JobManager { public: @@ -66,67 +68,46 @@ class JobManager // The returned object MUST release any resources that it holds once it goes out of scope in the code of the caller. // This can be achieved by returning a unique pointer with an appropriate deleter. // This method may return a nullptr to deny a request. - virtual std::unique_ptr requestModelAccess(const std::string& model_name) = 0; + virtual SafeUniqueModelPointer requestModelAccess(const std::string& model_name) = 0; // To initialize the load balancer we first need a list of model names that are available on a server. // Typically, this can be achieved by simply running the model code and requesting the model names from the server. - // Therefore, the implementation can most likely use the same mechanism that is also used for granting model access, - // which is why this method was placed in this class to avoid code duplication. + // Therefore, the implementation can most likely use the same mechanism that is also used for granting model access. virtual std::vector getModelNames() = 0; virtual ~JobManager() {}; }; -class FileBasedModelDeleter -{ -public: - FileBasedModelDeleter(std::string cancelation_command, std::string file_to_delete) - : cancelation_command(cancelation_command), file_to_delete(file_to_delete) {} - - void operator()(umbridge::Model* model) { - delete model; - std::filesystem::remove(file_to_delete); - std::system(cancelation_command.c_str()); - } - -protected: - std::string cancelation_command; - std::string file_to_delete; -}; -using unique_file_based_model_ptr = std::unique_ptr; - class FileBasedJobManager : public JobManager { public: - virtual std::unique_ptr requestModelAccess(const std::string& model_name) override + virtual SafeUniqueModelPointer requestModelAccess(const std::string& model_name) override { - std::string submission_command = getSubmissionCommand(); - std::string job_id = submitJob(submission_command); + std::string job_id = submitJob(); std::string server_url = readURL(job_id); - FileBasedModelDeleter deleter(getCancelationCommand(job_id), getURLFileName(job_id)); - unique_file_based_model_ptr client(new umbridge::HTTPModel(server_url, model_name), deleter); + + SafeUniqueModelPointer client(new umbridge::HTTPModel(server_url, model_name), createModelDeleter(job_id)); return client; } protected: virtual std::string getSubmissionCommand() = 0; virtual std::string getCancelationCommand(const std::string& job_id) = 0; - std::unique_ptr setDeleter(std::unique_ptr client) - { - - } - std::unique_ptr submitJobAndStartClient(const std::string& model_name) { - - } - - std::unique_ptr connectToServer(const std::string& server_url, const std::string& model_name) + std::function createModelDeleter(const std::string& job_id) { - return std::make_unique(server_url, model_name); + std::string file_to_delete = getURLFileName(job_id); + std::string cancelation_command = getCancelationCommand(job_id); + return [file_to_delete, cancelation_command](umbridge::Model* model) { + delete model; + std::filesystem::remove(file_to_delete); + std::system(cancelation_command.c_str()); + }; } std::string getURLFileName(const std::string& job_id) { - return url_file_prefix + job_id + url_file_suffix; + std::filesystem::path url_file_name(url_file_prefix + job_id + url_file_suffix); + return (url_dir / url_file_name).string(); } std::string readURL(const std::string& job_id) @@ -134,7 +115,7 @@ class FileBasedJobManager : public JobManager return readLineFromFile(getURLFileName(job_id)); } - std::string submitJob(const std::string& command) + std::string submitJob() { // Add optional delay to job submissions to prevent issues in some cases. if (submission_delay_ms) { @@ -142,7 +123,7 @@ class FileBasedJobManager : public JobManager std::this_thread::sleep_for(std::chrono::milliseconds(submission_delay_ms)); } // Submit job and increase job count - std::string command_output = getCommandOutput(command); + std::string command_output = getCommandOutput(getSubmissionCommand()); job_count++; // Extract the actual job id from the command output @@ -179,16 +160,16 @@ class FileBasedJobManager : public JobManager return job_script; } - const std::filesystem::path submission_script_dir; - const std::filesystem::path submission_script_default; + std::filesystem::path submission_script_dir; + std::filesystem::path submission_script_default; // Model-specifc job-script format: - const std::string submission_script_model_specific_prefix; - const std::string submission_script_model_specific_suffix; + std::string submission_script_model_specific_prefix; + std::string submission_script_model_specific_suffix; // URL file format: - const std::filesystem::path url_dir; - const std::string url_file_prefix; - const std::string url_file_suffix; + std::filesystem::path url_dir; + std::string url_file_prefix; + std::string url_file_suffix; int submission_delay_ms = 0; std::mutex submission_mutex;