Skip to content

Commit

Permalink
Merge pull request espressomd#2342 from RudolfWeeber/coldet_checkpoint
Browse files Browse the repository at this point in the history
Py: pickle support for collision detection + test
  • Loading branch information
fweik authored and RudolfWeeber committed Oct 28, 2018
1 parent f88858c commit 8e9f756
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/python/espressomd/collision_detection.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,12 @@ class CollisionDetection(ScriptInterfaceHelper):
if self._int_mode[key] == int_mode:
return key
raise Exception("Unknown integer collision mode %d" % int_mode)

# Pickle support
def __reduce__(self):
return _restore_collision_detection, (self.get_params(),)


def _restore_collision_detection(params):
print(params)
return CollisionDetection(**params)
3 changes: 3 additions & 0 deletions src/python/espressomd/system.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ cdef class System(object):
odict['minimize_energy'] = System.__getattribute__(
self, "minimize_energy")
odict['thermostat'] = System.__getattribute__(self, "thermostat")
IF COLLISION_DETECTION:
odict['collision_detection'] = System.__getattribute__(
self, "collision_detection")
return odict

def __setstate__(self, params):
Expand Down
22 changes: 22 additions & 0 deletions testsuite/python/collision_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class CollisionDetection(ut.TestCase):
part_type_after_glueing = 3
other_type = 5

def get_state_set_state_consistency(self):
state = self.s.collision_detection.get_params()
self.s.collision_detection.set_params(**state)
self.assertEqual(state, self.s.collision_detection.get_params())

def test_00_interface_and_defaults(self):
# Is it off by default
self.assertEqual(self.s.collision_detection.mode, "off")
Expand Down Expand Up @@ -84,6 +89,7 @@ def test_bind_centers(self):
# Check that it cannot be activated
self.s.collision_detection.set_params(
mode="bind_centers", distance=0.11, bond_centers=self.H)
self.get_state_set_state_consistency()
self.s.integrator.run(1, recalc_forces=True)
bond0 = ((self.s.bonded_inter[0], 1),)
bond1 = ((self.s.bonded_inter[0], 0),)
Expand All @@ -99,6 +105,7 @@ def test_bind_centers(self):

# Check turning it off
self.s.collision_detection.set_params(mode="off")
self.get_state_set_state_consistency()
self.assertEqual(self.s.collision_detection.mode, "off")

def run_test_bind_at_point_of_collision_for_pos(self, *positions):
Expand All @@ -119,6 +126,7 @@ def run_test_bind_at_point_of_collision_for_pos(self, *positions):

self.s.collision_detection.set_params(
mode="bind_at_point_of_collision", distance=0.11, bond_centers=self.H, bond_vs=self.H2, part_type_vs=1, vs_placement=0.4)
self.get_state_set_state_consistency()
self.s.integrator.run(0, recalc_forces=True)
self.verify_state_after_bind_at_poc(expected_np)

Expand Down Expand Up @@ -265,6 +273,7 @@ def test_bind_at_point_of_collision_random(self):
bond_vs=self.H2,
part_type_vs=1,
vs_placement=0.4)
self.get_state_set_state_consistency()

# Integrate lj liquid
self.s.integrator.set_vv()
Expand Down Expand Up @@ -344,6 +353,7 @@ def run_test_glue_to_surface_for_pos(self, *positions):

self.s.collision_detection.set_params(
mode="glue_to_surface", distance=0.11, distance_glued_particle_to_vs=0.02, bond_centers=self.H, bond_vs=self.H2, part_type_vs=self.part_type_vs, part_type_to_attach_vs_to=self.part_type_to_attach_vs_to, part_type_to_be_glued=self.part_type_to_be_glued, part_type_after_glueing=self.part_type_after_glueing)
self.get_state_set_state_consistency()
self.s.integrator.run(0, recalc_forces=True)
self.verify_state_after_glue_to_surface(expected_np)

Expand Down Expand Up @@ -493,6 +503,7 @@ def test_glue_to_surface_random(self):
# Collision detection
self.s.collision_detection.set_params(
mode="glue_to_surface", distance=0.11, distance_glued_particle_to_vs=0.02, bond_centers=self.H, bond_vs=self.H2, part_type_vs=self.part_type_vs, part_type_to_attach_vs_to=self.part_type_to_attach_vs_to, part_type_to_be_glued=self.part_type_to_be_glued, part_type_after_glueing=self.part_type_after_glueing)
self.get_state_set_state_consistency()

# Integrate lj liquid
self.s.integrator.set_vv()
Expand Down Expand Up @@ -603,6 +614,7 @@ def test_AngleHarmonic(self):
self.s.collision_detection.set_params(
mode="bind_three_particles", bond_centers=self.H,
bond_three_particles=2, three_particle_binding_angle_resolution=res, distance=cutoff)
self.get_state_set_state_consistency()
self.s.integrator.run(0, recalc_forces=True)
self.verify_triangle_binding(cutoff, self.s.bonded_inter[2], res)

Expand Down Expand Up @@ -714,6 +726,16 @@ def verify_triangle_binding(self, distance, first_bond, angle_res):

self.assertEqual(expected_angle_bonds, found_angle_bonds)

def test_zz_serialization(self):
self.s.collision_detection.set_params(
mode="bind_centers", distance=0.11, bond_centers=self.H)
reduce = self.s.collision_detection.__reduce__()
res = reduce[0](reduce[1][0])
self.assertEqual(res.__class__.__name__, "CollisionDetection")
self.assertEqual(res.mode, "bind_centers")
self.assertAlmostEqual(res.distance, 0.11, delta=1E-9)
self.assertEqual(res.bond_centers, self.H)


if __name__ == "__main__":
ut.main()
3 changes: 3 additions & 0 deletions testsuite/python/save_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,7 @@
if espressomd.has_features('LB'):
lbf[1, 1, 1].velocity = [0.1, 0.2, 0.3]
lbf.save_checkpoint("@CMAKE_CURRENT_BINARY_DIR@/lb.cpt", 1)
if espressomd.has_features("COLLISION_DETECTION"):
system.collision_detection.set_params(
mode="bind_centers", distance=0.11, bond_centers=harmonic_bond)
checkpoint.save(0)
7 changes: 7 additions & 0 deletions testsuite/python/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def test_mean_variance_calculator(self):
def test_p3m(self):
self.assertTrue(any(isinstance(actor, espressomd.electrostatics.P3M)
for actor in system.actors.active_actors))

@ut.skipIf(not espressomd.has_features("COLLISION_DETECTION"), "skipped for missing features")
def test_collision_detection(self):
coldet = system.collision_detection
self.assertEqual(coldet.mode, "bind_centers")
self.assertAlmostEqual(coldet.distance, 0.11, delta=1E-9)
self.assertTrue(coldet.bond_centers, system.bonded_inter[0])


if __name__ == '__main__':
Expand Down

0 comments on commit 8e9f756

Please sign in to comment.