Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Full Level-Based Foraging(LBF) environment #218

Merged
merged 88 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
d7b804a
feat: types + generator poc done
sash-a Jul 25, 2023
181b314
feat: collision checking and resetting
sash-a Jul 26, 2023
5521591
feat: update types and implement more utils methods
sash-a Jul 27, 2023
6fc12f3
fix: generator
sash-a Jul 27, 2023
4fcf676
feat: more utils tests
sash-a Jul 27, 2023
631fc70
feat: reward and test
sash-a Aug 1, 2023
24e881a
feat: observation spec and action mask
sash-a Aug 1, 2023
b78e65d
fix: loading and eating and tests to go with it
sash-a Aug 2, 2023
34be44a
feat: no foods are placed on edge or adj to other food
sash-a Aug 3, 2023
37556df
feat: test that no foods are adjacent to eachother
sash-a Aug 3, 2023
013acd5
chore: test slice around
sash-a Aug 4, 2023
108f894
fix: state_to_obs
sash-a Aug 4, 2023
2f84d26
fix: obs - env no longer smokes - piping through
sash-a Aug 23, 2023
e7cc966
Merge branch 'main' into feat/lbf
sash-a Aug 23, 2023
366772c
feat: test state_to_obs
sash-a Aug 23, 2023
26f7533
chore: clean and add in tests skeleton methods
sash-a Aug 23, 2023
236bf63
chore: pre-commit
sash-a Aug 23, 2023
8ea7960
feat: place food on grid test
sash-a Aug 28, 2023
c6a0084
feat: test step and reset and done - fix small bugs
sash-a Aug 28, 2023
042289e
feat: tested eating multiple foods at once
sash-a Aug 28, 2023
af715bb
feat: vector obs
sash-a Aug 28, 2023
f45d86e
fix: local agent obs
sash-a Aug 28, 2023
2052fbc
feat: fix masking bug + add attribs to look more like rware
sash-a Aug 28, 2023
208b366
fix: fov was hardcoded
sash-a Aug 28, 2023
0eb27bf
fix: correct reward
sash-a Aug 28, 2023
c9b78df
feat: registered lbf
sash-a Aug 29, 2023
ec78da2
feat: implemented observers
sash-a Sep 4, 2023
03c584f
feat: observer test and fix out of bounds bug in vector observer
sash-a Sep 4, 2023
812b71a
fix: tests and move obs spec to observers
sash-a Sep 4, 2023
48996eb
chore: pre-commit
sash-a Sep 4, 2023
3166a7b
chore: clean up
sash-a Sep 4, 2023
2e6aabf
fix: made valid moves consistent with lbf
sash-a Sep 5, 2023
db4676c
chore: lots of docstrings
sash-a Sep 5, 2023
a91a101
chore: remove unused arg
sash-a Sep 5, 2023
495a15f
feat: initial viewer - needs numbers and better icons
sash-a Sep 5, 2023
be032c0
feat: exposing properties for mava
sash-a Sep 7, 2023
5c89e99
feat: correct termination/truncation behaviour
sash-a Sep 12, 2023
f9c905c
fix: max food level bug
sash-a Sep 20, 2023
bd490e6
Merge branch 'instadeepai:main' into feat/lbf-truncate
sash-a Jan 8, 2024
d7a78ea
feat: full lbf implementation
WiemKhlifi Jan 18, 2024
92fc862
Merge branch 'instadeepai:main' into feat/lbf-truncate
WiemKhlifi Jan 18, 2024
6176c99
feat: multi-agent connector
sash-a Jan 27, 2024
98e3273
Merge branch 'main' into feat/connector-multiagent
sash-a Jan 27, 2024
8015531
integrate Ruan fixes
SimonDuToit Feb 8, 2024
b70dc3d
feat: correction for linter + tests
WiemKhlifi Feb 9, 2024
9650b45
Merge pull request #1 from SimonDuToit/feat/connector-multiagent
sash-a Feb 12, 2024
dec1b8f
Merge branch 'main' into feat/lbf-truncate
sash-a Feb 12, 2024
e3c70cb
Merge pull request #2 from sash-a/feat/lbf-truncate
sash-a Feb 12, 2024
569a003
Merge pull request #3 from sash-a/feat/connector-multiagent
sash-a Feb 12, 2024
6826397
fix: requirements
sash-a Feb 14, 2024
9aa74ab
Merge branch 'fix/sokoban-requirements'
sash-a Feb 14, 2024
79b453e
fix: fix lbf bug with new jax version
WiemKhlifi Feb 28, 2024
b276f42
chore: delete unnecessary file
WiemKhlifi Mar 1, 2024
adbb258
chore: delete unnecessary file
WiemKhlifi Mar 1, 2024
74d4c54
Merge branch 'main' into feat/lbf-truncate
WiemKhlifi Mar 6, 2024
f938f80
chore: less strict cleaner types
sash-a Apr 30, 2024
8a58b1f
chore: small changes
WiemKhlifi May 17, 2024
6082c50
fix: fix dataclass with defaults issue with recent python version
WiemKhlifi Jul 10, 2024
857ce8d
Merge branch 'main' into feat/lbf-truncate
WiemKhlifi Jul 10, 2024
234b32e
fix: fix missing attribute in utils test
WiemKhlifi Jul 10, 2024
0408965
feat: undo changes related to cleaner and conncetor
WiemKhlifi Jul 10, 2024
a7692e7
chore: resolve conflict with main branch and merge changes
WiemKhlifi Jul 10, 2024
15aa74d
feat: adapt lbf to the recent changes
WiemKhlifi Jul 11, 2024
2945c3b
test: fixing conflict
WiemKhlifi Jul 11, 2024
6c51988
Merge branch 'main' into feat/lbf-truncate
WiemKhlifi Jul 11, 2024
f964b40
chore: small changes for lbf
WiemKhlifi Jul 11, 2024
be14b0b
fix: fix lbf env collision test
WiemKhlifi Jul 11, 2024
b90360c
fix: dataclass default setting issue + few changes
WiemKhlifi Jul 11, 2024
f102941
Merge latest changes in jumanji + lbf
WiemKhlifi Oct 10, 2024
077807f
docs: edit lbf documentation
WiemKhlifi Oct 10, 2024
ba3a4a2
chore: naming typos
WiemKhlifi Oct 21, 2024
d7ddb47
feat: refactor GridObserver and make few fixes for action_mask
WiemKhlifi Oct 22, 2024
9adee92
chore: small changes based on review
WiemKhlifi Oct 24, 2024
ab72da5
chore: add few more tests
WiemKhlifi Oct 24, 2024
65489f6
chore: add some asserts
WiemKhlifi Oct 24, 2024
3f8297d
Merge branch 'main' into feat/lbf-truncate
sash-a Oct 25, 2024
67878f4
chore: small changes based on review
WiemKhlifi Oct 25, 2024
b3d2495
fix: invert the condition for assertions
WiemKhlifi Oct 25, 2024
9aa8967
docs: edit assert message
WiemKhlifi Oct 25, 2024
bf91646
Merge branch 'main' into feat/lbf-truncate
WiemKhlifi Oct 28, 2024
a828313
Merge branch 'main' into feat/lbf-truncate
sash-a Oct 28, 2024
4ed4509
docs: edit assert message
WiemKhlifi Oct 28, 2024
6e28a67
chore: update how we read imgs
WiemKhlifi Oct 28, 2024
42c171d
chore: change test files path
WiemKhlifi Oct 28, 2024
969fdaa
Merge branch 'main' into feat/lbf-truncate
WiemKhlifi Oct 28, 2024
477b10b
chore: remove uneeded files used for testing
WiemKhlifi Oct 29, 2024
0260b9f
chore: add a more protective assert for generator
WiemKhlifi Oct 30, 2024
d47de71
fix: increase grid_size to verify the new assert
WiemKhlifi Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
</div>
<div class="row" align="center">
<img src="docs/env_anim/multi_cvrp.gif" alt="MultiCVRP" width="16%">
<img src="docs/env_anim/pac_man.gif" alt="PacMan" width="16%">
<img src="docs/env_anim/pac_man.gif" alt="PacMan" width="12.9%">
<img src="docs/env_anim/robot_warehouse.gif" alt="RobotWarehouse" width="16%">
<img src="docs/env_anim/rubiks_cube.gif" alt="RubiksCube" width="16%">
<img src="docs/env_anim/sliding_tile_puzzle.gif" alt="SlidingTilePuzzle" width="16%">
Expand All @@ -50,6 +50,7 @@
<img src="docs/env_anim/sudoku.gif" alt="Sudoku" width="16%">
<img src="docs/env_anim/tetris.gif" alt="Tetris" width="16%">
<img src="docs/env_anim/tsp.gif" alt="Tetris" width="16%">
<img src="docs/env_anim/lbf.gif" alt="Level-Based Foraging" width="16%">
</div>
</div>

