Skip to content

Commit

Permalink
script_interface: MPI-safe exceptions
Browse files Browse the repository at this point in the history
Make C++ exceptions from core classes safe in a MPI-parallel context.
When an exception occurs, re-throw the exception on the head node
and throw a ScriptInterface::Exception on worker nodes.
  • Loading branch information
jngrad committed May 17, 2022
1 parent 5c2435b commit 03d25e9
Showing 21 changed files with 418 additions and 56 deletions.
14 changes: 10 additions & 4 deletions src/core/RuntimeErrorCollector.cpp
Original file line number Diff line number Diff line change
@@ -36,11 +36,10 @@ RuntimeErrorCollector::RuntimeErrorCollector(boost::mpi::communicator comm)
: m_comm(std::move(comm)) {}

RuntimeErrorCollector::~RuntimeErrorCollector() {
if (!m_errors.empty())
if (!m_errors.empty()) {
/* Print remaining error messages on destruction */
std::cerr << "There were unhandled errors.\n";
/* Print remaining error messages on destruction */
for (auto const &e : m_errors) {
std::cerr << e.format() << std::endl;
flush();
}
}

@@ -108,6 +107,13 @@ int RuntimeErrorCollector::count(RuntimeError::ErrorLevel level) {

void RuntimeErrorCollector::clear() { m_errors.clear(); }

void RuntimeErrorCollector::flush() {
for (auto const &e : m_errors) {
std::cerr << e.format() << std::endl;
}
this->clear();
}

std::vector<RuntimeError> RuntimeErrorCollector::gather() {
std::vector<RuntimeError> all_errors{};
std::swap(all_errors, m_errors);
5 changes: 5 additions & 0 deletions src/core/RuntimeErrorCollector.hpp
Original file line number Diff line number Diff line change
@@ -74,6 +74,11 @@ class RuntimeErrorCollector {
*/
void clear();

/**
* @brief Flush error messages to standard error.
*/
void flush();

std::vector<RuntimeError> gather();
void gather_local();

14 changes: 13 additions & 1 deletion src/core/errorhandling.cpp
Original file line number Diff line number Diff line change
@@ -62,7 +62,7 @@ RuntimeErrorStream _runtimeMessageStream(RuntimeError::ErrorLevel level,
return {*runtimeErrorCollector, level, file, line, function};
}

void mpi_gather_runtime_errors_local() {
static void mpi_gather_runtime_errors_local() {
runtimeErrorCollector->gather_local();
}

@@ -72,6 +72,14 @@ std::vector<RuntimeError> mpi_gather_runtime_errors() {
m_callbacks->call(mpi_gather_runtime_errors_local);
return runtimeErrorCollector->gather();
}

std::vector<RuntimeError> mpi_gather_runtime_errors_all(bool is_head_node) {
if (is_head_node) {
return runtimeErrorCollector->gather();
}
runtimeErrorCollector->gather_local();
return {};
}
} // namespace ErrorHandling

void errexit() {
@@ -89,3 +97,7 @@ int check_runtime_errors(boost::mpi::communicator const &comm) {
return boost::mpi::all_reduce(comm, check_runtime_errors_local(),
std::plus<int>());
}

void flush_runtime_errors_local() {
ErrorHandling::runtimeErrorCollector->flush();
}
11 changes: 11 additions & 0 deletions src/core/errorhandling.hpp
Original file line number Diff line number Diff line change
@@ -72,6 +72,14 @@ int check_runtime_errors(boost::mpi::communicator const &comm);
*/
int check_runtime_errors_local();

/**
* @brief Flush runtime errors to standard error on the local node.
* This is used to clear pending runtime error messages when the
* call site is handling an exception that needs to be re-thrown
* instead of being queued as an additional runtime error message.
*/
void flush_runtime_errors_local();

namespace ErrorHandling {
/**
* @brief Initialize the error collection system.
@@ -94,7 +102,10 @@ RuntimeErrorStream _runtimeMessageStream(RuntimeError::ErrorLevel level,
ErrorHandling::RuntimeError::ErrorLevel::WARNING, __FILE__, __LINE__, \
PRETTY_FUNCTION_EXTENSION)

/** @brief Gather messages on main rank. Only call from main rank. */
std::vector<RuntimeError> mpi_gather_runtime_errors();
/** @brief Gather messages on main rank. Call on all ranks. */
std::vector<RuntimeError> mpi_gather_runtime_errors_all(bool is_head_node);

} // namespace ErrorHandling

2 changes: 1 addition & 1 deletion src/script_interface/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
add_library(
ScriptInterface SHARED
initialize.cpp ObjectHandle.cpp object_container_mpi_guard.cpp
GlobalContext.cpp ContextManager.cpp)
GlobalContext.cpp ContextManager.cpp ParallelExceptionHandler.cpp)

add_subdirectory(accumulators)
add_subdirectory(bond_breakage)
1 change: 1 addition & 0 deletions src/script_interface/Context.hpp
Original file line number Diff line number Diff line change
@@ -98,6 +98,7 @@ class Context : public std::enable_shared_from_this<Context> {
virtual boost::string_ref name(const ObjectHandle *o) const = 0;

virtual bool is_head_node() const = 0;
virtual void parallel_try_catch(std::function<void()> const &cb) const = 0;

virtual ~Context() = default;
};
4 changes: 2 additions & 2 deletions src/script_interface/ContextManager.cpp
Original file line number Diff line number Diff line change
@@ -55,8 +55,8 @@ std::string ContextManager::serialize(const ObjectHandle *o) const {

ContextManager::ContextManager(Communication::MpiCallbacks &callbacks,
const Utils::Factory<ObjectHandle> &factory) {
auto const mpi_rank = callbacks.comm().rank();
auto local_context = std::make_shared<LocalContext>(factory, mpi_rank);
auto local_context =
std::make_shared<LocalContext>(factory, callbacks.comm());

/* If there is only one node, we can treat all objects as local, and thus
* never invoke any callback. */
2 changes: 1 addition & 1 deletion src/script_interface/ContextManager.hpp
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@
#include "Context.hpp"
#include "Variant.hpp"

#include "MpiCallbacks.hpp"
#include "core/MpiCallbacks.hpp"

#include <utils/Factory.hpp>

13 changes: 10 additions & 3 deletions src/script_interface/GlobalContext.hpp
Original file line number Diff line number Diff line change
@@ -29,16 +29,19 @@

#include "Context.hpp"
#include "LocalContext.hpp"
#include "MpiCallbacks.hpp"
#include "ObjectHandle.hpp"
#include "ParallelExceptionHandler.hpp"
#include "packed_variant.hpp"

#include "core/MpiCallbacks.hpp"

#include <utils/Factory.hpp>

#include <boost/serialization/utility.hpp>

#include <cstddef>
#include <memory>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
@@ -61,7 +64,6 @@ namespace ScriptInterface {
class GlobalContext : public Context {
using ObjectId = std::size_t;

private:
/* Instances on this node that are managed by the
* head node. */
std::unordered_map<ObjectId, ObjectRef> m_local_objects;
@@ -70,7 +72,8 @@ class GlobalContext : public Context {

bool m_is_head_node;

private:
ParallelExceptionHandler m_parallel_exception_handler;

Communication::CallbackHandle<ObjectId, const std::string &,
const PackedMap &>
cb_make_handle;
@@ -87,6 +90,7 @@ class GlobalContext : public Context {
std::shared_ptr<LocalContext> node_local_context)
: m_local_objects(), m_node_local_context(std::move(node_local_context)),
m_is_head_node(callbacks.comm().rank() == 0),
m_parallel_exception_handler(callbacks.comm()),
cb_make_handle(&callbacks,
[this](ObjectId id, const std::string &name,
const PackedMap &parameters) {
@@ -162,6 +166,9 @@ class GlobalContext : public Context {
boost::string_ref name(const ObjectHandle *o) const override;

bool is_head_node() const override { return m_is_head_node; }
void parallel_try_catch(std::function<void()> const &cb) const override {
m_parallel_exception_handler.parallel_try_catch<std::exception>(cb);
}
};
} // namespace ScriptInterface

14 changes: 12 additions & 2 deletions src/script_interface/LocalContext.hpp
Original file line number Diff line number Diff line change
@@ -21,11 +21,15 @@

#include "Context.hpp"
#include "ObjectHandle.hpp"
#include "ParallelExceptionHandler.hpp"

#include <utils/Factory.hpp>

#include <boost/mpi/communicator.hpp>

#include <cassert>
#include <memory>
#include <stdexcept>
#include <string>
#include <utility>

@@ -39,10 +43,13 @@ namespace ScriptInterface {
class LocalContext : public Context {
Utils::Factory<ObjectHandle> m_factory;
bool m_is_head_node;
ParallelExceptionHandler m_parallel_exception_handler;

public:
LocalContext(Utils::Factory<ObjectHandle> factory, int mpi_rank)
: m_factory(std::move(factory)), m_is_head_node(mpi_rank == 0) {}
LocalContext(Utils::Factory<ObjectHandle> factory,
boost::mpi::communicator const &comm)
: m_factory(std::move(factory)), m_is_head_node(comm.rank() == 0),
m_parallel_exception_handler(comm) {}

const Utils::Factory<ObjectHandle> &factory() const { return m_factory; }

@@ -68,6 +75,9 @@ class LocalContext : public Context {
}

bool is_head_node() const override { return m_is_head_node; }
void parallel_try_catch(std::function<void()> const &cb) const override {
m_parallel_exception_handler.parallel_try_catch<std::exception>(cb);
}
};
} // namespace ScriptInterface

85 changes: 85 additions & 0 deletions src/script_interface/ParallelExceptionHandler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (C) 2022 The ESPResSo project
*
* This file is part of ESPResSo.
*
* ESPResSo is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ESPResSo is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

#include "ParallelExceptionHandler.hpp"

#include "Exception.hpp"

#include "core/MpiCallbacks.hpp"
#include "core/RuntimeError.hpp"
#include "core/communication.hpp"
#include "core/errorhandling.hpp"

#include <boost/mpi/collectives.hpp>
#include <boost/serialization/string.hpp>

#include <cassert>
#include <cstddef>
#include <functional>
#include <stdexcept>
#include <string>
#include <vector>

namespace ScriptInterface {

void ParallelExceptionHandler::handle_impl(std::exception const *error) const {
auto const head_node = 0;
auto const this_node = m_comm.rank();

enum : unsigned char {
NO_RANK_FAILED = 0u,
SOME_RANK_FAILED = 1u,
THIS_RANK_SUCCESS = 0u,
THIS_RANK_FAILED = 1u,
MAIN_RANK_FAILED = 2u,
};
auto const this_fail_flag =
((error)
? ((this_node == head_node) ? MAIN_RANK_FAILED : THIS_RANK_FAILED)
: THIS_RANK_SUCCESS);
auto const fail_flag = boost::mpi::all_reduce(
m_comm, static_cast<unsigned char>(this_fail_flag), std::bit_or<>());
auto const main_rank_failed = fail_flag & MAIN_RANK_FAILED;
auto const some_rank_failed = fail_flag & SOME_RANK_FAILED;

if (main_rank_failed) {
flush_runtime_errors_local();
if (this_node == head_node) {
throw;
}
throw Exception("");
}

if (some_rank_failed) {
flush_runtime_errors_local();
std::vector<std::string> messages;
std::string this_message{(error) ? error->what() : ""};
boost::mpi::gather(m_comm, this_message, messages, head_node);
if (this_node == head_node) {
std::string error_message{"an error occurred on one or more MPI ranks:"};
for (std::size_t i = 0; i < messages.size(); ++i) {
error_message += "\n rank " + std::to_string(i) + ": " + messages[i];
}
throw std::runtime_error(error_message.c_str());
}
throw Exception("");
}
}

} // namespace ScriptInterface
Loading

0 comments on commit 03d25e9

Please sign in to comment.