Skip to content

Commit

Permalink
Rewrite script interface object containers serialization (#4724)
Browse files Browse the repository at this point in the history
Fixes #4280

Description of changes:
- checkpoint restrictions on the number of MPI ranks have been lifted
  • Loading branch information
kodiakhq[bot] authored May 12, 2023
2 parents e2503b6 + a90ec05 commit b3efb66
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 363 deletions.
5 changes: 0 additions & 5 deletions doc/sphinx/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,6 @@ Be aware of the following limitations:
for a specific combination of features, please share your findings
with the |es| community.

* Checkpointing only supports recursion on the head node. It is therefore
impossible to checkpoint a :class:`espressomd.system.System` instance that
contains LB boundaries, constraint unions or auto-update accumulators when the
simulation is running with 2 or more MPI nodes.

* The active actors, i.e., the content of ``system.actors``, are checkpointed.
For lattice-Boltzmann fluids, this only includes the parameters such as the
lattice constant (``agrid``). The actual flow field has to be saved
Expand Down
14 changes: 11 additions & 3 deletions src/python/espressomd/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,8 +1546,6 @@ def add(self, *args):
return bond_id

def __getitem__(self, bond_id):
self._assert_key_type(bond_id)

if self.call_method('has_bond', bond_id=bond_id):
bond_obj = self.call_method('get_bond', bond_id=bond_id)
bond_obj._bond_id = bond_id
Expand Down Expand Up @@ -1590,7 +1588,6 @@ def _insert_bond(self, bond_id, bond_obj):
bond_id = self.call_method("insert", object=bond_obj)
else:
# Throw error if attempting to overwrite a bond of different type
self._assert_key_type(bond_id)
if self.call_method("contains", key=bond_id):
old_type = self._bond_classes[
self.call_method("get_zero_based_type", bond_id=bond_id)]
Expand Down Expand Up @@ -1625,3 +1622,14 @@ def __getstate__(self):
def __setstate__(self, params):
for bond_id, (type_number, bond_params) in params.items():
self[bond_id] = self._bond_classes[type_number](**bond_params)

def __reduce__(self):
so_callback, (so_name, so_bytestring) = super().__reduce__()
return (BondedInteractions._restore_object,
(so_callback, (so_name, so_bytestring), self.__getstate__()))

@classmethod
def _restore_object(cls, so_callback, so_callback_args, state):
so = so_callback(*so_callback_args)
so.__setstate__(state)
return so
63 changes: 0 additions & 63 deletions src/python/espressomd/script_interface.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -457,15 +457,6 @@ class ScriptObjectList(ScriptInterfaceHelper):
"""

def __init__(self, *args, **kwargs):
if args:
params, (_unpickle_so_class, (_so_name, bytestring)) = args
assert _so_name == self._so_name
self = _unpickle_so_class(_so_name, bytestring)
self.__setstate__(params)
else:
super().__init__(**kwargs)

def __getitem__(self, key):
return self.call_method("get_elements")[key]

Expand All @@ -477,24 +468,6 @@ class ScriptObjectList(ScriptInterfaceHelper):
def __len__(self):
return self.call_method("size")

@classmethod
def _restore_object(cls, so_callback, so_callback_args, state):
so = so_callback(*so_callback_args)
so.__setstate__(state)
return so

def __reduce__(self):
so_callback, (so_name, so_bytestring) = super().__reduce__()
return (ScriptObjectList._restore_object,
(so_callback, (so_name, so_bytestring), self.__getstate__()))

def __getstate__(self):
return self.call_method("get_elements")

def __setstate__(self, object_list):
for item in object_list:
self.add(item)


class ScriptObjectMap(ScriptInterfaceHelper):
"""
Expand All @@ -507,17 +480,6 @@ class ScriptObjectMap(ScriptInterfaceHelper):
"""

_key_type = int

def __init__(self, *args, **kwargs):
if args:
params, (_unpickle_so_class, (_so_name, bytestring)) = args
assert _so_name == self._so_name
self = _unpickle_so_class(_so_name, bytestring)
self.__setstate__(params)
else:
super().__init__(**kwargs)

def remove(self, key):
"""
Remove the element with the given key.
Expand All @@ -536,15 +498,12 @@ class ScriptObjectMap(ScriptInterfaceHelper):
return self.call_method("size")

def __getitem__(self, key):
self._assert_key_type(key)
return self.call_method("get", key=key)

def __setitem__(self, key, value):
self._assert_key_type(key)
self.call_method("insert", key=key, object=value)

def __delitem__(self, key):
self._assert_key_type(key)
self.call_method("erase", key=key)

def keys(self):
Expand All @@ -556,28 +515,6 @@ class ScriptObjectMap(ScriptInterfaceHelper):
def items(self):
for k in self.keys(): yield k, self[k]

def _assert_key_type(self, key):
if not utils.is_valid_type(key, self._key_type):
raise TypeError(f"Key has to be of type {self._key_type.__name__}")

@classmethod
def _restore_object(cls, so_callback, so_callback_args, state):
so = so_callback(*so_callback_args)
so.__setstate__(state)
return so

def __reduce__(self):
so_callback, (so_name, so_bytestring) = super().__reduce__()
return (ScriptObjectMap._restore_object,
(so_callback, (so_name, so_bytestring), self.__getstate__()))

def __getstate__(self):
return dict(self.items())

def __setstate__(self, params):
for key, val in params.items():
self[key] = val


# Map from script object names to their corresponding python classes
_python_class_by_so_name = {}
Expand Down
4 changes: 2 additions & 2 deletions src/script_interface/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

add_library(
espresso_script_interface SHARED
initialize.cpp ObjectHandle.cpp object_container_mpi_guard.cpp
GlobalContext.cpp ContextManager.cpp ParallelExceptionHandler.cpp)
initialize.cpp ObjectHandle.cpp GlobalContext.cpp ContextManager.cpp
ParallelExceptionHandler.cpp)
add_library(espresso::script_interface ALIAS espresso_script_interface)
set_target_properties(espresso_script_interface
PROPERTIES CXX_CLANG_TIDY "${ESPRESSO_CXX_CLANG_TIDY}")
Expand Down
42 changes: 42 additions & 0 deletions src/script_interface/ObjectContainer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (C) 2023 The ESPResSo project
*
* This file is part of ESPResSo.
*
* ESPResSo is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ESPResSo is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef SCRIPT_INTERFACE_OBJECT_CONTAINER_HPP
#define SCRIPT_INTERFACE_OBJECT_CONTAINER_HPP

