Skip to content

Commit

Permalink
fix: PR fixes (#4)
Browse files Browse the repository at this point in the history
* refactor: Rename to targets_remaining

* docs: Formatting and expand docs

* refactor: Move target and reward checks into utils module

* fix: Set agent and target numbers via generator

* refactor: Terminate episode if all targets found

* test: Add swarms.common tests

* refactor: Move agent initialisation into generator

* test: Add environment utility tests
  • Loading branch information
zombie-einstein authored Nov 14, 2024
1 parent 06de3a0 commit 34beab6
Show file tree
Hide file tree
Showing 9 changed files with 459 additions and 116 deletions.
27 changes: 20 additions & 7 deletions docs/environments/search_and_rescue.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,33 @@ space is a uniform space with unit dimensions, and wrapped at the boundaries.

where `1.0` indicates there is no agents along that ray, and `0.5` is the normalised
distance to the other agent.
- `target_remaining`: float in the range [0, 1]. The normalised number of targets
- `targets_remaining`: float in the range `[0, 1]`. The normalised number of targets
remaining to be detected (i.e. 1.0 when no targets have been found).
- `time_remaining`: float in the range [0, 1]. The normalised number of steps remaining
- `time_remaining`: float in the range `[0, 1]`. The normalised number of steps remaining
to locate the targets (i.e. 0.0 at the end of the episode).

## Actions

Jax array (float) of `(num_searchers, 2)` in the range [-1, 1]. Each entry in the
Jax array (float) of `(num_searchers, 2)` in the range `[-1, 1]`. Each entry in the
array represents an update of each agents velocity in the next step. Searching agents
update their velocity each step by rotating and accelerating/decelerating. Values
are clipped to the range `[-1, 1]` and then scaled by max rotation and acceleration
parameters. Agents are restricted to velocities within a fixed range of speeds.
update their velocity each step by rotating and accelerating/decelerating, where the
values are `[rotation, acceleration]`. Values are clipped to the range `[-1, 1]`
and then scaled by max rotation and acceleration parameters, i.e. the new values each
step are given by

```
heading = heading + max_rotation * action[0]
```

and speed

```
speed = speed + max_acceleration * action[1]
```

Once applied, agent speeds are clipped to velocities within a fixed range of speeds.

## Rewards

Jax array (float) of `(num_searchers, 2)`. Rewards are generated for each agent individually.
Jax array (float) of `(num_searchers,)`. Rewards are generated for each agent individually.
Agents are rewarded 1.0 for locating a target that has not already been detected.
173 changes: 173 additions & 0 deletions jumanji/environments/swarms/common/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple

import jax
import jax.numpy as jnp
import matplotlib
import matplotlib.pyplot as plt
import pytest

from jumanji.environments.swarms.common import types, updates, viewer


@pytest.fixture
def params() -> types.AgentParams:
return types.AgentParams(
max_rotate=0.5,
max_accelerate=0.01,
min_speed=0.01,
max_speed=0.05,
view_angle=0.5,
)


@pytest.mark.parametrize(
"heading, speed, actions, expected",
[
[0.0, 0.01, [1.0, 0.0], (0.5 * jnp.pi, 0.01)],
[0.0, 0.01, [-1.0, 0.0], (1.5 * jnp.pi, 0.01)],
[jnp.pi, 0.01, [1.0, 0.0], (1.5 * jnp.pi, 0.01)],
[jnp.pi, 0.01, [-1.0, 0.0], (0.5 * jnp.pi, 0.01)],
[1.75 * jnp.pi, 0.01, [1.0, 0.0], (0.25 * jnp.pi, 0.01)],
[0.0, 0.01, [0.0, 1.0], (0.0, 0.02)],
[0.0, 0.01, [0.0, -1.0], (0.0, 0.01)],
[0.0, 0.02, [0.0, -1.0], (0.0, 0.01)],
[0.0, 0.05, [0.0, -1.0], (0.0, 0.04)],
[0.0, 0.05, [0.0, 1.0], (0.0, 0.05)],
],
)
def test_velocity_update(
params: types.AgentParams,
heading: float,
speed: float,
actions: List[float],
expected: Tuple[float, float],
) -> None:
key = jax.random.PRNGKey(101)

state = types.AgentState(
pos=jnp.zeros((1, 2)),
heading=jnp.array([heading]),
speed=jnp.array([speed]),
)
actions = jnp.array([actions])

new_heading, new_speed = updates.update_velocity(key, params, (actions, state))

assert jnp.isclose(new_heading[0], expected[0])
assert jnp.isclose(new_speed[0], expected[1])


@pytest.mark.parametrize(
"pos, heading, speed, expected",
[
[[0.0, 0.5], 0.0, 0.1, [0.1, 0.5]],
[[0.0, 0.5], jnp.pi, 0.1, [0.9, 0.5]],
[[0.5, 0.0], 0.5 * jnp.pi, 0.1, [0.5, 0.1]],
[[0.5, 0.0], 1.5 * jnp.pi, 0.1, [0.5, 0.9]],
],
)
def test_move(pos: List[float], heading: float, speed: float, expected: List[float]) -> None:
pos = jnp.array(pos)
new_pos = updates.move(pos, heading, speed)

assert jnp.allclose(new_pos, jnp.array(expected))


@pytest.mark.parametrize(
"pos, heading, speed, actions, expected_pos, expected_heading, expected_speed",
[
[[0.0, 0.5], 0.0, 0.01, [0.0, 0.0], [0.01, 0.5], 0.0, 0.01],
[[0.5, 0.0], 0.0, 0.01, [1.0, 0.0], [0.5, 0.01], 0.5 * jnp.pi, 0.01],
[[0.5, 0.0], 0.0, 0.01, [-1.0, 0.0], [0.5, 0.99], 1.5 * jnp.pi, 0.01],
[[0.0, 0.5], 0.0, 0.01, [0.0, 1.0], [0.02, 0.5], 0.0, 0.02],
[[0.0, 0.5], 0.0, 0.01, [0.0, -1.0], [0.01, 0.5], 0.0, 0.01],
[[0.0, 0.5], 0.0, 0.05, [0.0, 1.0], [0.05, 0.5], 0.0, 0.05],
],
)
def test_state_update(
params: types.AgentParams,
pos: List[float],
heading: float,
speed: float,
actions: List[float],
expected_pos: List[float],
expected_heading: float,
expected_speed: float,
) -> None:
key = jax.random.PRNGKey(101)

state = types.AgentState(
pos=jnp.array([pos]),
heading=jnp.array([heading]),
speed=jnp.array([speed]),
)
actions = jnp.array([actions])

new_state = updates.update_state(key, params, state, actions)

assert isinstance(new_state, types.AgentState)
assert jnp.allclose(new_state.pos, jnp.array([expected_pos]))
assert jnp.allclose(new_state.heading, jnp.array([expected_heading]))
assert jnp.allclose(new_state.speed, jnp.array([expected_speed]))


@pytest.mark.parametrize(
"pos, view_angle, expected",
[
[[0.05, 0.0], 0.5, [1.0, 1.0, 0.5, 1.0, 1.0]],
[[0.0, 0.05], 0.5, [0.5, 1.0, 1.0, 1.0, 1.0]],
[[0.0, 0.95], 0.5, [1.0, 1.0, 1.0, 1.0, 0.5]],
[[0.95, 0.0], 0.5, [1.0, 1.0, 1.0, 1.0, 1.0]],
[[0.05, 0.0], 0.25, [1.0, 1.0, 0.5, 1.0, 1.0]],
[[0.0, 0.05], 0.25, [1.0, 1.0, 1.0, 1.0, 1.0]],
[[0.0, 0.95], 0.25, [1.0, 1.0, 1.0, 1.0, 1.0]],
[[0.01, 0.0], 0.5, [1.0, 1.0, 0.1, 1.0, 1.0]],
],
)
def test_view(pos: List[float], view_angle: float, expected: List[float]) -> None:
state_a = types.AgentState(
pos=jnp.zeros((2,)),
heading=0.0,
speed=0.0,
)

state_b = types.AgentState(
pos=jnp.array(pos),
heading=0.0,
speed=0.0,
)

obs = updates.view(None, (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1)
assert jnp.allclose(obs, jnp.array(expected))


def test_viewer_utils() -> None:
f, ax = plt.subplots()
f, ax = viewer.format_plot(f, ax)

assert isinstance(f, matplotlib.figure.Figure)
assert isinstance(ax, matplotlib.axes.Axes)

state = types.AgentState(
pos=jnp.zeros((3, 2)),
heading=jnp.zeros((3,)),
speed=jnp.zeros((3,)),
)

quiver = viewer.draw_agents(ax, state, "red")

assert isinstance(quiver, matplotlib.quiver.Quiver)
25 changes: 0 additions & 25 deletions jumanji/environments/swarms/common/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,31 +70,6 @@ def move(pos: chex.Array, heading: chex.Array, speed: chex.Array) -> chex.Array:
return (pos + d_pos) % 1.0


def init_state(n: int, params: types.AgentParams, key: chex.PRNGKey) -> types.AgentState:
"""
Randomly initialise state of a group of agents
Args:
n: Number of agents to initialise.
params: Agent parameters.
key: JAX random key.
Returns:
AgentState: Random agent states (i.e. position, headings, and speeds)
"""
k1, k2, k3 = jax.random.split(key, 3)

positions = jax.random.uniform(k1, (n, 2))
speeds = jax.random.uniform(k2, (n,), minval=params.min_speed, maxval=params.max_speed)
headings = jax.random.uniform(k3, (n,), minval=0.0, maxval=2.0 * jnp.pi)

return types.AgentState(
pos=positions,
speed=speeds,
heading=headings,
)


def update_state(
key: chex.PRNGKey,
params: types.AgentParams,
Expand Down
Loading

0 comments on commit 34beab6

Please sign in to comment.