diff --git a/jumanji/environments/routing/connector/env.py b/jumanji/environments/routing/connector/env.py index dae55c0fc..2aa71fe6c 100644 --- a/jumanji/environments/routing/connector/env.py +++ b/jumanji/environments/routing/connector/env.py @@ -70,9 +70,10 @@ class Connector(Environment[State, specs.MultiDiscreteArray, Observation]): - can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left]. - each value in the array corresponds to an agent's action. - - reward: jax array (float) of shape (): - - dense: reward is 1 for each successful connection on that step. Additionally, - each pair of points that have not connected receives a penalty reward of -0.03. + - reward: jax array (float) of shape (num_agents,): + - dense: for each agent the reward is 1 for each successful connection on that step. + Additionally, each pair of points that have not connected receives a + penalty reward of -0.03. - episode termination: - all agents either can't move (no available actions) or have connected to their target. @@ -142,7 +143,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: step_count=state.step_count, ) extras = self._get_extras(state) - timestep = restart(observation=observation, extras=extras) + timestep = restart(observation=observation, extras=extras, shape=(self.num_agents,)) return state, timestep def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]: @@ -171,19 +172,23 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observ grid=grid, action_mask=action_mask, step_count=new_state.step_count ) - done = jnp.all(jax.vmap(connected_or_blocked)(agents, action_mask)) + done = jax.vmap(connected_or_blocked)(agents, action_mask) + discount = (1 - done).astype(float) extras = self._get_extras(new_state) timestep = jax.lax.cond( - done | (new_state.step_count >= self.time_limit), + jnp.all(done) | (new_state.step_count >= self.time_limit), lambda: termination( reward=reward, observation=observation, extras=extras, + shape=(self.num_agents,), ), lambda: transition( reward=reward, observation=observation, extras=extras, + discount=discount, + shape=(self.num_agents,), ), ) @@ -362,3 +367,19 @@ def action_spec(self) -> specs.MultiDiscreteArray: dtype=jnp.int32, name="action", ) + + @cached_property + def reward_spec(self) -> specs.Array: + """Returns: a reward per agent.""" + return specs.Array(shape=(self.num_agents,), dtype=float, name="reward") + + @cached_property + def discount_spec(self) -> specs.BoundedArray: + """Returns: discount per agent.""" + return specs.BoundedArray( + shape=(self.num_agents,), + dtype=float, + minimum=0.0, + maximum=1.0, + name="discount", + ) diff --git a/jumanji/environments/routing/connector/env_test.py b/jumanji/environments/routing/connector/env_test.py index 5abe0dbbe..9e31bf364 100644 --- a/jumanji/environments/routing/connector/env_test.py +++ b/jumanji/environments/routing/connector/env_test.py @@ -57,8 +57,8 @@ def test_connector__reset(connector: Connector, key: jax.random.PRNGKey) -> None assert all(is_head_on_grid(state.agents, state.grid)) assert all(is_target_on_grid(state.agents, state.grid)) - assert timestep.discount == 1.0 - assert timestep.reward == 0.0 + assert jnp.allclose(timestep.discount, jnp.ones((connector.num_agents,))) + assert jnp.allclose(timestep.reward, jnp.zeros((connector.num_agents,))) assert timestep.step_type == StepType.FIRST @@ -94,7 +94,7 @@ def test_connector__step_connected( chex.assert_trees_all_equal(real_state2, state2) assert timestep.step_type == StepType.LAST - assert jnp.array_equal(timestep.discount, jnp.asarray(0)) + assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents)) reward = connector._reward_fn(real_state1, action2, real_state2) assert jnp.array_equal(timestep.reward, reward) @@ -146,7 +146,7 @@ def test_connector__step_blocked( assert jnp.array_equal(state.grid, expected_grid) assert timestep.step_type == StepType.LAST - assert jnp.array_equal(timestep.discount, jnp.asarray(0)) + assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents)) assert all(is_head_on_grid(state.agents, state.grid)) assert all(is_target_on_grid(state.agents, state.grid)) @@ -165,12 +165,12 @@ def test_connector__step_horizon(connector: Connector, state: State) -> None: state, timestep = step_fn(state, actions) assert timestep.step_type != StepType.LAST - assert jnp.array_equal(timestep.discount, jnp.asarray(1)) + assert jnp.array_equal(timestep.discount, jnp.ones(connector.num_agents)) # step 5 state, timestep = step_fn(state, actions) assert timestep.step_type == StepType.LAST - assert jnp.array_equal(timestep.discount, jnp.asarray(0)) + assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents)) def test_connector__step_agents_collision( diff --git a/jumanji/environments/routing/connector/reward.py b/jumanji/environments/routing/connector/reward.py index f29dbf2d5..8d176f74c 100644 --- a/jumanji/environments/routing/connector/reward.py +++ b/jumanji/environments/routing/connector/reward.py @@ -71,4 +71,4 @@ def __call__( ~state.agents.connected & next_state.agents.connected, float ) timestep_rewards = self.timestep_reward * jnp.asarray(~state.agents.connected, float) - return jnp.sum(connected_rewards + timestep_rewards) + return connected_rewards + timestep_rewards diff --git a/jumanji/environments/routing/connector/reward_test.py b/jumanji/environments/routing/connector/reward_test.py index 226a42ed1..205b61689 100644 --- a/jumanji/environments/routing/connector/reward_test.py +++ b/jumanji/environments/routing/connector/reward_test.py @@ -35,26 +35,27 @@ def test_dense_reward( # Reward of moving between the same states should be 0. reward = dense_reward_fn(state, jnp.array([0, 0, 0]), state) - chex.assert_rank(reward, 0) - assert jnp.isclose(reward, jnp.asarray(timestep_reward * 3)) + chex.assert_rank(reward, 1) + assert jnp.allclose(reward, jnp.array([timestep_reward] * 3)) # Reward for no agents finished to 2 agents finished. reward = dense_reward_fn(state, action1, state1) - chex.assert_rank(reward, 0) - expected_reward = connected_reward * 2 + timestep_reward * 3 - assert jnp.isclose(reward, expected_reward) + chex.assert_rank(reward, 1) + expected_reward = jnp.array([connected_reward, 0, connected_reward]) + timestep_reward + assert jnp.allclose(reward, expected_reward) # Reward for some agents finished to all agents finished. reward = dense_reward_fn(state1, action2, state2) - chex.assert_rank(reward, 0) - assert jnp.isclose(reward, jnp.array(connected_reward + timestep_reward)) + chex.assert_rank(reward, 1) + expected_reward = jnp.array([0, connected_reward + timestep_reward, 0]) + assert jnp.allclose(reward, expected_reward) # Reward for none finished to all finished reward = dense_reward_fn(state, action1, state2) - chex.assert_rank(reward, 0) - assert jnp.isclose(reward, jnp.array((connected_reward + timestep_reward) * 3)) + chex.assert_rank(reward, 1) + assert jnp.allclose(reward, jnp.array([connected_reward + timestep_reward] * 3)) # Reward of all finished to all finished. reward = dense_reward_fn(state2, jnp.zeros(3), state2) - chex.assert_rank(reward, 0) - assert jnp.isclose(reward, jnp.zeros(1)) + chex.assert_rank(reward, 1) + assert jnp.allclose(reward, jnp.zeros(1)) diff --git a/jumanji/training/setup_train.py b/jumanji/training/setup_train.py index d8612bed9..7daa0201d 100644 --- a/jumanji/training/setup_train.py +++ b/jumanji/training/setup_train.py @@ -90,7 +90,7 @@ def setup_logger(cfg: DictConfig) -> Logger: def _make_raw_env(cfg: DictConfig) -> Environment: env = jumanji.make(cfg.env.registered_version) - if cfg.env.name in {"lbf"}: + if cfg.env.name in {"lbf", "connector"}: # Convert a multi-agent environment to a single-agent environment env = MultiToSingleWrapper(env) return env