Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove more global variables #4950

Merged
merged 12 commits into from
Aug 1, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "nonbonded_interactions/nonbonded_interaction_data.hpp"

#include "electrostatics/coulomb.hpp"
#include "system/System.hpp"

#include <algorithm>
#include <cassert>
Expand Down Expand Up @@ -122,3 +123,7 @@ double InteractionsNonBonded::maximal_cutoff() const {
}
return max_cut_nonbonded;
}

void InteractionsNonBonded::on_non_bonded_ia_change() const {
get_system().on_non_bonded_ia_change();
}
12 changes: 8 additions & 4 deletions src/core/nonbonded_interactions/nonbonded_interaction_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include "TabulatedPotential.hpp"
#include "config/config.hpp"
#include "system/Leaf.hpp"

#include <utils/index.hpp>
#include <utils/math/int_pow.hpp>
Expand Down Expand Up @@ -353,7 +354,7 @@ struct IA_parameters {
#endif
};

class InteractionsNonBonded {
class InteractionsNonBonded : public System::Leaf<InteractionsNonBonded> {
/** @brief List of pairwise interactions. */
std::vector<std::shared_ptr<IA_parameters>> m_nonbonded_ia_params{};
/** @brief Maximal particle type seen so far. */
Expand Down Expand Up @@ -414,15 +415,15 @@ class InteractionsNonBonded {
* @return Reference to interaction parameters for the type pair.
*/
auto &get_ia_param(int i, int j) {
return *m_nonbonded_ia_params[get_ia_param_key(i, j)];
return *m_nonbonded_ia_params.at(get_ia_param_key(i, j));
jngrad marked this conversation as resolved.
Show resolved Hide resolved
}

auto const &get_ia_param(int i, int j) const {
return *m_nonbonded_ia_params[get_ia_param_key(i, j)];
return *m_nonbonded_ia_params.at(get_ia_param_key(i, j));
}

auto get_ia_param_ref_counted(int i, int j) const {
return m_nonbonded_ia_params[get_ia_param_key(i, j)];
return m_nonbonded_ia_params.at(get_ia_param_key(i, j));
}

void set_ia_param(int i, int j, std::shared_ptr<IA_parameters> const &ia) {
Expand All @@ -436,4 +437,7 @@ class InteractionsNonBonded {

/** @brief Get maximal cutoff. */
double maximal_cutoff() const;

/** @brief Notify system that non-bonded interactions changed. */
void on_non_bonded_ia_change() const;
};
1 change: 1 addition & 0 deletions src/core/system/System.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ void System::initialize() {
cell_structure->bind_system(handle);
lees_edwards->bind_system(handle);
thermostat->bind_system(handle);
nonbonded_ias->bind_system(handle);
auto_update_accumulators->bind_system(handle);
constraints->bind_system(handle);
#ifdef CUDA
Expand Down
67 changes: 2 additions & 65 deletions src/python/espressomd/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,6 @@ def set_params(self, **kwargs):
err_msg = f"setting {self.__class__.__name__} raised an error"
self.call_method("set_params", handle_errors_message=err_msg, **params)

def __reduce__(self):
return (NonBondedInteraction._restore_object,
(self.__class__, self.get_params()))

@classmethod
def _restore_object(cls, derived_class, kwargs):
return derived_class(**kwargs)

@abc.abstractmethod
def default_params(self):
pass
Expand Down Expand Up @@ -681,31 +673,8 @@ class NonBondedInteractionHandle(ScriptInterfaceHelper):
"""
_so_name = "Interactions::NonBondedInteractionHandle"

def __getattr__(self, key):
obj = super().__getattr__(key)
return globals()[obj.__class__.__name__](
_types=self.call_method("get_types"), **obj.get_params())

def _serialize(self):
serialized = []
for name, obj in self.get_params().items():
serialized.append((name, obj.__reduce__()[1]))
return serialized

def reset(self):
for key in self._valid_parameters():
getattr(self, key).deactivate()

@classmethod
def _restore_object(cls, types, kwargs):
objects = {}
for name, (obj_class, obj_params) in kwargs:
objects[name] = obj_class(**obj_params)
return NonBondedInteractionHandle(_types=types, **objects)

def __reduce__(self):
return (NonBondedInteractionHandle._restore_object,
(self.call_method("get_types"), self._serialize()))
self.call_method("reset")


@script_interface_register
Expand All @@ -724,40 +693,8 @@ class NonBondedInteractions(ScriptInterfaceHelper):
_so_creation_policy = "GLOBAL"
_so_bind_methods = ("reset",)

def keys(self):
return [tuple(x) for x in self.call_method("keys")]

def __getitem__(self, key):
self.call_method("check_key", key=key)
return NonBondedInteractionHandle(_types=key)

def __setitem__(self, key, value):
self.call_method("insert", key=key, object=value)

def __getstate__(self):
n_types = self.call_method("get_n_types")
state = []
for i in range(n_types):
for j in range(i, n_types):
handle = NonBondedInteractionHandle(_types=(i, j))
state.append(((i, j), handle._serialize()))
return {"state": state}

def __setstate__(self, params):
for types, kwargs in params["state"]:
obj = NonBondedInteractionHandle._restore_object(types, kwargs)
self.call_method("insert", key=types, object=obj)

@classmethod
def _restore_object(cls, so_callback, so_callback_args, state):
so = so_callback(*so_callback_args)
so.__setstate__(state)
return so

def __reduce__(self):
so_callback, (so_name, so_bytestring) = super().__reduce__()
return (NonBondedInteractions._restore_object,
(so_callback, (so_name, so_bytestring), self.__getstate__()))
return self.call_method("get_handle", key=key)


class BONDED_IA(enum.IntEnum):
Expand Down
2 changes: 1 addition & 1 deletion src/python/espressomd/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _restore_object(cls, so_callback, so_callback_args, state):

def __getstate__(self):
checkpointable_properties = [
"non_bonded_inter", "bonded_inter",
"bonded_inter",
"part",
]
if has_features("COLLISION_DETECTION"):
Expand Down
1 change: 0 additions & 1 deletion src/script_interface/ObjectHandle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class ObjectHandle {
*/
void construct(VariantMap const &params) { do_construct(params); }

private:
virtual void do_construct(VariantMap const &params) {
for (auto const &p : params) {
do_set_parameter(p.first, p.second);
Expand Down
Loading