Skip to content

Commit

Permalink
Improve particle exception handling
Browse files Browse the repository at this point in the history
Convert fatal errors into ValueError for invalid particle ids.
Make python code for particle creation more readable.
  • Loading branch information
jngrad committed Mar 18, 2022
1 parent 4d319c3 commit db253de
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 38 deletions.
8 changes: 6 additions & 2 deletions src/core/particle_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,9 @@ void build_particle_node() { mpi_who_has(); }
* @brief Get the mpi rank which owns the particle with id.
*/
int get_particle_node(int id) {
if (id < 0)
throw std::runtime_error("Invalid particle id!");
if (id < 0) {
throw std::domain_error("Invalid particle id: " + std::to_string(id));
}

if (particle_node.empty())
build_particle_node();
Expand Down Expand Up @@ -724,6 +725,9 @@ void mpi_place_particle(int node, int p_id, const Utils::Vector3d &pos) {
}

int place_particle(int p_id, Utils::Vector3d const &pos) {
if (p_id < 0) {
throw std::domain_error("Invalid particle id: " + std::to_string(p_id));
}
if (particle_exists(p_id)) {
mpi_place_particle(get_particle_node(p_id), p_id, pos);

Expand Down
2 changes: 1 addition & 1 deletion src/python/espressomd/particle_data.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ cdef extern from "particle_data.hpp":
# Setter/getter/modifier functions functions
void prefetch_particle_data(vector[int] ids)

int place_particle(int part, const Vector3d & p)
int place_particle(int part, const Vector3d & p) except +

void set_particle_v(int part, const Vector3d & v)

Expand Down
65 changes: 32 additions & 33 deletions src/python/espressomd/particle_data.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1747,84 +1747,83 @@ cdef class ParticleList:
# Did we get a dictionary
if len(args) == 1:
if hasattr(args[0], "__getitem__"):
P = args[0]
particles_dict = args[0]
else:
if len(args) == 0 and len(kwargs.keys()) != 0:
P = kwargs
if len(args) == 0 and len(kwargs) != 0:
particles_dict = kwargs
else:
raise ValueError(
"add() takes either a dictionary or a bunch of keyword args.")

# Check for presence of pos attribute
if "pos" not in P:
if "pos" not in particles_dict:
raise ValueError(
"pos attribute must be specified for new particle")

if len(np.array(P["pos"]).shape) == 2:
return self._place_new_particles(P)
if len(np.array(particles_dict["pos"]).shape) == 2:
return self._place_new_particles(particles_dict)
else:
return self._place_new_particle(P)
return self._place_new_particle(particles_dict)

def _place_new_particle(self, P):
def _place_new_particle(self, p_dict):
# Handling of particle id
if "id" not in P:
if "id" not in p_dict:
# Generate particle id
P["id"] = get_maximal_particle_id() + 1
p_dict["id"] = get_maximal_particle_id() + 1
else:
if particle_exists(P["id"]):
raise Exception(f"Particle {P['id']} already exists.")
if particle_exists(p_dict["id"]):
raise Exception(f"Particle {p_dict['id']} already exists.")

# Prevent setting of contradicting attributes
IF DIPOLES:
if 'dip' in P and 'dipm' in P:
if 'dip' in p_dict and 'dipm' in p_dict:
raise ValueError("Contradicting attributes: dip and dipm. Setting \
dip is sufficient as the length of the vector defines the scalar dipole moment.")
IF ROTATION:
if 'dip' in P and 'quat' in P:
if 'dip' in p_dict and 'quat' in p_dict:
raise ValueError("Contradicting attributes: dip and quat. \
Setting dip overwrites the rotation of the particle around the dipole axis. \
Set quat and scalar dipole moment (dipm) instead.")

# The ParticleList can not be used yet, as the particle
# doesn't yet exist. Hence, the setting of position has to be
# done here. the code is from the pos:property of ParticleHandle
# done here.
check_type_or_throw_except(
P["pos"], 3, float, "Position must be 3 floats.")
if place_particle(P["id"], make_Vector3d(P["pos"])) == -1:
p_dict["pos"], 3, float, "Position must be 3 floats.")
error_code = place_particle(p_dict["id"], make_Vector3d(p_dict["pos"]))
if error_code == -1:
raise Exception("particle could not be set.")

# Pos is taken care of
del P["pos"]
pid = P["id"]
del P["id"]
# position is taken care of
del p_dict["pos"]
pid = p_dict.pop("id")

if P != {}:
self.by_id(pid).update(P)
if p_dict != {}:
self.by_id(pid).update(p_dict)

return self.by_id(pid)

def _place_new_particles(self, Ps):
def _place_new_particles(self, p_list_dict):
# Check if all entries have the same length
n_parts = len(Ps["pos"])
if not all(np.shape(Ps[k]) and len(Ps[k]) == n_parts for k in Ps):
n_parts = len(p_list_dict["pos"])
if not all(np.shape(v) and len(v) ==
n_parts for v in p_list_dict.values()):
raise ValueError(
"When adding several particles at once, all lists of attributes have to have the same size")

# If particle ids haven't been provided, use free ones
# beyond the highest existing one
if not "id" in Ps:
if "id" not in p_list_dict:
first_id = get_maximal_particle_id() + 1
Ps["id"] = range(first_id, first_id + n_parts)
p_list_dict["id"] = range(first_id, first_id + n_parts)

# Place the particles
for i in range(n_parts):
P = {}
for k in Ps:
P[k] = Ps[k][i]
self._place_new_particle(P)
p_dict = {k: v[i] for k, v in p_list_dict.items()}
self._place_new_particle(p_dict)

# Return slice of added particles
return self.by_ids(Ps["id"])
return self.by_ids(p_list_dict["id"])

# Iteration over all existing particles
def __iter__(self):
Expand Down
14 changes: 12 additions & 2 deletions testsuite/python/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,21 @@ def test_image_box(self):

np.testing.assert_equal(np.copy(p.image_box), [1, 1, 1])

def test_accessing_invalid_id_raises(self):
def test_invalid_particle_ids_exceptions(self):
self.system.part.clear()
handle_to_non_existing_particle = self.system.part.by_id(42)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, "Particle node for id 42 not found"):
handle_to_non_existing_particle.id
p = self.system.part.add(pos=[0., 0., 0.], id=0)
with self.assertRaisesRegex(RuntimeError, "Particle node for id 42 not found"):
p._id = 42
p.node
for i in range(1, 10):
with self.assertRaisesRegex(ValueError, f"Invalid particle id: {-i}"):
p._id = -i
p.node
with self.assertRaisesRegex(ValueError, f"Invalid particle id: {-i}"):
self.system.part.add(pos=[0., 0., 0.], id=-i)

def test_parallel_property_setters(self):
s = self.system
Expand Down

0 comments on commit db253de

Please sign in to comment.