#include "script_interface/auto_parameters/AutoParameters.hpp"

#include <type_traits>

namespace ScriptInterface {

/**
* @brief Base class for containers whose @c BaseType might be a full
* specialization of @ref AutoParameters.
*/
template <template <typename...> class Container, typename ManagedType,
class BaseType,
class =
std::enable_if_t<std::is_base_of_v<ObjectHandle, ManagedType>>>
using ObjectContainer = std::conditional_t<
std::is_same_v<BaseType, ObjectHandle>,
AutoParameters<Container<ManagedType, BaseType>, BaseType>, BaseType>;

} // namespace ScriptInterface

#endif
65 changes: 27 additions & 38 deletions src/script_interface/ObjectList.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
#ifndef SCRIPT_INTERFACE_OBJECT_LIST_HPP
#define SCRIPT_INTERFACE_OBJECT_LIST_HPP

#include "script_interface/ObjectContainer.hpp"
#include "script_interface/ScriptInterface.hpp"
#include "script_interface/get_value.hpp"
#include "script_interface/object_container_mpi_guard.hpp"

#include <utils/serialization/pack.hpp>

Expand All @@ -35,20 +35,39 @@
#include <vector>

namespace ScriptInterface {

/**
* @brief Owning list of ObjectHandles
* @tparam ManagedType Type of the managed objects, needs to be
* derived from ObjectHandle
* derived from @ref ObjectHandle
*/
template <typename ManagedType, class BaseType = ObjectHandle,
class =
std::enable_if_t<std::is_base_of_v<ObjectHandle, ManagedType>>>
class ObjectList : public BaseType {
template <typename ManagedType, class BaseType = ObjectHandle>
class ObjectList : public ObjectContainer<ObjectList, ManagedType, BaseType> {
public:
using Base = ObjectContainer<ObjectList, ManagedType, BaseType>;
using Base::add_parameters;

private:
std::vector<std::shared_ptr<ManagedType>> m_elements;

virtual void add_in_core(const std::shared_ptr<ManagedType> &obj_ptr) = 0;
virtual void remove_in_core(const std::shared_ptr<ManagedType> &obj_ptr) = 0;

public:
ObjectList() {
add_parameters({
{"_objects", AutoParameter::read_only,
[this]() { return make_vector_of_variants(m_elements); }},
});
}

void do_construct(VariantMap const &params) override {
m_elements = get_value_or<decltype(m_elements)>(params, "_objects", {});
for (auto const &object : m_elements) {
add_in_core(object);
}
}

/**
* @brief Add an element to the list.
*
Expand Down Expand Up @@ -107,12 +126,7 @@ class ObjectList : public BaseType {
}

if (method == "get_elements") {
std::vector<Variant> ret;
ret.reserve(m_elements.size());
for (auto const &e : m_elements)
ret.emplace_back(e);

return ret;
return make_vector_of_variants(m_elements);
}

if (method == "clear") {
Expand All @@ -128,33 +142,8 @@ class ObjectList : public BaseType {
return m_elements.empty();
}

return BaseType::do_call_method(method, parameters);
return Base::do_call_method(method, parameters);
}

private:
std::string get_internal_state() const override {
object_container_mpi_guard(BaseType::name(), m_elements.size(),
BaseType::context()->get_comm().size());

std::vector<std::string> object_states(m_elements.size());

boost::transform(m_elements, object_states.begin(),
[](auto const &e) { return e->serialize(); });

return Utils::pack(object_states);
}

void set_internal_state(std::string const &state) override {
auto const object_states = Utils::unpack<std::vector<std::string>>(state);

for (auto const &packed_object : object_states) {
auto o = std::dynamic_pointer_cast<ManagedType>(
BaseType::deserialize(packed_object, *BaseType::context()));
add(std::move(o));
}
}

std::vector<std::shared_ptr<ManagedType>> m_elements;
};
} // Namespace ScriptInterface
#endif
Loading

0 comments on commit b3efb66

Please sign in to comment.