Skip to content

Commit

Permalink
python: Use standard checkpointing mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
jngrad committed Sep 10, 2020
1 parent df90e9c commit 3115d3d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
43 changes: 27 additions & 16 deletions src/python/espressomd/interactions.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,15 @@ cdef class BondedInteraction:
def __reduce__(self):
return (self.__class__, (self._bond_id,))

def __getstate__(self):
# for most bonds, we only need to pickle arguments passed to the
# constructor, yet some bond classes need to pass additional data
return (self.params,)

def __setstate__(self, params):
# parameters are already set in the core by the contructor
pass

def is_valid(self):
"""Check, if the data stored in the instance still matches what is in ESPResSo.
Expand Down Expand Up @@ -2091,6 +2100,15 @@ class ThermalizedBond(BondedInteraction):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __getstate__(self):
return (self.params, thermalized_bond.rng_counter())

def __setstate__(self, params):
# parameters are already set in the core by the contructor, we only
# need the RNG state
rng_counter = params[1]
thermalized_bond_set_rng_counter(rng_counter)

def type_number(self):
return BONDED_IA_THERMALIZED_DIST

Expand All @@ -2099,13 +2117,13 @@ class ThermalizedBond(BondedInteraction):

def valid_keys(self):
return {"temp_com", "gamma_com", "temp_distance",
"gamma_distance", "r_cut", "seed", "_counter"}
"gamma_distance", "r_cut", "seed"}

def required_keys(self):
return {"temp_com", "gamma_com", "temp_distance", "gamma_distance"}

def set_default_params(self):
self._params = {"r_cut": 0., "seed": None, "_counter": None}
self._params = {"r_cut": 0., "seed": None}

def _get_params_from_es_core(self):
return \
Expand All @@ -2119,7 +2137,6 @@ class ThermalizedBond(BondedInteraction):
bonded_ia_params[
self._bond_id].p.thermalized_bond.gamma_distance,
"r_cut": bonded_ia_params[self._bond_id].p.thermalized_bond.r_cut,
"_counter": thermalized_bond.rng_counter(),
"seed": thermalized_bond.rng_seed()
}

Expand All @@ -2133,8 +2150,6 @@ class ThermalizedBond(BondedInteraction):
if self.params["seed"] < 0:
raise ValueError("seed must be a positive integer")
thermalized_bond_set_rng_seed(self.params["seed"])
if self.params.get("_counter") is not None:
thermalized_bond_set_rng_counter(self.params["_counter"])

thermalized_bond_set_params(
self._bond_id, self._params["temp_com"], self._params["gamma_com"],
Expand Down Expand Up @@ -3359,17 +3374,13 @@ class BondedInteractions:

def __getstate__(self):
params = {}
for i, bonded_instance in enumerate(self):
if hasattr(bonded_instance, 'params'):
params[i] = bonded_instance.params
params[i]['bond_type'] = bonded_instance.type_number()
else:
params[i] = None
for i, bond_object in enumerate(self):
params[i] = (bond_object.__getstate__(), bond_object.type_number())
return params

def __setstate__(self, params):
for i in params:
if params[i] is not None:
bond_type = params[i]['bond_type']
del params[i]['bond_type']
self[i] = bonded_interaction_classes[bond_type](**params[i])
for i, (bond_pickle, bond_type) in params.items():
bond_params = bond_pickle[0]
self[i] = bonded_interaction_classes[bond_type](**bond_params)
if len(bond_pickle) > 1:
self[i].__setstate__(bond_pickle)
3 changes: 1 addition & 2 deletions testsuite/python/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,7 @@ def test_bonded_inter(self):
if 'THERM.LB' not in modes:
state = system.part[1].bonds[1][0]._get_params_from_es_core()
reference = {'temp_com': 0., 'gamma_com': 0., 'temp_distance': 0.2,
'gamma_distance': 0.5, 'r_cut': 2.0, 'seed': 51,
'_counter': 0}
'gamma_distance': 0.5, 'r_cut': 2.0, 'seed': 51}
self.assertEqual(state, reference)
state = system.part[1].bonds[1][0].params
self.assertEqual(state, reference)
Expand Down

0 comments on commit 3115d3d

Please sign in to comment.