From 6e25369527b9609667b8ac908da1a46de5bda5a2 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 20 Feb 2024 18:24:27 +0100 Subject: [PATCH 01/19] feat: multiCVRP base implementaion --- mava/configs/env/multicvrp.yaml | 10 +++ mava/configs/env/scenario/2v-20c.yaml | 7 +++ mava/configs/env/scenario/2v-6c.yaml | 7 +++ mava/configs/logger/base_logger.yaml | 2 +- mava/configs/network/mlp.yaml | 2 +- mava/systems/ff_ippo.py | 4 +- mava/systems/ff_mappo.py | 2 +- mava/systems/rec_ippo.py | 2 +- mava/systems/rec_mappo.py | 2 +- mava/utils/make_env.py | 14 ++++- mava/wrappers/__init__.py | 2 +- mava/wrappers/jumanji.py | 89 ++++++++++++++++++++++++++- 12 files changed, 132 insertions(+), 11 deletions(-) create mode 100644 mava/configs/env/multicvrp.yaml create mode 100644 mava/configs/env/scenario/2v-20c.yaml create mode 100644 mava/configs/env/scenario/2v-6c.yaml diff --git a/mava/configs/env/multicvrp.yaml b/mava/configs/env/multicvrp.yaml new file mode 100644 index 000000000..41bc9196a --- /dev/null +++ b/mava/configs/env/multicvrp.yaml @@ -0,0 +1,10 @@ +# ---Environment Configs--- +defaults: + - _self_ + - scenario: 2v-20c #2v-20c + + +env_name: MultiCVRP-v0 + +eval_metric: episode_return + diff --git a/mava/configs/env/scenario/2v-20c.yaml b/mava/configs/env/scenario/2v-20c.yaml new file mode 100644 index 000000000..221afd980 --- /dev/null +++ b/mava/configs/env/scenario/2v-20c.yaml @@ -0,0 +1,7 @@ +task_name: 2v-20c + +task_config: + num_customers : 20 + num_vehicles : 2 + + diff --git a/mava/configs/env/scenario/2v-6c.yaml b/mava/configs/env/scenario/2v-6c.yaml new file mode 100644 index 000000000..ef1183032 --- /dev/null +++ b/mava/configs/env/scenario/2v-6c.yaml @@ -0,0 +1,7 @@ +task_name: 2v-6c + +task_config: + num_customers : 6 + num_vehicles : 2 + + diff --git a/mava/configs/logger/base_logger.yaml b/mava/configs/logger/base_logger.yaml index 9a1e37b87..d0ba0f690 100644 --- a/mava/configs/logger/base_logger.yaml +++ b/mava/configs/logger/base_logger.yaml @@ -9,7 +9,7 @@ use_neptune: False # Whether to log to neptune.ai. # --- Other logger kwargs --- kwargs: neptune_project: Instadeep/Mava - neptune_tag: [rware] + neptune_tag: [multiCVRP first run] detailed_neptune_logging: False # having mean/std/min/max can clutter neptune so we make it optional json_path: ~ # If set, json files will be logged to a set path so that multiple experiments can # write to the same json file for easy downstream aggregation and plotting with marl-eval. diff --git a/mava/configs/network/mlp.yaml b/mava/configs/network/mlp.yaml index cab9ecd44..7e50fa9ca 100644 --- a/mava/configs/network/mlp.yaml +++ b/mava/configs/network/mlp.yaml @@ -2,7 +2,7 @@ actor_network: pre_torso: _target_: mava.networks.MLPTorso - layer_sizes: [128, 128] + layer_sizes: [128, 512, 1024, 1024, 512, 512, 64, 64] use_layer_norm: False activation: relu diff --git a/mava/systems/ff_ippo.py b/mava/systems/ff_ippo.py index cb9dd0f8b..8406a7895 100644 --- a/mava/systems/ff_ippo.py +++ b/mava/systems/ff_ippo.py @@ -510,7 +510,7 @@ def run_experiment(_config: DictConfig) -> float: ) # Run experiment for a total number of evaluations. - max_episode_return = jnp.float32(0.0) + max_episode_return = jnp.float32("-inf") best_params = None for eval_step in range(config.arch.num_evaluation): # Train. @@ -574,7 +574,7 @@ def run_experiment(_config: DictConfig) -> float: key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) eval_keys = jnp.stack(eval_keys) eval_keys = eval_keys.reshape(n_devices, -1) - + evaluator_output = absolute_metric_evaluator(best_params, eval_keys) jax.block_until_ready(evaluator_output) diff --git a/mava/systems/ff_mappo.py b/mava/systems/ff_mappo.py index 0ba006577..d272d09c4 100644 --- a/mava/systems/ff_mappo.py +++ b/mava/systems/ff_mappo.py @@ -516,7 +516,7 @@ def run_experiment(_config: DictConfig) -> float: ) # Run experiment for a total number of evaluations. - max_episode_return = jnp.float32(0.0) + max_episode_return = jnp.float32("-inf") best_params = None for eval_step in range(config.arch.num_evaluation): # Train. diff --git a/mava/systems/rec_ippo.py b/mava/systems/rec_ippo.py index ce44ccf79..0dd20f6b8 100644 --- a/mava/systems/rec_ippo.py +++ b/mava/systems/rec_ippo.py @@ -679,7 +679,7 @@ def run_experiment(_config: DictConfig) -> float: ) # Run experiment for a total number of evaluations. - max_episode_return = jnp.float32(0.0) + max_episode_return = jnp.float32("-inf") best_params = None for eval_step in range(config.arch.num_evaluation): # Train. diff --git a/mava/systems/rec_mappo.py b/mava/systems/rec_mappo.py index 9c3d9c1d3..f91d96547 100644 --- a/mava/systems/rec_mappo.py +++ b/mava/systems/rec_mappo.py @@ -686,7 +686,7 @@ def run_experiment(_config: DictConfig) -> float: ) # Run experiment for a total number of evaluations. - max_episode_return = jnp.float32(0.0) + max_episode_return = jnp.float32("-inf") best_params = None for eval_step in range(config.arch.num_evaluation): # Train. diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 2ef7010af..cd24d74cd 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -24,6 +24,9 @@ from jumanji.environments.routing.robot_warehouse.generator import ( RandomGenerator as RwareRandomGenerator, ) +from jumanji.environments.routing.multi_cvrp.generator import ( + UniformRandomGenerator as MultiCVRPRandomGenerator, +) from omegaconf import DictConfig from mava.wrappers import ( @@ -34,12 +37,14 @@ LbfWrapper, RecordEpisodeMetrics, RwareWrapper, + multiCVRPWrapper, ) # Registry mapping environment names to their generator and wrapper classes. _jumanji_registry = { "RobotWarehouse-v0": {"generator": RwareRandomGenerator, "wrapper": RwareWrapper}, "LevelBasedForaging-v0": {"generator": LbfRandomGenerator, "wrapper": LbfWrapper}, + "MultiCVRP-v0": {"generator": MultiCVRPRandomGenerator, "wrapper": multiCVRPWrapper}, } @@ -77,8 +82,13 @@ def make_jumanji_env( wrapper = _jumanji_registry[env_name]["wrapper"] # Create envs. - env = jumanji.make(env_name, generator=generator, **config.env.kwargs) - eval_env = jumanji.make(env_name, generator=generator, **config.env.kwargs) + if "kwargs" in config.env: + env = jumanji.make(env_name, generator=generator, **config.env.kwargs) + eval_env = jumanji.make(env_name, generator=generator, **config.env.kwargs) + else: + env = jumanji.make(env_name, generator=generator) + eval_env = jumanji.make(env_name, generator=generator) + env, eval_env = wrapper(env), wrapper(eval_env) env = add_optional_wrappers(env, config, add_global_state) diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index 48d030db4..fd1c772f3 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -15,5 +15,5 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.jaxmarl import JaxMarlWrapper -from mava.wrappers.jumanji import LbfWrapper, RwareWrapper +from mava.wrappers.jumanji import LbfWrapper, RwareWrapper, multiCVRPWrapper from mava.wrappers.observation import AgentIDWrapper, GlobalStateWrapper diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 19c56cafa..60f68b251 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -20,8 +20,12 @@ from jumanji.env import Environment from jumanji.environments.routing.lbf import LevelBasedForaging from jumanji.environments.routing.robot_warehouse import RobotWarehouse -from jumanji.types import TimeStep +from jumanji.types import TimeStep, StepType from jumanji.wrappers import Wrapper + +import jax +from jax import tree_util +from jumanji.environments.routing.multi_cvrp import MultiCVRP from mava.types import Observation, State @@ -117,3 +121,86 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: # Aggregate the list of individual rewards and use a single team_reward. return self.aggregate_rewards(timestep, modified_observation) + +class multiCVRPWrapper(Wrapper): + def __init__(self, env: MultiCVRP): + self.num_agents = env._num_vehicles + self._env = env + + def reset(self, key: chex.PRNGKey) -> Tuple[State | TimeStep]: + state , timestep = self._env.reset(key) + timestep = self.modify_timestep(timestep, state.step_count) # handeling the step_count is wrong in both gigastep / here it's allways set to 0 + return state, timestep + + def step(self, state: State, action: chex.Array) -> Tuple[State | TimeStep]: + state, timestep = self._env.step(state,action) + timestep = self.modify_timestep(timestep, state.step_count) + return state,timestep + + def modify_timestep(self, timestep: TimeStep, step_count : chex.Array) -> TimeStep[Observation]: + + observation = self._format_observation(timestep.observation) + reward = jnp.repeat(timestep.reward, (self.num_agents)) + discount = jnp.repeat(timestep.discount, (self.num_agents)) + step_count = jnp.repeat(step_count, (self.num_agents)) + observation = Observation( + agents_view=observation, + action_mask=timestep.observation.action_mask, + step_count=step_count, + ) + + + timestep = timestep.replace(observation=observation, reward=reward, discount=discount) + return timestep + + def _format_observation(self, observation): + #flatten and concat all of the observations for now + customers_info, _ = tree_util.tree_flatten((observation.nodes,observation.windows,observation.coeffs)) + vehicles_info , _ = tree_util.tree_flatten(observation.vehicles) + + #this results in c1-info1-c2,info2 + customers_info = jnp.column_stack(customers_info).ravel() + #each agents needs to get the customers_info in their observation , alot of compute is wasted this way + customers_info = jnp.tile(customers_info, (self.num_agents, 1) ) + + vehicles_info = jnp.column_stack(vehicles_info) + + #(num_vechials, obs) with obs (vechial_obs + costumer_obs) if the costumer obs doest change much woudn't be better to pass it thought it's own network once? + observations = jnp.column_stack((vehicles_info, customers_info)) + + return observations + + def observation_spec(self) -> specs.Spec[Observation]: + step_count = specs.BoundedArray( + (self.num_agents,), jnp.int32,0, self._env._num_customers + 1 ,"step_count" + ) + action_mask = specs.BoundedArray( + (self.num_agents, self._env._num_customers + 1), bool, False, True, "action_mask" + ) + agents_view = specs.BoundedArray( + (self.num_agents, (self._env._num_customers + 1) * 7 + 4), #7 is broken into 2 for cords, 1 each of demands,start,end,early,late and the 4 is the cords,capacity of the veichale + jnp.float32, + -jnp.inf, + jnp.inf, + "agents_view", + ) + return specs.Spec( + Observation, + "ObservationSpec", + agents_view=agents_view, + action_mask=action_mask, + step_count=step_count, + ) + + def reward_spec(self) -> specs.Array: + return specs.Array(shape=(self.num_agents,), dtype=float, name="reward") + + def discount_spec(self) -> specs.BoundedArray: + return specs.BoundedArray( + shape=(self.num_agents,), dtype=float, minimum=0.0, maximum=1.0, name="discount" + ) + + def action_spec(self) -> specs.Spec: + return specs.MultiDiscreteArray(num_values=jnp.full(self.num_agents, self._env._num_customers + 1)) + + From 4b804c2958212ad2672b5eb320c886daebff056d Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 26 Feb 2024 16:17:48 +0100 Subject: [PATCH 02/19] feat: multi --- mava/configs/env/multicvrp.yaml | 4 +- mava/configs/env/scenario/2v-20c.yaml | 1 + mava/configs/env/scenario/2v-6c.yaml | 1 + mava/configs/logger/base_logger.yaml | 2 +- mava/configs/network/mlp.yaml | 2 +- mava/utils/make_env.py | 12 ++++-- mava/wrappers/jumanji.py | 60 +++++++++++++++++---------- 7 files changed, 52 insertions(+), 30 deletions(-) diff --git a/mava/configs/env/multicvrp.yaml b/mava/configs/env/multicvrp.yaml index 41bc9196a..ab3e9d630 100644 --- a/mava/configs/env/multicvrp.yaml +++ b/mava/configs/env/multicvrp.yaml @@ -1,10 +1,10 @@ # ---Environment Configs--- defaults: - _self_ - - scenario: 2v-20c #2v-20c + - scenario: 2v-20c # [2v-20c, 2v-6c] -env_name: MultiCVRP-v0 +env_name: MultiCVRP eval_metric: episode_return diff --git a/mava/configs/env/scenario/2v-20c.yaml b/mava/configs/env/scenario/2v-20c.yaml index 221afd980..49ea3b8c7 100644 --- a/mava/configs/env/scenario/2v-20c.yaml +++ b/mava/configs/env/scenario/2v-20c.yaml @@ -1,3 +1,4 @@ +name: MultiCVRP-v0 task_name: 2v-20c task_config: diff --git a/mava/configs/env/scenario/2v-6c.yaml b/mava/configs/env/scenario/2v-6c.yaml index ef1183032..1952ff557 100644 --- a/mava/configs/env/scenario/2v-6c.yaml +++ b/mava/configs/env/scenario/2v-6c.yaml @@ -1,3 +1,4 @@ +name: MultiCVRP-v0 task_name: 2v-6c task_config: diff --git a/mava/configs/logger/base_logger.yaml b/mava/configs/logger/base_logger.yaml index d0ba0f690..a52f4a8f6 100644 --- a/mava/configs/logger/base_logger.yaml +++ b/mava/configs/logger/base_logger.yaml @@ -4,7 +4,7 @@ base_exp_path: results # Base path for logging. use_console: True # Whether to log to stdout. use_tb: False # Whether to use tensorboard logging. use_json: False # Whether to log marl-eval style to json files. -use_neptune: False # Whether to log to neptune.ai. +use_neptune: True # Whether to log to neptune.ai. # --- Other logger kwargs --- kwargs: diff --git a/mava/configs/network/mlp.yaml b/mava/configs/network/mlp.yaml index 7e50fa9ca..cab9ecd44 100644 --- a/mava/configs/network/mlp.yaml +++ b/mava/configs/network/mlp.yaml @@ -2,7 +2,7 @@ actor_network: pre_torso: _target_: mava.networks.MLPTorso - layer_sizes: [128, 512, 1024, 1024, 512, 512, 64, 64] + layer_sizes: [128, 128] use_layer_norm: False activation: relu diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index 56d710894..dfa914d79 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -96,10 +96,14 @@ def make_jumanji_env( env = jumanji.make(env_name, generator=generator) eval_env = jumanji.make(env_name, generator=generator) - env, eval_env = wrapper(env), wrapper(eval_env) - - env = add_optional_wrappers(env, config, add_global_state) - eval_env = add_optional_wrappers(eval_env, config, add_global_state) + if env_name == "MultiCVRP-v0": + env, eval_env = wrapper(env, add_global_state), wrapper(eval_env, add_global_state) + env = add_optional_wrappers(env, config ) + eval_env = add_optional_wrappers(eval_env, config) + else: + env, eval_env = wrapper(env), wrapper(eval_env) + env = add_optional_wrappers(env, config, add_global_state) + eval_env = add_optional_wrappers(eval_env, config, add_global_state) env = AutoResetWrapper(env) env = RecordEpisodeMetrics(env) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 60f68b251..6fce0d4fa 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -22,12 +22,9 @@ from jumanji.environments.routing.robot_warehouse import RobotWarehouse from jumanji.types import TimeStep, StepType from jumanji.wrappers import Wrapper - -import jax from jax import tree_util from jumanji.environments.routing.multi_cvrp import MultiCVRP - -from mava.types import Observation, State +from mava.types import Observation, ObservationGlobalState, State class MultiAgentWrapper(Wrapper): @@ -123,13 +120,14 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: return self.aggregate_rewards(timestep, modified_observation) class multiCVRPWrapper(Wrapper): - def __init__(self, env: MultiCVRP): + def __init__(self, env: MultiCVRP, has_global_state : bool = False): self.num_agents = env._num_vehicles self._env = env + self.has_global_state = has_global_state def reset(self, key: chex.PRNGKey) -> Tuple[State | TimeStep]: state , timestep = self._env.reset(key) - timestep = self.modify_timestep(timestep, state.step_count) # handeling the step_count is wrong in both gigastep / here it's allways set to 0 + timestep = self.modify_timestep(timestep, state.step_count) return state, timestep def step(self, state: State, action: chex.Array) -> Tuple[State | TimeStep]: @@ -138,37 +136,41 @@ def step(self, state: State, action: chex.Array) -> Tuple[State | TimeStep]: return state,timestep def modify_timestep(self, timestep: TimeStep, step_count : chex.Array) -> TimeStep[Observation]: - - observation = self._format_observation(timestep.observation) + observation, global_observation = self._format_observation(timestep.observation) + obs_data = { + "agents_view": observation, + "action_mask": timestep.observation.action_mask, + "step_count": jnp.repeat(step_count, (self.num_agents)), + } + if self.has_global_state: + obs_data["global_state"] = global_observation + observation = ObservationGlobalState(**obs_data) + else: + observation = Observation(**obs_data) + reward = jnp.repeat(timestep.reward, (self.num_agents)) discount = jnp.repeat(timestep.discount, (self.num_agents)) - step_count = jnp.repeat(step_count, (self.num_agents)) - observation = Observation( - agents_view=observation, - action_mask=timestep.observation.action_mask, - step_count=step_count, - ) - - timestep = timestep.replace(observation=observation, reward=reward, discount=discount) return timestep def _format_observation(self, observation): + global_observation = None #flatten and concat all of the observations for now customers_info, _ = tree_util.tree_flatten((observation.nodes,observation.windows,observation.coeffs)) vehicles_info , _ = tree_util.tree_flatten(observation.vehicles) #this results in c1-info1-c2,info2 customers_info = jnp.column_stack(customers_info).ravel() - #each agents needs to get the customers_info in their observation , alot of compute is wasted this way - customers_info = jnp.tile(customers_info, (self.num_agents, 1) ) - vehicles_info = jnp.column_stack(vehicles_info) - #(num_vechials, obs) with obs (vechial_obs + costumer_obs) if the costumer obs doest change much woudn't be better to pass it thought it's own network once? - observations = jnp.column_stack((vehicles_info, customers_info)) - return observations + if self.has_global_state: + global_observation = jnp.concat((customers_info, vehicles_info.ravel())) + global_observation = jnp.tile(global_observation, (self.num_agents, 1) ) + + customers_info = jnp.tile(customers_info, (self.num_agents, 1) ) + observations = jnp.column_stack((vehicles_info, customers_info)) + return observations, global_observation def observation_spec(self) -> specs.Spec[Observation]: step_count = specs.BoundedArray( @@ -184,6 +186,20 @@ def observation_spec(self) -> specs.Spec[Observation]: jnp.inf, "agents_view", ) + if self.has_global_state: + global_state = specs.Array( + (self.num_agents, (self._env._num_customers + 1) * 7 + 4 * self.num_agents), + jnp.float32, + "global_state", + ) + return specs.Spec( + ObservationGlobalState, + "ObservationSpec", + agents_view=agents_view, + action_mask=action_mask, + global_state=global_state, + step_count=step_count, + ) return specs.Spec( Observation, "ObservationSpec", From 09d0cd9a4ad0e69721c8dde0e0fc26362651f978 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 27 Feb 2024 13:27:00 +0100 Subject: [PATCH 03/19] fix: a lot of small changes --- mava/configs/env/multicvrp.yaml | 4 +-- .../{2v-20c.yaml => multicvrp-2v-20c.yaml} | 2 +- .../{2v-6c.yaml => multicvrp-2v-6c.yaml} | 2 +- mava/configs/logger/base_logger.yaml | 6 ++--- mava/utils/make_env.py | 26 +++++++----------- mava/wrappers/jumanji.py | 27 +++++++++++++++---- 6 files changed, 39 insertions(+), 28 deletions(-) rename mava/configs/env/scenario/{2v-20c.yaml => multicvrp-2v-20c.yaml} (72%) rename mava/configs/env/scenario/{2v-6c.yaml => multicvrp-2v-6c.yaml} (73%) diff --git a/mava/configs/env/multicvrp.yaml b/mava/configs/env/multicvrp.yaml index ab3e9d630..1b8024c54 100644 --- a/mava/configs/env/multicvrp.yaml +++ b/mava/configs/env/multicvrp.yaml @@ -1,10 +1,10 @@ # ---Environment Configs--- defaults: - _self_ - - scenario: 2v-20c # [2v-20c, 2v-6c] - + - scenario: multicvrp-2v-20c # [multicvrp-2v-20c, multicvrp-2v-6c] env_name: MultiCVRP eval_metric: episode_return +kwargs: {} \ No newline at end of file diff --git a/mava/configs/env/scenario/2v-20c.yaml b/mava/configs/env/scenario/multicvrp-2v-20c.yaml similarity index 72% rename from mava/configs/env/scenario/2v-20c.yaml rename to mava/configs/env/scenario/multicvrp-2v-20c.yaml index 49ea3b8c7..be711bbd0 100644 --- a/mava/configs/env/scenario/2v-20c.yaml +++ b/mava/configs/env/scenario/multicvrp-2v-20c.yaml @@ -1,5 +1,5 @@ name: MultiCVRP-v0 -task_name: 2v-20c +task_name: multicvrp-2v-20c task_config: num_customers : 20 diff --git a/mava/configs/env/scenario/2v-6c.yaml b/mava/configs/env/scenario/multicvrp-2v-6c.yaml similarity index 73% rename from mava/configs/env/scenario/2v-6c.yaml rename to mava/configs/env/scenario/multicvrp-2v-6c.yaml index 1952ff557..4c43e85ef 100644 --- a/mava/configs/env/scenario/2v-6c.yaml +++ b/mava/configs/env/scenario/multicvrp-2v-6c.yaml @@ -1,5 +1,5 @@ name: MultiCVRP-v0 -task_name: 2v-6c +task_name: multicvrp-2v-6c task_config: num_customers : 6 diff --git a/mava/configs/logger/base_logger.yaml b/mava/configs/logger/base_logger.yaml index a52f4a8f6..725d2de93 100644 --- a/mava/configs/logger/base_logger.yaml +++ b/mava/configs/logger/base_logger.yaml @@ -4,12 +4,12 @@ base_exp_path: results # Base path for logging. use_console: True # Whether to log to stdout. use_tb: False # Whether to use tensorboard logging. use_json: False # Whether to log marl-eval style to json files. -use_neptune: True # Whether to log to neptune.ai. +use_neptune: False # Whether to log to neptune.ai. # --- Other logger kwargs --- kwargs: neptune_project: Instadeep/Mava - neptune_tag: [multiCVRP first run] + neptune_tag: [rware] detailed_neptune_logging: False # having mean/std/min/max can clutter neptune so we make it optional json_path: ~ # If set, json files will be logged to a set path so that multiple experiments can # write to the same json file for easy downstream aggregation and plotting with marl-eval. @@ -29,4 +29,4 @@ checkpointing: load_model: False # Whether to load model checkpoints. load_args: - checkpoint_uid: "" # Unique identifier for checkpoint to load. + checkpoint_uid: "" # Unique identifier for checkpoint to load. \ No newline at end of file diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index aa3be4d09..a507ef5f8 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -66,7 +66,10 @@ def add_optional_wrappers( ) -> Environment: # Add the global state to observation. if add_global_state: - env = GlobalStateWrapper(env) + if hasattr(env, "has_global_state"): + env.has_global_state = True + else: + env = GlobalStateWrapper(env) # Add agent id to observation. if config.system.add_agent_id: @@ -94,21 +97,12 @@ def make_jumanji_env( wrapper = _jumanji_registry[env_name]["wrapper"] # Create envs. - if "kwargs" in config.env: - env = jumanji.make(env_name, generator=generator, **config.env.kwargs) - eval_env = jumanji.make(env_name, generator=generator, **config.env.kwargs) - else: - env = jumanji.make(env_name, generator=generator) - eval_env = jumanji.make(env_name, generator=generator) - - if env_name == "MultiCVRP-v0": - env, eval_env = wrapper(env, add_global_state), wrapper(eval_env, add_global_state) - env = add_optional_wrappers(env, config ) - eval_env = add_optional_wrappers(eval_env, config) - else: - env, eval_env = wrapper(env), wrapper(eval_env) - env = add_optional_wrappers(env, config, add_global_state) - eval_env = add_optional_wrappers(eval_env, config, add_global_state) + env = jumanji.make(env_name, generator=generator, **config.env.kwargs) + eval_env = jumanji.make(env_name, generator=generator, **config.env.kwargs) + + env, eval_env = wrapper(env), wrapper(eval_env) + env = add_optional_wrappers(env, config, add_global_state) + eval_env = add_optional_wrappers(eval_env, config, add_global_state) env = AutoResetWrapper(env) env = RecordEpisodeMetrics(env) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index aacc0a8c2..5065c5354 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Tuple, Union, Dict import chex import jax.numpy as jnp @@ -220,6 +220,8 @@ def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalSta return spec class multiCVRPWrapper(Wrapper): + """ Wrapper for MultiCVRP environment. """ + def __init__(self, env: MultiCVRP, has_global_state : bool = False): self.num_agents = env._num_vehicles self._env = env @@ -234,8 +236,9 @@ def step(self, state: State, action: chex.Array) -> Tuple[State | TimeStep]: state, timestep = self._env.step(state,action) timestep = self.modify_timestep(timestep, state.step_count) return state,timestep - + def modify_timestep(self, timestep: TimeStep, step_count : chex.Array) -> TimeStep[Observation]: + #avoided the MultiAgentWrapper wrapper to use the step_count provided by the environment observation, global_observation = self._format_observation(timestep.observation) obs_data = { "agents_view": observation, @@ -253,13 +256,26 @@ def modify_timestep(self, timestep: TimeStep, step_count : chex.Array) -> TimeSt timestep = timestep.replace(observation=observation, reward=reward, discount=discount) return timestep - def _format_observation(self, observation): + def _format_observation(self, observation : Dict[str, chex.Array]) -> Tuple[chex.Array, Union[None, chex.Array]]: + """ + Formats the observation dictionary from the environment into a format suitable for mava. + + Args: + observation (Dict[chex.Array]): The raw observation dict provided by jumanji + + Returns: + observations (chex.Array): Concatenated individual observations for each agent, + shaped (num_agents, vehicle_info + customer_info). + global_observation (Union[None, chex.Array]): Concatenated global observation + shaped (num_agents, global_info) if has_global_state = True, + None otherwise. + """ global_observation = None #flatten and concat all of the observations for now customers_info, _ = tree_util.tree_flatten((observation.nodes,observation.windows,observation.coeffs)) vehicles_info , _ = tree_util.tree_flatten(observation.vehicles) - #this results in c1-info1-c2,info2 + #this results in c1_info1-c2_info customers_info = jnp.column_stack(customers_info).ravel() vehicles_info = jnp.column_stack(vehicles_info) @@ -279,8 +295,9 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask = specs.BoundedArray( (self.num_agents, self._env._num_customers + 1), bool, False, True, "action_mask" ) + #7 is broken into 2 for cords, 1 each of demands,start,end,early,late and the 4 is the cords,capacity of the vehicle agents_view = specs.BoundedArray( - (self.num_agents, (self._env._num_customers + 1) * 7 + 4), #7 is broken into 2 for cords, 1 each of demands,start,end,early,late and the 4 is the cords,capacity of the veichale + (self.num_agents, (self._env._num_customers + 1) * 7 + 4), jnp.float32, -jnp.inf, jnp.inf, From bf96c8b4a62b6d0aa1d80e36ce3bf4c3c46d994c Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Tue, 27 Feb 2024 13:42:09 +0100 Subject: [PATCH 04/19] fix: minor changes --- mava/configs/env/multicvrp.yaml | 2 +- .../env/scenario/multicvrp-2v-20c.yaml | 2 - .../configs/env/scenario/multicvrp-2v-6c.yaml | 2 - mava/configs/logger/base_logger.yaml | 2 +- mava/utils/make_env.py | 10 +-- mava/wrappers/__init__.py | 7 +- mava/wrappers/jumanji.py | 88 ++++++++++--------- 7 files changed, 61 insertions(+), 52 deletions(-) diff --git a/mava/configs/env/multicvrp.yaml b/mava/configs/env/multicvrp.yaml index 1b8024c54..1090c12df 100644 --- a/mava/configs/env/multicvrp.yaml +++ b/mava/configs/env/multicvrp.yaml @@ -7,4 +7,4 @@ env_name: MultiCVRP eval_metric: episode_return -kwargs: {} \ No newline at end of file +kwargs: {} diff --git a/mava/configs/env/scenario/multicvrp-2v-20c.yaml b/mava/configs/env/scenario/multicvrp-2v-20c.yaml index be711bbd0..66702df5d 100644 --- a/mava/configs/env/scenario/multicvrp-2v-20c.yaml +++ b/mava/configs/env/scenario/multicvrp-2v-20c.yaml @@ -4,5 +4,3 @@ task_name: multicvrp-2v-20c task_config: num_customers : 20 num_vehicles : 2 - - diff --git a/mava/configs/env/scenario/multicvrp-2v-6c.yaml b/mava/configs/env/scenario/multicvrp-2v-6c.yaml index 4c43e85ef..83341b538 100644 --- a/mava/configs/env/scenario/multicvrp-2v-6c.yaml +++ b/mava/configs/env/scenario/multicvrp-2v-6c.yaml @@ -4,5 +4,3 @@ task_name: multicvrp-2v-6c task_config: num_customers : 6 num_vehicles : 2 - - diff --git a/mava/configs/logger/base_logger.yaml b/mava/configs/logger/base_logger.yaml index 725d2de93..9a1e37b87 100644 --- a/mava/configs/logger/base_logger.yaml +++ b/mava/configs/logger/base_logger.yaml @@ -29,4 +29,4 @@ checkpointing: load_model: False # Whether to load model checkpoints. load_args: - checkpoint_uid: "" # Unique identifier for checkpoint to load. \ No newline at end of file + checkpoint_uid: "" # Unique identifier for checkpoint to load. diff --git a/mava/utils/make_env.py b/mava/utils/make_env.py index a507ef5f8..d3ab7a822 100644 --- a/mava/utils/make_env.py +++ b/mava/utils/make_env.py @@ -25,12 +25,12 @@ from jumanji.environments.routing.lbf.generator import ( RandomGenerator as LbfRandomGenerator, ) -from jumanji.environments.routing.robot_warehouse.generator import ( - RandomGenerator as RwareRandomGenerator, -) from jumanji.environments.routing.multi_cvrp.generator import ( UniformRandomGenerator as MultiCVRPRandomGenerator, ) +from jumanji.environments.routing.robot_warehouse.generator import ( + RandomGenerator as RwareRandomGenerator, +) from omegaconf import DictConfig from mava.wrappers import ( @@ -41,10 +41,10 @@ LbfWrapper, MabraxWrapper, MatraxWrapper, + MultiCVRPWrapper, RecordEpisodeMetrics, RwareWrapper, SmaxWrapper, - multiCVRPWrapper, ) # Registry mapping environment names to their generator and wrapper classes. @@ -52,7 +52,7 @@ "RobotWarehouse-v0": {"generator": RwareRandomGenerator, "wrapper": RwareWrapper}, "LevelBasedForaging-v0": {"generator": LbfRandomGenerator, "wrapper": LbfWrapper}, "MaConnector-v2": {"generator": ConnectorRandomGenerator, "wrapper": ConnectorWrapper}, - "MultiCVRP-v0": {"generator": MultiCVRPRandomGenerator, "wrapper": multiCVRPWrapper}, + "MultiCVRP-v0": {"generator": MultiCVRPRandomGenerator, "wrapper": MultiCVRPWrapper}, } # Define a different registry for Matrax since it has no generator. diff --git a/mava/wrappers/__init__.py b/mava/wrappers/__init__.py index d7a0b966d..3a5d06ffb 100644 --- a/mava/wrappers/__init__.py +++ b/mava/wrappers/__init__.py @@ -15,6 +15,11 @@ from mava.wrappers.auto_reset_wrapper import AutoResetWrapper from mava.wrappers.episode_metrics import RecordEpisodeMetrics from mava.wrappers.jaxmarl import MabraxWrapper, SmaxWrapper -from mava.wrappers.jumanji import ConnectorWrapper, LbfWrapper, RwareWrapper, multiCVRPWrapper +from mava.wrappers.jumanji import ( + ConnectorWrapper, + LbfWrapper, + MultiCVRPWrapper, + RwareWrapper, +) from mava.wrappers.matrax import MatraxWrapper from mava.wrappers.observation import AgentIDWrapper, GlobalStateWrapper diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 5065c5354..5bab2527c 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union, Dict +from typing import Tuple, Union import chex import jax.numpy as jnp +from jax import tree_util from jumanji import specs from jumanji.env import Environment from jumanji.environments.routing.connector import MaConnector @@ -26,11 +27,12 @@ TARGET, ) from jumanji.environments.routing.lbf import LevelBasedForaging +from jumanji.environments.routing.multi_cvrp import MultiCVRP +from jumanji.environments.routing.multi_cvrp.types import State as MultiCVRPState from jumanji.environments.routing.robot_warehouse import RobotWarehouse -from jumanji.types import TimeStep, StepType +from jumanji.types import TimeStep from jumanji.wrappers import Wrapper -from jax import tree_util -from jumanji.environments.routing.multi_cvrp import MultiCVRP + from mava.types import Observation, ObservationGlobalState, State @@ -218,27 +220,28 @@ def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalSta ) return spec - -class multiCVRPWrapper(Wrapper): - """ Wrapper for MultiCVRP environment. """ - def __init__(self, env: MultiCVRP, has_global_state : bool = False): + +class MultiCVRPWrapper(Wrapper): + """Wrapper for MultiCVRP environment.""" + + def __init__(self, env: MultiCVRP, has_global_state: bool = False): self.num_agents = env._num_vehicles self._env = env self.has_global_state = has_global_state - def reset(self, key: chex.PRNGKey) -> Tuple[State | TimeStep]: - state , timestep = self._env.reset(key) - timestep = self.modify_timestep(timestep, state.step_count) + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: + state, timestep = self._env.reset(key) + timestep = self.modify_timestep(timestep, state.step_count) return state, timestep - - def step(self, state: State, action: chex.Array) -> Tuple[State | TimeStep]: - state, timestep = self._env.step(state,action) + + def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: + state, timestep = self._env.step(state, action) timestep = self.modify_timestep(timestep, state.step_count) - return state,timestep - - def modify_timestep(self, timestep: TimeStep, step_count : chex.Array) -> TimeStep[Observation]: - #avoided the MultiAgentWrapper wrapper to use the step_count provided by the environment + return state, timestep + + def modify_timestep(self, timestep: TimeStep, step_count: chex.Array) -> TimeStep[Observation]: + # avoided the MultiAgentWrapper wrapper to use the step_count provided by the environment observation, global_observation = self._format_observation(timestep.observation) obs_data = { "agents_view": observation, @@ -250,13 +253,15 @@ def modify_timestep(self, timestep: TimeStep, step_count : chex.Array) -> TimeSt observation = ObservationGlobalState(**obs_data) else: observation = Observation(**obs_data) - + reward = jnp.repeat(timestep.reward, (self.num_agents)) discount = jnp.repeat(timestep.discount, (self.num_agents)) timestep = timestep.replace(observation=observation, reward=reward, discount=discount) return timestep - - def _format_observation(self, observation : Dict[str, chex.Array]) -> Tuple[chex.Array, Union[None, chex.Array]]: + + def _format_observation( + self, observation: MultiCVRPState + ) -> Tuple[chex.Array, Union[None, chex.Array]]: """ Formats the observation dictionary from the environment into a format suitable for mava. @@ -267,37 +272,40 @@ def _format_observation(self, observation : Dict[str, chex.Array]) -> Tuple[chex observations (chex.Array): Concatenated individual observations for each agent, shaped (num_agents, vehicle_info + customer_info). global_observation (Union[None, chex.Array]): Concatenated global observation - shaped (num_agents, global_info) if has_global_state = True, + shaped (num_agents, global_info) + if has_global_state = True, None otherwise. """ - global_observation = None - #flatten and concat all of the observations for now - customers_info, _ = tree_util.tree_flatten((observation.nodes,observation.windows,observation.coeffs)) - vehicles_info , _ = tree_util.tree_flatten(observation.vehicles) - - #this results in c1_info1-c2_info + global_observation = None + # flatten and concat all of the observations for now + customers_info, _ = tree_util.tree_flatten( + (observation.nodes, observation.windows, observation.coeffs) + ) + vehicles_info, _ = tree_util.tree_flatten(observation.vehicles) + + # this results in c1_info1-c2_info customers_info = jnp.column_stack(customers_info).ravel() vehicles_info = jnp.column_stack(vehicles_info) - if self.has_global_state: global_observation = jnp.concat((customers_info, vehicles_info.ravel())) - global_observation = jnp.tile(global_observation, (self.num_agents, 1) ) + global_observation = jnp.tile(global_observation, (self.num_agents, 1)) - customers_info = jnp.tile(customers_info, (self.num_agents, 1) ) - observations = jnp.column_stack((vehicles_info, customers_info)) + customers_info = jnp.tile(customers_info, (self.num_agents, 1)) + observations = jnp.column_stack((vehicles_info, customers_info)) return observations, global_observation - + def observation_spec(self) -> specs.Spec[Observation]: step_count = specs.BoundedArray( - (self.num_agents,), jnp.int32,0, self._env._num_customers + 1 ,"step_count" + (self.num_agents,), jnp.int32, 0, self._env._num_customers + 1, "step_count" ) action_mask = specs.BoundedArray( (self.num_agents, self._env._num_customers + 1), bool, False, True, "action_mask" ) - #7 is broken into 2 for cords, 1 each of demands,start,end,early,late and the 4 is the cords,capacity of the vehicle + # 7 is broken into 2 for cords, 1 each of demands,start,end,early,late + # and the 4 is the cords,capacity of the vehicle agents_view = specs.BoundedArray( - (self.num_agents, (self._env._num_customers + 1) * 7 + 4), + (self.num_agents, (self._env._num_customers + 1) * 7 + 4), jnp.float32, -jnp.inf, jnp.inf, @@ -332,8 +340,8 @@ def discount_spec(self) -> specs.BoundedArray: return specs.BoundedArray( shape=(self.num_agents,), dtype=float, minimum=0.0, maximum=1.0, name="discount" ) - - def action_spec(self) -> specs.Spec: - return specs.MultiDiscreteArray(num_values=jnp.full(self.num_agents, self._env._num_customers + 1)) - + def action_spec(self) -> specs.Spec: + return specs.MultiDiscreteArray( + num_values=jnp.full(self.num_agents, self._env._num_customers + 1) + ) From 4d46916a852f49e42cec41fed9e5c24b1027ec27 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 28 Feb 2024 17:04:51 +0100 Subject: [PATCH 05/19] fix: comments + super + observation insted of state --- mava/wrappers/jumanji.py | 44 ++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 5bab2527c..b3e699a2b 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -28,7 +28,7 @@ ) from jumanji.environments.routing.lbf import LevelBasedForaging from jumanji.environments.routing.multi_cvrp import MultiCVRP -from jumanji.environments.routing.multi_cvrp.types import State as MultiCVRPState +from jumanji.environments.routing.multi_cvrp.types import Observation as MultiCvrpObservation from jumanji.environments.routing.robot_warehouse import RobotWarehouse from jumanji.types import TimeStep from jumanji.wrappers import Wrapper @@ -224,14 +224,15 @@ def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalSta class MultiCVRPWrapper(Wrapper): """Wrapper for MultiCVRP environment.""" - + def __init__(self, env: MultiCVRP, has_global_state: bool = False): + super().__init__(env) self.num_agents = env._num_vehicles self._env = env self.has_global_state = has_global_state def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: - state, timestep = self._env.reset(key) + state, timestep = self._env.reset(key) timestep = self.modify_timestep(timestep, state.step_count) return state, timestep @@ -242,7 +243,7 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: def modify_timestep(self, timestep: TimeStep, step_count: chex.Array) -> TimeStep[Observation]: # avoided the MultiAgentWrapper wrapper to use the step_count provided by the environment - observation, global_observation = self._format_observation(timestep.observation) + observation, global_observation = self._flatten_observation(timestep.observation) obs_data = { "agents_view": observation, "action_mask": timestep.observation.action_mask, @@ -259,39 +260,49 @@ def modify_timestep(self, timestep: TimeStep, step_count: chex.Array) -> TimeSte timestep = timestep.replace(observation=observation, reward=reward, discount=discount) return timestep - def _format_observation( - self, observation: MultiCVRPState + def _flatten_observation( + self, observation: MultiCvrpObservation ) -> Tuple[chex.Array, Union[None, chex.Array]]: """ - Formats the observation dictionary from the environment into a format suitable for mava. + Concatenates all observation fields into a single array. Args: - observation (Dict[chex.Array]): The raw observation dict provided by jumanji + observation (MultiCvrpObservation): The raw observation NamedTuple provided by jumanji. Returns: observations (chex.Array): Concatenated individual observations for each agent, - shaped (num_agents, vehicle_info + customer_info). + shaped (num_agents, vehicle_info + customer_info). global_observation (Union[None, chex.Array]): Concatenated global observation - shaped (num_agents, global_info) - if has_global_state = True, - None otherwise. + shaped (num_agents, global_info) if has_global_state = True, None otherwise. """ global_observation = None - # flatten and concat all of the observations for now + # N: number of nodes, same as _num_customers + 1 + # V: number of vehicles, same as num_agents + # nodes are composed of (x, y, demands) + # Windows are composed of (start_time, end_time) + # Coeffs are composed of (early, late) + + # Tuple[(N, 3), (N, 2), (N, 2)] customers_info, _ = tree_util.tree_flatten( (observation.nodes, observation.windows, observation.coeffs) ) + # Tuple[(V, 2), (V, 1), (V, 1)] vehicles_info, _ = tree_util.tree_flatten(observation.vehicles) - # this results in c1_info1-c2_info + # (N * 7, ) customers_info = jnp.column_stack(customers_info).ravel() + # (V, 4) vehicles_info = jnp.column_stack(vehicles_info) if self.has_global_state: - global_observation = jnp.concat((customers_info, vehicles_info.ravel())) + #(V * 4 * N * 7, ) + global_observation = jnp.concat((vehicles_info.ravel(), customers_info)) + #(V, N * 7 * V * 4) global_observation = jnp.tile(global_observation, (self.num_agents, 1)) + # (V, N * 7) customers_info = jnp.tile(customers_info, (self.num_agents, 1)) + # (V, 4 + N * 7) observations = jnp.column_stack((vehicles_info, customers_info)) return observations, global_observation @@ -302,8 +313,7 @@ def observation_spec(self) -> specs.Spec[Observation]: action_mask = specs.BoundedArray( (self.num_agents, self._env._num_customers + 1), bool, False, True, "action_mask" ) - # 7 is broken into 2 for cords, 1 each of demands,start,end,early,late - # and the 4 is the cords,capacity of the vehicle + agents_view = specs.BoundedArray( (self.num_agents, (self._env._num_customers + 1) * 7 + 4), jnp.float32, From e3e002c0a2dee99230cd721e8b2f79791c30c732 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Wed, 28 Feb 2024 17:09:14 +0100 Subject: [PATCH 06/19] fix: minor comment changes --- mava/wrappers/jumanji.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index b3e699a2b..3b79ad926 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -28,7 +28,9 @@ ) from jumanji.environments.routing.lbf import LevelBasedForaging from jumanji.environments.routing.multi_cvrp import MultiCVRP -from jumanji.environments.routing.multi_cvrp.types import Observation as MultiCvrpObservation +from jumanji.environments.routing.multi_cvrp.types import ( + Observation as MultiCvrpObservation, +) from jumanji.environments.routing.robot_warehouse import RobotWarehouse from jumanji.types import TimeStep from jumanji.wrappers import Wrapper @@ -224,7 +226,7 @@ def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalSta class MultiCVRPWrapper(Wrapper): """Wrapper for MultiCVRP environment.""" - + def __init__(self, env: MultiCVRP, has_global_state: bool = False): super().__init__(env) self.num_agents = env._num_vehicles @@ -232,7 +234,7 @@ def __init__(self, env: MultiCVRP, has_global_state: bool = False): self.has_global_state = has_global_state def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: - state, timestep = self._env.reset(key) + state, timestep = self._env.reset(key) timestep = self.modify_timestep(timestep, state.step_count) return state, timestep @@ -281,6 +283,7 @@ def _flatten_observation( # nodes are composed of (x, y, demands) # Windows are composed of (start_time, end_time) # Coeffs are composed of (early, late) + # Vehicles have ((x, y), local_time, capacity) # Tuple[(N, 3), (N, 2), (N, 2)] customers_info, _ = tree_util.tree_flatten( @@ -295,9 +298,9 @@ def _flatten_observation( vehicles_info = jnp.column_stack(vehicles_info) if self.has_global_state: - #(V * 4 * N * 7, ) + # (V * 4 * N * 7, ) global_observation = jnp.concat((vehicles_info.ravel(), customers_info)) - #(V, N * 7 * V * 4) + # (V, N * 7 * V * 4) global_observation = jnp.tile(global_observation, (self.num_agents, 1)) # (V, N * 7) From d3da4ed0751555a941f9e9a1c2c45b6b93e818bf Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 4 Mar 2024 13:53:44 +0100 Subject: [PATCH 07/19] fix: use the multiagentWrapper --- mava/wrappers/jumanji.py | 40 +++++++++++++++------------------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 3b79ad926..7715c3862 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Tuple, Union, Optional import chex import jax.numpy as jnp @@ -44,27 +44,27 @@ def __init__(self, env: Environment): self._num_agents = self._env.num_agents self.time_limit = self._env.time_limit - def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: + def modify_timestep(self, timestep: TimeStep, state: Optional[State] = None) -> TimeStep[Observation]: """Modify the timestep for `step` and `reset`.""" pass def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: """Reset the environment.""" state, timestep = self._env.reset(key) - return state, self.modify_timestep(timestep) + return state, self.modify_timestep(timestep, state) def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: """Step the environment.""" state, timestep = self._env.step(state, action) - return state, self.modify_timestep(timestep) + return state, self.modify_timestep(timestep, state) def observation_spec(self) -> specs.Spec[Observation]: """Specification of the observation of the environment.""" step_count = specs.BoundedArray( (self._num_agents,), - int, - jnp.zeros(self._num_agents, dtype=int), - jnp.repeat(self.time_limit, self._num_agents), + jnp.int32, + [0] * self._num_agents, + [self.time_limit] * self._num_agents, "step_count", ) return self._env.observation_spec().replace(step_count=step_count) @@ -76,7 +76,7 @@ class RwareWrapper(MultiAgentWrapper): def __init__(self, env: RobotWarehouse): super().__init__(env) - def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: + def modify_timestep(self, timestep: TimeStep, state : State) -> TimeStep[Observation]: """Modify the timestep for the Robotic Warehouse environment.""" observation = Observation( agents_view=timestep.observation.agents_view, @@ -113,7 +113,7 @@ def aggregate_rewards( reward = jnp.repeat(team_reward, self._num_agents) return timestep.replace(observation=observation, reward=reward) - def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: + def modify_timestep(self, timestep: TimeStep, state : State) -> TimeStep[Observation]: """Modify the timestep for Level-Based Foraging environment and update the reward based on the specified reward handling strategy.""" @@ -142,7 +142,7 @@ def __init__(self, env: MaConnector, has_global_state: bool = False): self.has_global_state = has_global_state def modify_timestep( - self, timestep: TimeStep + self, timestep: TimeStep, state: State ) -> TimeStep[Union[Observation, ObservationGlobalState]]: """Modify the timestep for the Connector environment.""" @@ -224,32 +224,22 @@ def observation_spec(self) -> specs.Spec[Union[Observation, ObservationGlobalSta return spec -class MultiCVRPWrapper(Wrapper): +class MultiCVRPWrapper(MultiAgentWrapper): """Wrapper for MultiCVRP environment.""" def __init__(self, env: MultiCVRP, has_global_state: bool = False): + env.num_agents = env._num_vehicles + env.time_limit = env._num_customers + 1 #added for consistency super().__init__(env) - self.num_agents = env._num_vehicles self._env = env self.has_global_state = has_global_state - def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]: - state, timestep = self._env.reset(key) - timestep = self.modify_timestep(timestep, state.step_count) - return state, timestep - - def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]: - state, timestep = self._env.step(state, action) - timestep = self.modify_timestep(timestep, state.step_count) - return state, timestep - - def modify_timestep(self, timestep: TimeStep, step_count: chex.Array) -> TimeStep[Observation]: - # avoided the MultiAgentWrapper wrapper to use the step_count provided by the environment + def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]: observation, global_observation = self._flatten_observation(timestep.observation) obs_data = { "agents_view": observation, "action_mask": timestep.observation.action_mask, - "step_count": jnp.repeat(step_count, (self.num_agents)), + "step_count": jnp.repeat(state.step_count, (self.num_agents)), } if self.has_global_state: obs_data["global_state"] = global_observation From 1486b53eee77bd829ca31c197d2dd2616c8154ce Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 4 Mar 2024 13:54:31 +0100 Subject: [PATCH 08/19] fix: annotiaions --- mava/wrappers/jumanji.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 7715c3862..de2c72086 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union, Optional +from typing import Tuple, Union import chex import jax.numpy as jnp @@ -44,7 +44,7 @@ def __init__(self, env: Environment): self._num_agents = self._env.num_agents self.time_limit = self._env.time_limit - def modify_timestep(self, timestep: TimeStep, state: Optional[State] = None) -> TimeStep[Observation]: + def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]: """Modify the timestep for `step` and `reset`.""" pass @@ -76,7 +76,7 @@ class RwareWrapper(MultiAgentWrapper): def __init__(self, env: RobotWarehouse): super().__init__(env) - def modify_timestep(self, timestep: TimeStep, state : State) -> TimeStep[Observation]: + def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]: """Modify the timestep for the Robotic Warehouse environment.""" observation = Observation( agents_view=timestep.observation.agents_view, @@ -113,7 +113,7 @@ def aggregate_rewards( reward = jnp.repeat(team_reward, self._num_agents) return timestep.replace(observation=observation, reward=reward) - def modify_timestep(self, timestep: TimeStep, state : State) -> TimeStep[Observation]: + def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]: """Modify the timestep for Level-Based Foraging environment and update the reward based on the specified reward handling strategy.""" @@ -229,7 +229,7 @@ class MultiCVRPWrapper(MultiAgentWrapper): def __init__(self, env: MultiCVRP, has_global_state: bool = False): env.num_agents = env._num_vehicles - env.time_limit = env._num_customers + 1 #added for consistency + env.time_limit = env._num_customers + 1 # added for consistency super().__init__(env) self._env = env self.has_global_state = has_global_state @@ -289,7 +289,7 @@ def _flatten_observation( if self.has_global_state: # (V * 4 * N * 7, ) - global_observation = jnp.concat((vehicles_info.ravel(), customers_info)) + global_observation = jnp.concatenate((vehicles_info.ravel(), customers_info)) # (V, N * 7 * V * 4) global_observation = jnp.tile(global_observation, (self.num_agents, 1)) From 9743c597540a0cc3e292a5ab592fff06a0dd16ef Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 11 Mar 2024 11:32:57 +0100 Subject: [PATCH 09/19] fix: pre-commit --- mava/wrappers/jumanji.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 73a9e5bb8..66471a05e 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -39,7 +39,7 @@ from jumanji.wrappers import Wrapper from mava.types import Observation, ObservationGlobalState, State -import jax + class MultiAgentWrapper(Wrapper, ABC): def __init__(self, env: Environment, add_global_state: bool): @@ -285,7 +285,6 @@ def __init__(self, env: MultiCVRP, add_global_state: bool = False): self.has_global_state = add_global_state self._env = env - def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]: observation, global_observation = self._flatten_observation(timestep.observation) obs_data = { @@ -301,8 +300,7 @@ def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observat reward = jnp.repeat(timestep.reward, (self.num_agents)) discount = jnp.repeat(timestep.discount, (self.num_agents)) - timestep = timestep.replace(observation=observation, reward=reward, discount=discount) - return timestep + return timestep.replace(observation=observation, reward=reward, discount=discount) def _flatten_observation( self, observation: MultiCvrpObservation @@ -367,9 +365,9 @@ def observation_spec(self) -> specs.Spec[Observation]: "agents_view", ) obs_data = { - "agents_view":agents_view, - "action_mask":action_mask, - "step_count":step_count, + "agents_view": agents_view, + "action_mask": action_mask, + "step_count": step_count, } if self.has_global_state: @@ -382,7 +380,6 @@ def observation_spec(self) -> specs.Spec[Observation]: return specs.Spec(ObservationGlobalState, "ObservationSpec", **obs_data) return specs.Spec(Observation, "ObservationSpec", **obs_data) - def reward_spec(self) -> specs.Array: return specs.Array(shape=(self.num_agents,), dtype=float, name="reward") From 1ed331b89fadedbecb23cf2e52dd78f6e077a805 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 18 Mar 2024 10:03:36 +0100 Subject: [PATCH 10/19] chore: node --> Node in docs --- mava/wrappers/jumanji.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 66471a05e..880b787b9 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -320,7 +320,7 @@ def _flatten_observation( global_observation = None # N: number of nodes, same as _num_customers + 1 # V: number of vehicles, same as num_agents - # nodes are composed of (x, y, demands) + # Nodes are composed of (x, y, demands) # Windows are composed of (start_time, end_time) # Coeffs are composed of (early, late) # Vehicles have ((x, y), local_time, capacity) @@ -332,7 +332,7 @@ def _flatten_observation( # Tuple[(V, 2), (V, 1), (V, 1)] vehicles_info, _ = tree_util.tree_flatten(observation.vehicles) - # (N * 7, ) + # (N * 7, ) customers_info = jnp.column_stack(customers_info).ravel() # (V, 4) vehicles_info = jnp.column_stack(vehicles_info) From c0e7d026deb397ad9834298c404709002e88b442 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 18 Mar 2024 10:44:21 +0100 Subject: [PATCH 11/19] chore: hardcoded numbers docs --- mava/wrappers/jumanji.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 880b787b9..dd9a5469e 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -325,16 +325,16 @@ def _flatten_observation( # Coeffs are composed of (early, late) # Vehicles have ((x, y), local_time, capacity) - # Tuple[(N, 3), (N, 2), (N, 2)] + # Tuple[(N, 3) : Nodes, (N, 2) : Windows, (N, 2) : Coeffs] customers_info, _ = tree_util.tree_flatten( (observation.nodes, observation.windows, observation.coeffs) ) - # Tuple[(V, 2), (V, 1), (V, 1)] + # Tuple[(V, 2) : Coordinates, (V, 1) : Local time, (V, 1) : Capacity] vehicles_info, _ = tree_util.tree_flatten(observation.vehicles) - # (N * 7, ) + # (N * 7, ) : N * (7 : Nodes (3) + Windows (2) + Coeffs (2)) customers_info = jnp.column_stack(customers_info).ravel() - # (V, 4) + # (V, 4) : V, (4 : Coordinates (2), Local Time (1), Coeffs (1)) vehicles_info = jnp.column_stack(vehicles_info) if self.has_global_state: From 9ddc91ad8bd8b02051d70b8ef85d6ccabbe54bea Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Mon, 18 Mar 2024 10:46:28 +0100 Subject: [PATCH 12/19] fix: pre-commit --- mava/wrappers/jumanji.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 6b2687157..3e46d60a5 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -325,14 +325,14 @@ def _flatten_observation( # Coeffs are composed of (early, late) # Vehicles have ((x, y), local_time, capacity) - # Tuple[(N, 3) : Nodes, (N, 2) : Windows, (N, 2) : Coeffs] + # Tuple[(N, 3) : Nodes, (N, 2) : Windows, (N, 2) : Coeffs] customers_info, _ = tree_util.tree_flatten( (observation.nodes, observation.windows, observation.coeffs) ) # Tuple[(V, 2) : Coordinates, (V, 1) : Local time, (V, 1) : Capacity] vehicles_info, _ = tree_util.tree_flatten(observation.vehicles) - # (N * 7, ) : N * (7 : Nodes (3) + Windows (2) + Coeffs (2)) + # (N * 7, ) : N * (7 : Nodes (3) + Windows (2) + Coeffs (2)) customers_info = jnp.column_stack(customers_info).ravel() # (V, 4) : V, (4 : Coordinates (2), Local Time (1), Coeffs (1)) vehicles_info = jnp.column_stack(vehicles_info) From b5b5a4f478ac659d39b7ce1bcbfcb1ef14751159 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 21 Mar 2024 10:13:26 +0100 Subject: [PATCH 13/19] chore: removed the reward/action specs --- mava/wrappers/jumanji.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 3e46d60a5..e6e9eabfd 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -281,16 +281,17 @@ def __init__(self, env: MultiCVRP, add_global_state: bool = False): env.num_agents = env._num_vehicles env.time_limit = env._num_customers + 1 # added for consistency env.action_dim = env._num_customers + 1 # n_costumers + 1 starter node - super().__init__(env, False) self.has_global_state = add_global_state + self.num_customers = env._num_customers self._env = env + super().__init__(env, False) def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]: observation, global_observation = self._flatten_observation(timestep.observation) obs_data = { "agents_view": observation, "action_mask": timestep.observation.action_mask, - "step_count": jnp.repeat(state.step_count, (self.num_agents)), + "step_count": jnp.repeat(state.step_count, self.num_agents), } if self.has_global_state: obs_data["global_state"] = global_observation @@ -298,8 +299,8 @@ def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observat else: observation = Observation(**obs_data) - reward = jnp.repeat(timestep.reward, (self.num_agents)) - discount = jnp.repeat(timestep.discount, (self.num_agents)) + reward = jnp.repeat(timestep.reward, self.num_agents) + discount = jnp.repeat(timestep.discount, self.num_agents) return timestep.replace(observation=observation, reward=reward, discount=discount) def _flatten_observation( @@ -351,14 +352,14 @@ def _flatten_observation( def observation_spec(self) -> specs.Spec[Observation]: step_count = specs.BoundedArray( - (self.num_agents,), jnp.int32, 0, self._env._num_customers + 1, "step_count" + (self.num_agents,), jnp.int32, 0, self.num_customers + 1, "step_count" ) action_mask = specs.BoundedArray( - (self.num_agents, self._env._num_customers + 1), bool, False, True, "action_mask" + (self.num_agents, self.num_customers + 1), bool, False, True, "action_mask" ) agents_view = specs.BoundedArray( - (self.num_agents, (self._env._num_customers + 1) * 7 + 4), + (self.num_agents, (self.num_customers + 1) * 7 + 4), jnp.float32, -jnp.inf, jnp.inf, @@ -372,7 +373,7 @@ def observation_spec(self) -> specs.Spec[Observation]: if self.has_global_state: global_state = specs.Array( - (self.num_agents, (self._env._num_customers + 1) * 7 + 4 * self.num_agents), + (self.num_agents, (self.num_customers + 1) * 7 + 4 * self.num_agents), jnp.float32, "global_state", ) @@ -381,15 +382,7 @@ def observation_spec(self) -> specs.Spec[Observation]: return specs.Spec(Observation, "ObservationSpec", **obs_data) - def reward_spec(self) -> specs.Array: - return specs.Array(shape=(self.num_agents,), dtype=float, name="reward") - - def discount_spec(self) -> specs.BoundedArray: - return specs.BoundedArray( - shape=(self.num_agents,), dtype=float, minimum=0.0, maximum=1.0, name="discount" - ) - def action_spec(self) -> specs.Spec: return specs.MultiDiscreteArray( - num_values=jnp.full(self.num_agents, self._env._num_customers + 1) + num_values=jnp.full(self.num_agents, self.num_customers + 1) ) From 02f69e12b39db5da00a7b4cea6fc0fa75165dba5 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Thu, 21 Mar 2024 10:31:13 +0100 Subject: [PATCH 14/19] fix: added state to cleanr's modifie_timestep function --- mava/wrappers/jumanji.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index bb72c0d8b..8deb2ef04 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -283,7 +283,7 @@ def __init__(self, env: Cleaner, add_global_state: bool = False): super().__init__(env, add_global_state) self._env: Cleaner - def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: + def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]: """Modify the timestep for the Cleaner environment.""" def create_agents_view(grid: chex.Array, agents_locations: chex.Array) -> chex.Array: From 2e30ffd618d727f8095eca1f47d3af5157437b0e Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 22 Mar 2024 08:55:51 +0100 Subject: [PATCH 15/19] fix: removed the unneeded action_spec --- mava/wrappers/jumanji.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 8deb2ef04..c783b5b07 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -390,8 +390,8 @@ def __init__(self, env: MultiCVRP, add_global_state: bool = False): env.action_dim = env._num_customers + 1 # n_costumers + 1 starter node self.has_global_state = add_global_state self.num_customers = env._num_customers - self._env = env super().__init__(env, False) + self._env : MultiCVRP def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]: observation, global_observation = self._flatten_observation(timestep.observation) @@ -489,7 +489,3 @@ def observation_spec(self) -> specs.Spec[Observation]: return specs.Spec(Observation, "ObservationSpec", **obs_data) - def action_spec(self) -> specs.Spec: - return specs.MultiDiscreteArray( - num_values=jnp.full(self.num_agents, self.num_customers + 1) - ) From b26e48e3007a2f339c454a50bb11665a85de593c Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Fri, 22 Mar 2024 08:57:58 +0100 Subject: [PATCH 16/19] fix: pre-commits --- mava/wrappers/jumanji.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index c783b5b07..52d25c32c 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -391,7 +391,7 @@ def __init__(self, env: MultiCVRP, add_global_state: bool = False): self.has_global_state = add_global_state self.num_customers = env._num_customers super().__init__(env, False) - self._env : MultiCVRP + self._env: MultiCVRP def modify_timestep(self, timestep: TimeStep, state: State) -> TimeStep[Observation]: observation, global_observation = self._flatten_observation(timestep.observation) @@ -488,4 +488,3 @@ def observation_spec(self) -> specs.Spec[Observation]: return specs.Spec(ObservationGlobalState, "ObservationSpec", **obs_data) return specs.Spec(Observation, "ObservationSpec", **obs_data) - From 023c915164498b9f8a579bb9070dace239e34176 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sat, 6 Jul 2024 12:14:56 +0100 Subject: [PATCH 17/19] fix: updated to the latest configs --- mava/configs/env/multicvrp.yaml | 2 ++ mava/configs/env/scenario/multicvrp-2v-20c.yaml | 5 ++++- mava/configs/env/scenario/multicvrp-2v-6c.yaml | 3 +++ mava/wrappers/jumanji.py | 2 +- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mava/configs/env/multicvrp.yaml b/mava/configs/env/multicvrp.yaml index ee4f2e092..708d30cb8 100644 --- a/mava/configs/env/multicvrp.yaml +++ b/mava/configs/env/multicvrp.yaml @@ -11,4 +11,6 @@ implicit_agent_id: False eval_metric: episode_return +log_win_rate: False + kwargs: {} diff --git a/mava/configs/env/scenario/multicvrp-2v-20c.yaml b/mava/configs/env/scenario/multicvrp-2v-20c.yaml index 66702df5d..c43abb16b 100644 --- a/mava/configs/env/scenario/multicvrp-2v-20c.yaml +++ b/mava/configs/env/scenario/multicvrp-2v-20c.yaml @@ -2,5 +2,8 @@ name: MultiCVRP-v0 task_name: multicvrp-2v-20c task_config: - num_customers : 20 + num_customers : 6 num_vehicles : 2 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env \ No newline at end of file diff --git a/mava/configs/env/scenario/multicvrp-2v-6c.yaml b/mava/configs/env/scenario/multicvrp-2v-6c.yaml index 83341b538..418e49434 100644 --- a/mava/configs/env/scenario/multicvrp-2v-6c.yaml +++ b/mava/configs/env/scenario/multicvrp-2v-6c.yaml @@ -4,3 +4,6 @@ task_name: multicvrp-2v-6c task_config: num_customers : 6 num_vehicles : 2 + +env_kwargs: + {} # there are no scenario specific env_kwargs for this env \ No newline at end of file diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 52d25c32c..658504f83 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -386,7 +386,7 @@ class MultiCVRPWrapper(MultiAgentWrapper): def __init__(self, env: MultiCVRP, add_global_state: bool = False): env.num_agents = env._num_vehicles - env.time_limit = env._num_customers + 1 # added for consistency + env.time_limit = None # added for consistency env.action_dim = env._num_customers + 1 # n_costumers + 1 starter node self.has_global_state = add_global_state self.num_customers = env._num_customers From bda1fffc48c84d235d5ebf9f09f29b2dfc5deb6a Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sat, 6 Jul 2024 12:16:24 +0100 Subject: [PATCH 18/19] fix: corrected the 20c scenario config --- mava/configs/env/scenario/multicvrp-2v-20c.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/configs/env/scenario/multicvrp-2v-20c.yaml b/mava/configs/env/scenario/multicvrp-2v-20c.yaml index c43abb16b..d3aac6783 100644 --- a/mava/configs/env/scenario/multicvrp-2v-20c.yaml +++ b/mava/configs/env/scenario/multicvrp-2v-20c.yaml @@ -2,7 +2,7 @@ name: MultiCVRP-v0 task_name: multicvrp-2v-20c task_config: - num_customers : 6 + num_customers : 20 num_vehicles : 2 env_kwargs: From 4f1a68a3b60c6abf02b0e4184832d6bbd774bd14 Mon Sep 17 00:00:00 2001 From: Louay-Ben-nessir Date: Sat, 6 Jul 2024 12:20:45 +0100 Subject: [PATCH 19/19] chore: pre-commits --- mava/configs/env/scenario/multicvrp-2v-20c.yaml | 2 +- mava/configs/env/scenario/multicvrp-2v-6c.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mava/configs/env/scenario/multicvrp-2v-20c.yaml b/mava/configs/env/scenario/multicvrp-2v-20c.yaml index d3aac6783..bd817a47d 100644 --- a/mava/configs/env/scenario/multicvrp-2v-20c.yaml +++ b/mava/configs/env/scenario/multicvrp-2v-20c.yaml @@ -6,4 +6,4 @@ task_config: num_vehicles : 2 env_kwargs: - {} # there are no scenario specific env_kwargs for this env \ No newline at end of file + {} # there are no scenario specific env_kwargs for this env diff --git a/mava/configs/env/scenario/multicvrp-2v-6c.yaml b/mava/configs/env/scenario/multicvrp-2v-6c.yaml index 418e49434..ebf0c884a 100644 --- a/mava/configs/env/scenario/multicvrp-2v-6c.yaml +++ b/mava/configs/env/scenario/multicvrp-2v-6c.yaml @@ -6,4 +6,4 @@ task_config: num_vehicles : 2 env_kwargs: - {} # there are no scenario specific env_kwargs for this env \ No newline at end of file + {} # there are no scenario specific env_kwargs for this env