Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat : MultiCVRP wrapper #1043

Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6e25369
feat: multiCVRP base implementaion
Louay-Ben-nessir Feb 20, 2024
8ecceee
Merge branch 'develop' into feat--multiCVRP-wrapper
Louay-Ben-nessir Feb 26, 2024
4b804c2
feat: multi
Louay-Ben-nessir Feb 26, 2024
bf85601
Merge branch 'develop' of https://github.com/Louay-Ben-nessir/Mava in…
Louay-Ben-nessir Feb 27, 2024
09d0cd9
fix: a lot of small changes
Louay-Ben-nessir Feb 27, 2024
bf96c8b
fix: minor changes
Louay-Ben-nessir Feb 27, 2024
4d46916
fix: comments + super + observation insted of state
Louay-Ben-nessir Feb 28, 2024
e3e002c
fix: minor comment changes
Louay-Ben-nessir Feb 28, 2024
ee31d6f
Merge branch 'develop' into feat--multiCVRP-wrapper
WiemKhlifi Feb 29, 2024
d3da4ed
fix: use the multiagentWrapper
Louay-Ben-nessir Mar 4, 2024
1486b53
fix: annotiaions
Louay-Ben-nessir Mar 4, 2024
74906b6
Merge remote-tracking branch 'origin/develop' into feat--multiCVRP-wr…
Louay-Ben-nessir Mar 11, 2024
9743c59
fix: pre-commit
Louay-Ben-nessir Mar 11, 2024
7edbc41
Merge branch 'develop' into feat--multiCVRP-wrapper
sash-a Mar 12, 2024
1f20e77
Merge branch 'develop' into feat--multiCVRP-wrapper
Louay-Ben-nessir Mar 12, 2024
3abe88f
Merge branch 'develop' into feat--multiCVRP-wrapper
RuanJohn Mar 15, 2024
1ed331b
chore: node --> Node in docs
Louay-Ben-nessir Mar 18, 2024
52a8299
Merge branch 'develop' into feat--multiCVRP-wrapper
RuanJohn Mar 18, 2024
c0e7d02
chore: hardcoded numbers docs
Louay-Ben-nessir Mar 18, 2024
b45bb53
Merge branch 'feat--multiCVRP-wrapper' of https://github.com/Louay-Be…
Louay-Ben-nessir Mar 18, 2024
9ddc91a
fix: pre-commit
Louay-Ben-nessir Mar 18, 2024
39540d3
Merge branch 'develop' into feat--multiCVRP-wrapper
sash-a Mar 18, 2024
b5b5a4f
chore: removed the reward/action specs
Louay-Ben-nessir Mar 21, 2024
bc4271d
Merge branch 'feat--multiCVRP-wrapper' of https://github.com/Louay-Be…
Louay-Ben-nessir Mar 21, 2024
86ab329
Merge branch 'develop' into feat--multiCVRP-wrapper
Louay-Ben-nessir Mar 21, 2024
02f69e1
fix: added state to cleanr's modifie_timestep function
Louay-Ben-nessir Mar 21, 2024
7be27f8
Merge branch 'develop' into feat--multiCVRP-wrapper
WiemKhlifi Mar 21, 2024
2e30ffd
fix: removed the unneeded action_spec
Louay-Ben-nessir Mar 22, 2024
667456c
Merge branch 'feat--multiCVRP-wrapper' of https://github.com/Louay-Be…
Louay-Ben-nessir Mar 22, 2024
b26e48e
fix: pre-commits
Louay-Ben-nessir Mar 22, 2024
9943929
Merge branch 'develop' into feat--multiCVRP-wrapper
WiemKhlifi May 9, 2024
e8d515d
Merge branch 'develop' into feat--multiCVRP-wrapper
OmaymaMahjoub Jul 5, 2024
023c915
fix: updated to the latest configs
Louay-Ben-nessir Jul 6, 2024
bda1fff
fix: corrected the 20c scenario config
Louay-Ben-nessir Jul 6, 2024
4f1a68a
chore: pre-commits
Louay-Ben-nessir Jul 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mava/configs/env/multicvrp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# ---Environment Configs---
defaults:
- _self_
- scenario: multicvrp-2v-20c # [multicvrp-2v-20c, multicvrp-2v-6c]

env_name: MultiCVRP

# Whether the environment observations encode implicit agent IDs. If True, the AgentID wrapper is not used.
# This should not be changed.
implicit_agent_id: False

eval_metric: episode_return

kwargs: {}
RuanJohn marked this conversation as resolved.
Show resolved Hide resolved
6 changes: 6 additions & 0 deletions mava/configs/env/scenario/multicvrp-2v-20c.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: MultiCVRP-v0
task_name: multicvrp-2v-20c

task_config:
num_customers : 20
num_vehicles : 2
6 changes: 6 additions & 0 deletions mava/configs/env/scenario/multicvrp-2v-6c.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: MultiCVRP-v0
task_name: multicvrp-2v-6c

