Skip to content

Commit

Permalink
feat: Allow variable environment dimensions (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
zombie-einstein authored Nov 19, 2024
1 parent 072db18 commit 162a74d
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 72 deletions.
67 changes: 40 additions & 27 deletions jumanji/environments/swarms/common/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,30 +72,38 @@ def test_velocity_update(


@pytest.mark.parametrize(
"pos, heading, speed, expected",
"pos, heading, speed, expected, env_size",
[
[[0.0, 0.5], 0.0, 0.1, [0.1, 0.5]],
[[0.0, 0.5], jnp.pi, 0.1, [0.9, 0.5]],
[[0.5, 0.0], 0.5 * jnp.pi, 0.1, [0.5, 0.1]],
[[0.5, 0.0], 1.5 * jnp.pi, 0.1, [0.5, 0.9]],
[[0.0, 0.5], 0.0, 0.1, [0.1, 0.5], 1.0],
[[0.0, 0.5], jnp.pi, 0.1, [0.9, 0.5], 1.0],
[[0.5, 0.0], 0.5 * jnp.pi, 0.1, [0.5, 0.1], 1.0],
[[0.5, 0.0], 1.5 * jnp.pi, 0.1, [0.5, 0.9], 1.0],
[[0.4, 0.2], 0.0, 0.2, [0.1, 0.2], 0.5],
[[0.1, 0.2], jnp.pi, 0.2, [0.4, 0.2], 0.5],
[[0.2, 0.4], 0.5 * jnp.pi, 0.2, [0.2, 0.1], 0.5],
[[0.2, 0.1], 1.5 * jnp.pi, 0.2, [0.2, 0.4], 0.5],
],
)
def test_move(pos: List[float], heading: float, speed: float, expected: List[float]) -> None:
def test_move(
pos: List[float], heading: float, speed: float, expected: List[float], env_size: float
) -> None:
pos = jnp.array(pos)
new_pos = updates.move(pos, heading, speed)
new_pos = updates.move(pos, heading, speed, env_size)

assert jnp.allclose(new_pos, jnp.array(expected))


@pytest.mark.parametrize(
"pos, heading, speed, actions, expected_pos, expected_heading, expected_speed",
"pos, heading, speed, actions, expected_pos, expected_heading, expected_speed, env_size",
[
[[0.0, 0.5], 0.0, 0.01, [0.0, 0.0], [0.01, 0.5], 0.0, 0.01],
[[0.5, 0.0], 0.0, 0.01, [1.0, 0.0], [0.5, 0.01], 0.5 * jnp.pi, 0.01],
[[0.5, 0.0], 0.0, 0.01, [-1.0, 0.0], [0.5, 0.99], 1.5 * jnp.pi, 0.01],
[[0.0, 0.5], 0.0, 0.01, [0.0, 1.0], [0.02, 0.5], 0.0, 0.02],
[[0.0, 0.5], 0.0, 0.01, [0.0, -1.0], [0.01, 0.5], 0.0, 0.01],
[[0.0, 0.5], 0.0, 0.05, [0.0, 1.0], [0.05, 0.5], 0.0, 0.05],
[[0.0, 0.5], 0.0, 0.01, [0.0, 0.0], [0.01, 0.5], 0.0, 0.01, 1.0],
[[0.5, 0.0], 0.0, 0.01, [1.0, 0.0], [0.5, 0.01], 0.5 * jnp.pi, 0.01, 1.0],
[[0.5, 0.0], 0.0, 0.01, [-1.0, 0.0], [0.5, 0.99], 1.5 * jnp.pi, 0.01, 1.0],
[[0.0, 0.5], 0.0, 0.01, [0.0, 1.0], [0.02, 0.5], 0.0, 0.02, 1.0],
[[0.0, 0.5], 0.0, 0.01, [0.0, -1.0], [0.01, 0.5], 0.0, 0.01, 1.0],
[[0.0, 0.5], 0.0, 0.05, [0.0, 1.0], [0.05, 0.5], 0.0, 0.05, 1.0],
[[0.495, 0.25], 0.0, 0.01, [0.0, 0.0], [0.005, 0.25], 0.0, 0.01, 0.5],
[[0.25, 0.005], 1.5 * jnp.pi, 0.01, [0.0, 0.0], [0.25, 0.495], 1.5 * jnp.pi, 0.01, 0.5],
],
)
def test_state_update(
Expand All @@ -107,6 +115,7 @@ def test_state_update(
expected_pos: List[float],
expected_heading: float,
expected_speed: float,
env_size: float,
) -> None:
key = jax.random.PRNGKey(101)

Expand All @@ -117,7 +126,7 @@ def test_state_update(
)
actions = jnp.array([actions])

new_state = updates.update_state(key, params, state, actions)
new_state = updates.update_state(key, env_size, params, state, actions)

assert isinstance(new_state, types.AgentState)
assert jnp.allclose(new_state.pos, jnp.array([expected_pos]))
Expand All @@ -133,19 +142,21 @@ def test_view_reduction() -> None:


@pytest.mark.parametrize(
"pos, view_angle, expected",
"pos, view_angle, env_size, expected",
[
[[0.05, 0.0], 0.5, [-1.0, -1.0, 0.5, -1.0, -1.0]],
[[0.0, 0.05], 0.5, [0.5, -1.0, -1.0, -1.0, -1.0]],
[[0.0, 0.95], 0.5, [-1.0, -1.0, -1.0, -1.0, 0.5]],
[[0.95, 0.0], 0.5, [-1.0, -1.0, -1.0, -1.0, -1.0]],
[[0.05, 0.0], 0.25, [-1.0, -1.0, 0.5, -1.0, -1.0]],
[[0.0, 0.05], 0.25, [-1.0, -1.0, -1.0, -1.0, -1.0]],
[[0.0, 0.95], 0.25, [-1.0, -1.0, -1.0, -1.0, -1.0]],
[[0.01, 0.0], 0.5, [-1.0, -1.0, 0.1, -1.0, -1.0]],
[[0.05, 0.0], 0.5, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]],
[[0.0, 0.05], 0.5, 1.0, [0.5, -1.0, -1.0, -1.0, -1.0]],
[[0.0, 0.95], 0.5, 1.0, [-1.0, -1.0, -1.0, -1.0, 0.5]],
[[0.95, 0.0], 0.5, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]],
[[0.05, 0.0], 0.25, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]],
[[0.0, 0.05], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]],
[[0.0, 0.95], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]],
[[0.01, 0.0], 0.5, 1.0, [-1.0, -1.0, 0.1, -1.0, -1.0]],
[[0.0, 0.45], 0.5, 1.0, [4.5, -1.0, -1.0, -1.0, -1.0]],
[[0.0, 0.45], 0.5, 0.5, [-1.0, -1.0, -1.0, -1.0, 0.5]],
],
)
def test_view(pos: List[float], view_angle: float, expected: List[float]) -> None:
def test_view(pos: List[float], view_angle: float, env_size: float, expected: List[float]) -> None:
state_a = types.AgentState(
pos=jnp.zeros((2,)),
heading=0.0,
Expand All @@ -158,13 +169,15 @@ def test_view(pos: List[float], view_angle: float, expected: List[float]) -> Non
speed=0.0,
)