Expand Down Expand Up @@ -121,6 +122,7 @@ problems.
| Multi Minimum Spanning Tree Problem | Routing | `MMST-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/mmst) | [doc](https://instadeepai.github.io/jumanji/environments/mmst/) |
| ᗧ•••ᗣ•• PacMan | Routing | `PacMan-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/pac_man/) | [doc](https://instadeepai.github.io/jumanji/environments/pac_man/)
| 👾 Sokoban | Routing | `Sokoban-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/sokoban/) | [doc](https://instadeepai.github.io/jumanji/environments/sokoban/) |
| 🍎 Level-Based Foraging | Routing | `LevelBasedForaging-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/lbf/) | [doc](https://instadeepai.github.io/jumanji/environments/lbf/) |

<h2 name="install" id="install">Installation 🎬</h2>

Expand Down
9 changes: 9 additions & 0 deletions docs/api/environments/lbf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
::: jumanji.environments.routing.lbf.env.LevelBasedForaging
selection:
members:
- __init__
- reset
- step
- observation_spec
- action_spec
- render
Binary file added docs/env_anim/lbf.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
43 changes: 43 additions & 0 deletions docs/environments/lbf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# # Level-Based Foraging Environment

<p align="center">
<img src="../env_anim/lbf.gif" width="600"/>
</p>

We provide a JAX jit-able implementation of the [Level-Based Foraging](https://github.com/semitable/lb-foraging/tree/master)
environment.

The Level-Based Foraging (LBF) represents a mixed cooperative-competitive environment that emphasises coordination between agents. As illustrated above, agents are placed within a grid world and assigned different levels.

To collect food, agents must be adjacent to it and the cumulative level of participating agents must meet or exceed the food's designated level. Agents receive points based on the level of the collected food and their own level.

## Observation

The **observation** seen by the agent is a `NamedTuple` containing the following:

- `agents_view`: jax array (int32) of shape `(num_agents, num_obs_features)`, array representing the agent's view of other agents
and food.

- `action_mask`: jax array (bool) of shape `(num_agents, 6)`, array specifying, for each agent,
which action (noop, up, down, left, right, load) is legal.

- `step_count`: jax array (int32) of shape `()`, number of steps elapsed in the current episode.

## Action

The action space is a `MultiDiscreteArray` containing an integer value in `[0, 1, 2, 3, 4, 5]` for each
agent. Each agent can take one of five actions: noop (`0`), up (`1`), down (`2`), turn left (`3`), turn right (`4`), or pick up food (`5`).

The episode terminates under the following conditions:

- An invalid action is taken, or

- An agent collides with another agent.

## Reward

The reward is equal to the sum of the levels of collected food divided by the level of the agents that collected them.

## Registered Versions 📖

- `LevelBasedForaging-v0`, a grid with 2 agents each with a field of view equal to the grid size (full observation case), with 2 food items and forcing the cooperation between agents.
6 changes: 6 additions & 0 deletions jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,9 @@
register(
id="SlidingTilePuzzle-v0", entry_point="jumanji.environments:SlidingTilePuzzle"
)

# LevelBasedForaging with a random generator with 8 grid size,
# 2 agents and 2 food items and the maximum agent's level is 2.
register(
id="LevelBasedForaging-v0", entry_point="jumanji.environments:LevelBasedForaging"
)
1 change: 1 addition & 0 deletions jumanji/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from jumanji.environments.routing.cleaner.env import Cleaner
from jumanji.environments.routing.connector.env import Connector
from jumanji.environments.routing.cvrp.env import CVRP
from jumanji.environments.routing.lbf.env import LevelBasedForaging
from jumanji.environments.routing.maze.env import Maze
from jumanji.environments.routing.mmst.env import MMST
from jumanji.environments.routing.multi_cvrp.env import MultiCVRP
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/routing/connector/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class Connector(Environment[State, specs.MultiDiscreteArray, Observation]):
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_specc.generate_value()
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)
```
Expand Down
17 changes: 17 additions & 0 deletions jumanji/environments/routing/lbf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jumanji.environments.routing.lbf.env import LevelBasedForaging
from jumanji.environments.routing.lbf.observer import GridObserver, VectorObserver
from jumanji.environments.routing.lbf.types import Agent, Food, Observation, State
205 changes: 205 additions & 0 deletions jumanji/environments/routing/lbf/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import chex
import jax
import jax.numpy as jnp
import pytest

from jumanji.environments.routing.lbf.env import LevelBasedForaging
from jumanji.environments.routing.lbf.generator import RandomGenerator
from jumanji.environments.routing.lbf.types import Agent, Food, State
from jumanji.tree_utils import tree_transpose

# create food and agents for grid that looks like:
# "AGENT" | EMPTY | EMPTY | EMPTY | EMPTY | EMPTY
# EMPTY | "AGENT" | EMPTY | EMPTY | EMPTY | EMPTY
# EMPTY | "FOOD" | "AGENT" | "FOOD" | EMPTY | EMPTY
# EMPTY | EMPTY | EMPTY | EMPTY | EMPTY | EMPTY
# EMPTY | EMPTY | "FOOD" | EMPTY | EMPTY | EMPTY
# EMPTY | EMPTY | EMPTY | EMPTY | EMPTY | EMPTY


@pytest.fixture
def key() -> chex.PRNGKey:
return jax.random.PRNGKey(42)


@pytest.fixture
def agent0() -> Agent:
return Agent(
id=jnp.asarray(0),
position=jnp.array([0, 0]),
level=jnp.asarray(1),
loading=jnp.asarray(False),
)


@pytest.fixture
def agent1() -> Agent:
return Agent(
id=jnp.asarray(1),
position=jnp.array([1, 1]),
level=jnp.asarray(2),
loading=jnp.asarray(False),
)


@pytest.fixture
def agent2() -> Agent:
return Agent(
id=jnp.asarray(2),
position=jnp.array([2, 2]),
level=jnp.asarray(4),
loading=jnp.asarray(False),
)


@pytest.fixture
def food0() -> Food:
return Food(
id=jnp.asarray(0),
position=jnp.array([2, 1]),
level=jnp.asarray(4),
eaten=jnp.asarray(False),
)


@pytest.fixture
def food1() -> Food:
return Food(
id=jnp.asarray(1),
position=jnp.array([2, 3]),
level=jnp.asarray(4),
eaten=jnp.asarray(False),
)


@pytest.fixture
def food2() -> Food:
return Food(
id=jnp.asarray(1),
position=jnp.array([4, 2]),
level=jnp.asarray(3),
eaten=jnp.asarray(False),
)


@pytest.fixture
def agents(agent0: Agent, agent1: Agent, agent2: Agent) -> Agent:
return tree_transpose([agent0, agent1, agent2])


@pytest.fixture
def food_items(food0: Food, food1: Food, food2: Food) -> Food:
return tree_transpose([food0, food1, food2])


@pytest.fixture
def state(agents: Agent, food_items: Food, key: chex.PRNGKey) -> State:
return State(agents=agents, food_items=food_items, step_count=0, key=key)


@pytest.fixture
def agent_grid() -> chex.Array:
"""Returns the agents' levels in their postion on the grid."""
return jnp.array(
[
[1, 0, 0, 0, 0, 0],
[0, 2, 0, 0, 0, 0],
[0, 0, 4, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
)


@pytest.fixture
def food_grid() -> chex.Array:
"""Returns the food items's levels in their postion on the grid."""
return jnp.array(
[
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 4, 0, 4, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 3, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]
)


@pytest.fixture
def random_generator() -> RandomGenerator:
return RandomGenerator(
grid_size=8,
fov=2,
num_agents=2,
num_food=2,
max_agent_level=2,
force_coop=True,
)


@pytest.fixture
def lbf_environment() -> LevelBasedForaging:
generator = RandomGenerator(
grid_size=8,
fov=6,
num_agents=3,
num_food=3,
max_agent_level=4,
force_coop=True,
)

return LevelBasedForaging(generator=generator, time_limit=5)


@pytest.fixture
def lbf_env_2s() -> LevelBasedForaging:
generator = RandomGenerator(
grid_size=8,
fov=2,
num_agents=2,
num_food=2,
max_agent_level=2,
force_coop=False,
)

return LevelBasedForaging(generator=generator, time_limit=5)


@pytest.fixture
def lbf_env_grid_obs() -> LevelBasedForaging:
generator = RandomGenerator(
grid_size=8,
fov=6,
num_agents=3,
num_food=3,
max_agent_level=4,
force_coop=True,
)

return LevelBasedForaging(generator=generator, grid_observation=True)


@pytest.fixture
def lbf_with_penalty() -> LevelBasedForaging:
return LevelBasedForaging(penalty=1.0)


@pytest.fixture
def lbf_with_no_norm_reward() -> LevelBasedForaging:
return LevelBasedForaging(normalize_reward=False)
32 changes: 32 additions & 0 deletions jumanji/environments/routing/lbf/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax.numpy as jnp

# Actions
NOOP = 0
UP = 1
DOWN = 2
LEFT = 3
RIGHT = 4
LOAD = 5

# NOOP, UP, DOWN, LEFT, RIGHT, LOAD
MOVES = jnp.array([[0, 0], [-1, 0], [1, 0], [0, -1], [0, 1], [0, 0]])

# viewer constants
_FIGURE_SIZE = (5, 5)

# Define some colors for visualization.
_GRID_COLOR = (0, 0, 0) # black
_LINE_COLOR = (1, 1, 1) # white
Loading
Loading