task_config:
num_customers : 6
num_vehicles : 2
5 changes: 5 additions & 0 deletions mava/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from jumanji.environments.routing.lbf.generator import (
RandomGenerator as LbfRandomGenerator,
)
from jumanji.environments.routing.multi_cvrp.generator import (
UniformRandomGenerator as MultiCVRPRandomGenerator,
)
from jumanji.environments.routing.robot_warehouse.generator import (
RandomGenerator as RwareRandomGenerator,
)
Expand All @@ -39,6 +42,7 @@
LbfWrapper,
MabraxWrapper,
MatraxWrapper,
MultiCVRPWrapper,
RecordEpisodeMetrics,
RwareWrapper,
SmaxWrapper,
Expand All @@ -49,6 +53,7 @@
"RobotWarehouse-v0": {"generator": RwareRandomGenerator, "wrapper": RwareWrapper},
"LevelBasedForaging-v0": {"generator": LbfRandomGenerator, "wrapper": LbfWrapper},
"MaConnector-v2": {"generator": ConnectorRandomGenerator, "wrapper": ConnectorWrapper},
"MultiCVRP-v0": {"generator": MultiCVRPRandomGenerator, "wrapper": MultiCVRPWrapper},
}

# Define a different registry for Matrax since it has no generator.
Expand Down
7 changes: 6 additions & 1 deletion mava/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from mava.wrappers.episode_metrics import RecordEpisodeMetrics
from mava.wrappers.gigastep import GigastepWrapper
from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper
from mava.wrappers.jumanji import ConnectorWrapper, LbfWrapper, RwareWrapper
from mava.wrappers.jumanji import (
ConnectorWrapper,
LbfWrapper,
MultiCVRPWrapper,
RwareWrapper,
)
from mava.wrappers.matrax import MatraxWrapper
from mava.wrappers.observation import AgentIDWrapper
140 changes: 134 additions & 6 deletions mava/wrappers/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import chex
import jax.numpy as jnp
from jax import tree_util
from jumanji import specs
from jumanji.env import Environment
from jumanji.environments.routing.connector import MaConnector
Expand All @@ -29,6 +30,10 @@
TARGET,
)
from jumanji.environments.routing.lbf import LevelBasedForaging
from jumanji.environments.routing.multi_cvrp import MultiCVRP
from jumanji.environments.routing.multi_cvrp.types import (
Observation as MultiCvrpObservation,
)
from jumanji.environments.routing.robot_warehouse import RobotWarehouse
from jumanji.types import TimeStep
from jumanji.wrappers import Wrapper
Expand All @@ -44,7 +49,7 @@ def __init__(self, env: Environment, add_global_state: bool):
self.add_global_state = add_global_state

@abstractmethod
def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]:
"""Modify the timestep for `step` and `reset`."""
pass