obs = updates.view(None, (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1)
obs = updates.view(
None, (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1, env_size=env_size
)
assert jnp.allclose(obs, jnp.array(expected))


def test_viewer_utils() -> None:
f, ax = plt.subplots()
f, ax = viewer.format_plot(f, ax)
f, ax = viewer.format_plot(f, ax, (1.0, 1.0))

assert isinstance(f, matplotlib.figure.Figure)
assert isinstance(ax, matplotlib.axes.Axes)
Expand Down
19 changes: 12 additions & 7 deletions jumanji/environments/swarms/common/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,26 @@ def update_velocity(
return new_heading, new_speeds


def move(pos: chex.Array, heading: chex.Array, speed: chex.Array) -> chex.Array:
def move(pos: chex.Array, heading: chex.Array, speed: chex.Array, env_size: float) -> chex.Array:
"""
Get updated agent positions from current speed and heading
Args:
pos: Agent position
pos: Agent position.
heading: Agent heading (angle).
speed: Agent speed
speed: Agent speed.
env_size: Size of the environment.
Returns:
jax array (float32): Updated agent position
jax array (float32): Updated agent position.
"""
d_pos = jnp.array([speed * jnp.cos(heading), speed * jnp.sin(heading)])
return (pos + d_pos) % 1.0
return (pos + d_pos) % env_size


def update_state(
key: chex.PRNGKey,
env_size: float,
params: types.AgentParams,
state: types.AgentState,
actions: chex.Array,
Expand All @@ -81,6 +83,7 @@ def update_state(
Args:
key: Dummy JAX random key.
env_size: Size of the environment.
params: Agent parameters.
state: Current agent states.
actions: Agent actions, i.e. a 2D array of action for each agent.
Expand All @@ -91,7 +94,7 @@ def update_state(
"""
actions = jnp.clip(actions, min=-1.0, max=1.0)
headings, speeds = update_velocity(key, params, (actions, state))
positions = jax.vmap(move)(state.pos, headings, speeds)
positions = jax.vmap(move, in_axes=(0, 0, 0, None))(state.pos, headings, speeds, env_size)

return types.AgentState(
pos=positions,
Expand Down Expand Up @@ -133,6 +136,7 @@ def view(
*,
n_view: int,
i_range: float,
env_size: float,
) -> chex.Array:
"""
Simple agent view model
Expand All @@ -153,6 +157,7 @@ def view(
n_view: Static number of view rays/subdivisions (i.e. how
many cells the resulting array contains).
i_range: Static agent view/interaction range.
env_size: Size of the environment.
Returns:
jax array (float32): 1D array representing the distance
Expand All @@ -165,7 +170,7 @@ def view(
n_view,
endpoint=True,
)
dx = esquilax.utils.shortest_vector(viewing_agent.pos, viewed_agent.pos)
dx = esquilax.utils.shortest_vector(viewing_agent.pos, viewed_agent.pos, length=env_size)
d = jnp.sqrt(jnp.sum(dx * dx)) / i_range
phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi)
dh = esquilax.utils.shortest_vector(phi, viewing_agent.heading, 2 * jnp.pi)
Expand Down
9 changes: 6 additions & 3 deletions jumanji/environments/swarms/common/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ def draw_agents(ax: Axes, agent_states: AgentState, color: str) -> Quiver:
return q


def format_plot(fig: Figure, ax: Axes, border: float = 0.01) -> Tuple[Figure, Axes]:
def format_plot(
fig: Figure, ax: Axes, env_dims: Tuple[float, float], border: float = 0.01
) -> Tuple[Figure, Axes]:
"""Format a flock/swarm plot, remove ticks and bound to the unit interval
Args:
fig: Matplotlib figure.
ax: Matplotlib axes.
env_dims: Environment dimensions (i.e. its boundaries).
border: Border padding to apply around plot.
Returns:
Expand All @@ -67,7 +70,7 @@ def format_plot(fig: Figure, ax: Axes, border: float = 0.01) -> Tuple[Figure, Ax
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_xlim(0, env_dims[0])
ax.set_ylim(0, env_dims[1])

return fig, ax
18 changes: 15 additions & 3 deletions jumanji/environments/swarms/search_and_rescue/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,12 @@ def __init__(
view_angle=searcher_view_angle,
)
self.max_steps = max_steps
self._viewer = viewer or SearchAndRescueViewer()
self._target_dynamics = target_dynamics or RandomWalk(0.01)
self.generator = generator or RandomGenerator(num_targets=20, num_searchers=10)
self._viewer = viewer or SearchAndRescueViewer()
# Needed to set environment boundaries for plots
if isinstance(self._viewer, SearchAndRescueViewer):
self._viewer.env_size = (self.generator.env_size, self.generator.env_size)
super().__init__()

def __repr__(self) -> str:
Expand All @@ -186,6 +189,7 @@ def __repr__(self) -> str:
f" - num vision: {self.num_vision}",
f" - agent radius: {self.agent_radius}",
f" - max steps: {self.max_steps},"
f" - env size: {self.generator.env_size}"
f" - target dynamics: {self._target_dynamics.__class__.__name__}",
f" - generator: {self.generator.__class__.__name__}",
]
Expand Down Expand Up @@ -222,9 +226,11 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser
# Note: only one new key is needed for the targets, as all other
# keys are just dummy values required by Esquilax
key, target_key = jax.random.split(state.key, num=2)
searchers = update_state(key, self.searcher_params, state.searchers, actions)
searchers = update_state(
key, self.generator.env_size, self.searcher_params, state.searchers, actions
)
# Ensure target positions are wrapped
target_pos = self._target_dynamics(target_key, state.targets.pos) % 1.0
target_pos = self._target_dynamics(target_key, state.targets.pos) % self.generator.env_size
# Grant searchers rewards if in range and not already detected
# spatial maps the has_found_target function over all pair of targets and
# searchers within range of each other and sums rewards per agent.
Expand All @@ -233,13 +239,15 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser
reduction=jnp.add,
default=0.0,
i_range=self.target_contact_range,
dims=self.generator.env_size,
)(
key,
self.searcher_params.view_angle,
searchers,
state.targets,
pos=searchers.pos,
pos_b=target_pos,
env_size=self.generator.env_size,
)
# Mark targets as found if with contact range and view angle of a searcher
# spatial maps the has_been_found function over all pair of targets and
Expand All @@ -249,13 +257,15 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser
reduction=jnp.logical_or,
default=False,
i_range=self.target_contact_range,
dims=self.generator.env_size,
)(
key,
self.searcher_params.view_angle,
state.targets.pos,
searchers,
pos=target_pos,
pos_b=searchers.pos,
env_size=self.generator.env_size,
)
# Targets need to remain found if they already have been
targets_found = jnp.logical_or(targets_found, state.targets.found)
Expand All @@ -282,6 +292,7 @@ def _state_to_observation(self, state: State) -> Observation:
default=-jnp.ones((self.num_vision,)),
include_self=False,
i_range=self.searcher_vision_range,
dims=self.generator.env_size,
)(
state.key,
(self.searcher_params.view_angle, self.agent_radius),
Expand All @@ -290,6 +301,7 @@ def _state_to_observation(self, state: State) -> Observation:
pos=state.searchers.pos,
n_view=self.num_vision,
i_range=self.searcher_vision_range,
env_size=self.generator.env_size,
)

return Observation(
Expand Down
26 changes: 20 additions & 6 deletions jumanji/environments/swarms/search_and_rescue/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None:
assert timestep.step_type == StepType.FIRST


def test_env_step(env: SearchAndRescue, key: chex.PRNGKey) -> None:
@pytest.mark.parametrize("env_size", [1.0, 0.2])
def test_env_step(env: SearchAndRescue, key: chex.PRNGKey, env_size: float) -> None:
"""
Run several steps of the environment with random actions and
check states (i.e. positions, heading, speeds) all fall
inside expected ranges.
"""
n_steps = 22
env.generator.env_size = env_size

def step(
carry: Tuple[chex.PRNGKey, State], _: None
Expand All @@ -89,7 +91,7 @@ def step(
assert isinstance(state_history, State)

assert state_history.searchers.pos.shape == (n_steps, env.generator.num_searchers, 2)
assert jnp.all((0.0 <= state_history.searchers.pos) & (state_history.searchers.pos <= 1.0))
assert jnp.all((0.0 <= state_history.searchers.pos) & (state_history.searchers.pos <= env_size))
assert state_history.searchers.speed.shape == (n_steps, env.generator.num_searchers)
assert jnp.all(
(env.searcher_params.min_speed <= state_history.searchers.speed)
Expand All @@ -101,7 +103,7 @@ def step(
)

assert state_history.targets.pos.shape == (n_steps, env.generator.num_targets, 2)
assert jnp.all((0.0 <= state_history.targets.pos) & (state_history.targets.pos <= 1.0))
assert jnp.all((0.0 <= state_history.targets.pos) & (state_history.targets.pos <= env_size))


def test_env_does_not_smoke(env: SearchAndRescue) -> None:
Expand All @@ -122,28 +124,38 @@ def test_env_specs_do_not_smoke(env: SearchAndRescue) -> None:


@pytest.mark.parametrize(
"searcher_positions, searcher_headings, view_updates",
"searcher_positions, searcher_headings, env_size, view_updates",
[
# Both out of view range
([[0.8, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], []),
([[0.8, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, []),
# Both view each other
([[0.25, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], [(0, 5, 0.25), (1, 5, 0.25)]),
([[0.25, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, [(0, 5, 0.25), (1, 5, 0.25)]),
# One facing wrong direction
(
[[0.25, 0.5], [0.2, 0.5]],
[jnp.pi, jnp.pi],
1.0,
[(0, 5, 0.25)],
),
# Only see closest neighbour
(
[[0.35, 0.5], [0.25, 0.5], [0.2, 0.5]],
[jnp.pi, 0.0, 0.0],
1.0,
[(0, 5, 0.5), (1, 5, 0.5), (2, 5, 0.25)],
),
# Observed around wrapped edge
(
[[0.025, 0.5], [0.975, 0.5]],
[jnp.pi, 0.0],
1.0,
[(0, 5, 0.25), (1, 5, 0.25)],
),
# Observed around wrapped edge of smaller env
(
[[0.025, 0.25], [0.475, 0.25]],
[jnp.pi, 0.0],
0.5,
[(0, 5, 0.25), (1, 5, 0.25)],
),
],
Expand All @@ -153,12 +165,14 @@ def test_searcher_view(
key: chex.PRNGKey,
searcher_positions: List[List[float]],
searcher_headings: List[float],
env_size: float,
view_updates: List[Tuple[int, int, float]],
) -> None:
"""
Test view model generates expected array with different
configurations of agents.
"""
env.generator.env_size = env_size

searcher_positions = jnp.array(searcher_positions)
searcher_headings = jnp.array(searcher_headings)
Expand Down
Loading

0 comments on commit 162a74d

Please sign in to comment.