diff --git a/tests/test_space.py b/tests/test_space.py index 7a1c6d3b0bd..0282b8f3f82 100644 --- a/tests/test_space.py +++ b/tests/test_space.py @@ -300,15 +300,13 @@ def test_set_cells_no_condition(self): np.testing.assert_array_equal(self.layer.data, np.full((10, 10), 2)) def test_set_cells_with_condition(self): - condition = np.full((10, 10), False) - condition[5, :] = True # Only update the 5th row + self.layer.set_cell((5, 5), 1) + condition = lambda x: x == 0 self.layer.set_cells(3, condition) - self.assertEqual(np.sum(self.layer.data[5, :] == 3), 10) - self.assertEqual(np.sum(self.layer.data != 3), 90) - - def test_set_cells_invalid_condition(self): - with self.assertRaises(ValueError): - self.layer.set_cells(4, condition=np.full((5, 5), False)) # Invalid shape + self.assertEqual(self.layer.data[5, 5], 1) + self.assertEqual(self.layer.data[0, 0], 3) + # Check if the sum is correct + self.assertEqual(np.sum(self.layer.data), 3 * 99 + 1) # Modify Cells Test def test_modify_cell_lambda(self):