Expand All @@ -59,7 +64,7 @@ def get_global_state(self, obs: Observation) -> chex.Array:
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
"""Reset the environment."""
state, timestep = self._env.reset(key)
timestep = self.modify_timestep(timestep)
timestep = self.modify_timestep(timestep, state)
if self.add_global_state:
global_state = self.get_global_state(timestep.observation)
observation = ObservationGlobalState(
Expand All @@ -75,7 +80,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
"""Step the environment."""
state, timestep = self._env.step(state, action)
timestep = self.modify_timestep(timestep)
timestep = self.modify_timestep(timestep, state)
if self.add_global_state:
global_state = self.get_global_state(timestep.observation)
observation = ObservationGlobalState(
Expand Down Expand Up @@ -130,7 +135,7 @@ def __init__(self, env: RobotWarehouse, add_global_state: bool = False):
super().__init__(env, add_global_state)
self._env: RobotWarehouse

def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]:
"""Modify the timestep for the Robotic Warehouse environment."""
observation = Observation(
agents_view=timestep.observation.agents_view,
Expand Down Expand Up @@ -172,7 +177,7 @@ def aggregate_rewards(
reward = jnp.repeat(team_reward, self.num_agents)
return timestep.replace(observation=observation, reward=reward)

def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]:
"""Modify the timestep for Level-Based Foraging environment and update
the reward based on the specified reward handling strategy.
"""
Expand Down Expand Up @@ -201,7 +206,9 @@ def __init__(self, env: MaConnector, add_global_state: bool = False):
super().__init__(env, add_global_state)
self._env: MaConnector

def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
def modify_timestep(
self, timestep: TimeStep, state: State
) -> TimeStep[Union[Observation, ObservationGlobalState]]:
"""Modify the timestep for the Connector environment."""

# TARGET = 3 = The number of different types of items on the grid.
Expand Down Expand Up @@ -265,3 +272,124 @@ def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalSta
return specs.Spec(ObservationGlobalState, "ObservationSpec", **obs_data)

return specs.Spec(Observation, "ObservationSpec", **obs_data)


class MultiCVRPWrapper(MultiAgentWrapper):
"""Wrapper for MultiCVRP environment."""

def __init__(self, env: MultiCVRP, add_global_state: bool = False):
env.num_agents = env._num_vehicles
env.time_limit = env._num_customers + 1 # added for consistency
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I'm not very familiar with MultiCVRP 😅, I wanted to ask if the time_limit is only controlled in this way or if it can be set manually like in other environments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jumanji's Multicvrp doesn't offer a way to manually set the time_limit. I changed it to None to better indicate this.

env.action_dim = env._num_customers + 1 # n_costumers + 1 starter node
super().__init__(env, False)
self.has_global_state = add_global_state
self._env = env
Louay-Ben-nessir marked this conversation as resolved.
Show resolved Hide resolved

def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]:
observation, global_observation = self._flatten_observation(timestep.observation)
obs_data = {
"agents_view": observation,
"action_mask": timestep.observation.action_mask,
"step_count": jnp.repeat(state.step_count, (self.num_agents)),
Louay-Ben-nessir marked this conversation as resolved.
Show resolved Hide resolved
}
if self.has_global_state:
obs_data["global_state"] = global_observation
observation = ObservationGlobalState(**obs_data)
else:
observation = Observation(**obs_data)

reward = jnp.repeat(timestep.reward, (self.num_agents))
discount = jnp.repeat(timestep.discount, (self.num_agents))
Louay-Ben-nessir marked this conversation as resolved.
Show resolved Hide resolved
return timestep.replace(observation=observation, reward=reward, discount=discount)

def _flatten_observation(
self, observation: MultiCvrpObservation
) -> Tuple[chex.Array, Union[None, chex.Array]]:
"""
Concatenates all observation fields into a single array.

Args:
observation (MultiCvrpObservation): The raw observation NamedTuple provided by jumanji.

Returns:
observations (chex.Array): Concatenated individual observations for each agent,
shaped (num_agents, vehicle_info + customer_info).
global_observation (Union[None, chex.Array]): Concatenated global observation
shaped (num_agents, global_info) if has_global_state = True, None otherwise.
"""
global_observation = None
# N: number of nodes, same as _num_customers + 1
# V: number of vehicles, same as num_agents
# nodes are composed of (x, y, demands)
Louay-Ben-nessir marked this conversation as resolved.
Show resolved Hide resolved
# Windows are composed of (start_time, end_time)
# Coeffs are composed of (early, late)
# Vehicles have ((x, y), local_time, capacity)

# Tuple[(N, 3), (N, 2), (N, 2)]
customers_info, _ = tree_util.tree_flatten(
(observation.nodes, observation.windows, observation.coeffs)
)
# Tuple[(V, 2), (V, 1), (V, 1)]
vehicles_info, _ = tree_util.tree_flatten(observation.vehicles)
Louay-Ben-nessir marked this conversation as resolved.
Show resolved Hide resolved

# (N * 7, )
Louay-Ben-nessir marked this conversation as resolved.
Show resolved Hide resolved
customers_info = jnp.column_stack(customers_info).ravel()
# (V, 4)
vehicles_info = jnp.column_stack(vehicles_info)

if self.has_global_state:
# (V * 4 + N * 7, )
global_observation = jnp.concatenate((vehicles_info.ravel(), customers_info))
# (V, N * 7 + V * 4)
global_observation = jnp.tile(global_observation, (self.num_agents, 1))

# (V, N * 7)
Louay-Ben-nessir marked this conversation as resolved.
Show resolved Hide resolved
customers_info = jnp.tile(customers_info, (self.num_agents, 1))
# (V, 4 + N * 7)
observations = jnp.column_stack((vehicles_info, customers_info))
return observations, global_observation

def observation_spec(self) -> specs.Spec[Observation]:
step_count = specs.BoundedArray(
(self.num_agents,), jnp.int32, 0, self._env._num_customers + 1, "step_count"
Louay-Ben-nessir marked this conversation as resolved.
Show resolved Hide resolved
)
action_mask = specs.BoundedArray(
(self.num_agents, self._env._num_customers + 1), bool, False, True, "action_mask"
)

agents_view = specs.BoundedArray(
(self.num_agents, (self._env._num_customers + 1) * 7 + 4),
jnp.float32,
-jnp.inf,
jnp.inf,
"agents_view",
)
obs_data = {
"agents_view": agents_view,
"action_mask": action_mask,
"step_count": step_count,
}

if self.has_global_state:
global_state = specs.Array(
(self.num_agents, (self._env._num_customers + 1) * 7 + 4 * self.num_agents),
jnp.float32,
"global_state",
)
obs_data["global_state"] = global_state
return specs.Spec(ObservationGlobalState, "ObservationSpec", **obs_data)

return specs.Spec(Observation, "ObservationSpec", **obs_data)

def reward_spec(self) -> specs.Array:
return specs.Array(shape=(self.num_agents,), dtype=float, name="reward")

def discount_spec(self) -> specs.BoundedArray:
return specs.BoundedArray(
shape=(self.num_agents,), dtype=float, minimum=0.0, maximum=1.0, name="discount"
)
Louay-Ben-nessir marked this conversation as resolved.
Show resolved Hide resolved

def action_spec(self) -> specs.Spec:
return specs.MultiDiscreteArray(
num_values=jnp.full(self.num_agents, self._env._num_customers + 1)
)
Loading