-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* refactor: Rename to targets_remaining * docs: Formatting and expand docs * refactor: Move target and reward checks into utils module * fix: Set agent and target numbers via generator * refactor: Terminate episode if all targets found * test: Add swarms.common tests * refactor: Move agent initialisation into generator * test: Add environment utility tests
- Loading branch information
1 parent
06de3a0
commit 34beab6
Showing
9 changed files
with
459 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# Copyright 2022 InstaDeep Ltd. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List, Tuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
import pytest | ||
|
||
from jumanji.environments.swarms.common import types, updates, viewer | ||
|
||
|
||
@pytest.fixture | ||
def params() -> types.AgentParams: | ||
return types.AgentParams( | ||
max_rotate=0.5, | ||
max_accelerate=0.01, | ||
min_speed=0.01, | ||
max_speed=0.05, | ||
view_angle=0.5, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"heading, speed, actions, expected", | ||
[ | ||
[0.0, 0.01, [1.0, 0.0], (0.5 * jnp.pi, 0.01)], | ||
[0.0, 0.01, [-1.0, 0.0], (1.5 * jnp.pi, 0.01)], | ||
[jnp.pi, 0.01, [1.0, 0.0], (1.5 * jnp.pi, 0.01)], | ||
[jnp.pi, 0.01, [-1.0, 0.0], (0.5 * jnp.pi, 0.01)], | ||
[1.75 * jnp.pi, 0.01, [1.0, 0.0], (0.25 * jnp.pi, 0.01)], | ||
[0.0, 0.01, [0.0, 1.0], (0.0, 0.02)], | ||
[0.0, 0.01, [0.0, -1.0], (0.0, 0.01)], | ||
[0.0, 0.02, [0.0, -1.0], (0.0, 0.01)], | ||
[0.0, 0.05, [0.0, -1.0], (0.0, 0.04)], | ||
[0.0, 0.05, [0.0, 1.0], (0.0, 0.05)], | ||
], | ||
) | ||
def test_velocity_update( | ||
params: types.AgentParams, | ||
heading: float, | ||
speed: float, | ||
actions: List[float], | ||
expected: Tuple[float, float], | ||
) -> None: | ||
key = jax.random.PRNGKey(101) | ||
|
||
state = types.AgentState( | ||
pos=jnp.zeros((1, 2)), | ||
heading=jnp.array([heading]), | ||
speed=jnp.array([speed]), | ||
) | ||
actions = jnp.array([actions]) | ||
|
||
new_heading, new_speed = updates.update_velocity(key, params, (actions, state)) | ||
|
||
assert jnp.isclose(new_heading[0], expected[0]) | ||
assert jnp.isclose(new_speed[0], expected[1]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"pos, heading, speed, expected", | ||
[ | ||
[[0.0, 0.5], 0.0, 0.1, [0.1, 0.5]], | ||
[[0.0, 0.5], jnp.pi, 0.1, [0.9, 0.5]], | ||
[[0.5, 0.0], 0.5 * jnp.pi, 0.1, [0.5, 0.1]], | ||
[[0.5, 0.0], 1.5 * jnp.pi, 0.1, [0.5, 0.9]], | ||
], | ||
) | ||
def test_move(pos: List[float], heading: float, speed: float, expected: List[float]) -> None: | ||
pos = jnp.array(pos) | ||
new_pos = updates.move(pos, heading, speed) | ||
|
||
assert jnp.allclose(new_pos, jnp.array(expected)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"pos, heading, speed, actions, expected_pos, expected_heading, expected_speed", | ||
[ | ||
[[0.0, 0.5], 0.0, 0.01, [0.0, 0.0], [0.01, 0.5], 0.0, 0.01], | ||
[[0.5, 0.0], 0.0, 0.01, [1.0, 0.0], [0.5, 0.01], 0.5 * jnp.pi, 0.01], | ||
[[0.5, 0.0], 0.0, 0.01, [-1.0, 0.0], [0.5, 0.99], 1.5 * jnp.pi, 0.01], | ||
[[0.0, 0.5], 0.0, 0.01, [0.0, 1.0], [0.02, 0.5], 0.0, 0.02], | ||
[[0.0, 0.5], 0.0, 0.01, [0.0, -1.0], [0.01, 0.5], 0.0, 0.01], | ||
[[0.0, 0.5], 0.0, 0.05, [0.0, 1.0], [0.05, 0.5], 0.0, 0.05], | ||
], | ||
) | ||
def test_state_update( | ||
params: types.AgentParams, | ||
pos: List[float], | ||
heading: float, | ||
speed: float, | ||
actions: List[float], | ||
expected_pos: List[float], | ||
expected_heading: float, | ||
expected_speed: float, | ||
) -> None: | ||
key = jax.random.PRNGKey(101) | ||
|
||
state = types.AgentState( | ||
pos=jnp.array([pos]), | ||
heading=jnp.array([heading]), | ||
speed=jnp.array([speed]), | ||
) | ||
actions = jnp.array([actions]) | ||
|
||
new_state = updates.update_state(key, params, state, actions) | ||
|
||
assert isinstance(new_state, types.AgentState) | ||
assert jnp.allclose(new_state.pos, jnp.array([expected_pos])) | ||
assert jnp.allclose(new_state.heading, jnp.array([expected_heading])) | ||
assert jnp.allclose(new_state.speed, jnp.array([expected_speed])) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"pos, view_angle, expected", | ||
[ | ||
[[0.05, 0.0], 0.5, [1.0, 1.0, 0.5, 1.0, 1.0]], | ||
[[0.0, 0.05], 0.5, [0.5, 1.0, 1.0, 1.0, 1.0]], | ||
[[0.0, 0.95], 0.5, [1.0, 1.0, 1.0, 1.0, 0.5]], | ||
[[0.95, 0.0], 0.5, [1.0, 1.0, 1.0, 1.0, 1.0]], | ||
[[0.05, 0.0], 0.25, [1.0, 1.0, 0.5, 1.0, 1.0]], | ||
[[0.0, 0.05], 0.25, [1.0, 1.0, 1.0, 1.0, 1.0]], | ||
[[0.0, 0.95], 0.25, [1.0, 1.0, 1.0, 1.0, 1.0]], | ||
[[0.01, 0.0], 0.5, [1.0, 1.0, 0.1, 1.0, 1.0]], | ||
], | ||
) | ||
def test_view(pos: List[float], view_angle: float, expected: List[float]) -> None: | ||
state_a = types.AgentState( | ||
pos=jnp.zeros((2,)), | ||
heading=0.0, | ||
speed=0.0, | ||
) | ||
|
||
state_b = types.AgentState( | ||
pos=jnp.array(pos), | ||
heading=0.0, | ||
speed=0.0, | ||
) | ||
|
||
obs = updates.view(None, (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1) | ||
assert jnp.allclose(obs, jnp.array(expected)) | ||
|
||
|
||
def test_viewer_utils() -> None: | ||
f, ax = plt.subplots() | ||
f, ax = viewer.format_plot(f, ax) | ||
|
||
assert isinstance(f, matplotlib.figure.Figure) | ||
assert isinstance(ax, matplotlib.axes.Axes) | ||
|
||
state = types.AgentState( | ||
pos=jnp.zeros((3, 2)), | ||
heading=jnp.zeros((3,)), | ||
speed=jnp.zeros((3,)), | ||
) | ||
|
||
quiver = viewer.draw_agents(ax, state, "red") | ||
|
||
assert isinstance(quiver, matplotlib.quiver.Quiver) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.