Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: return individual rewards in Connector env #263

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions jumanji/environments/routing/connector/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ class Connector(Environment[State, specs.MultiDiscreteArray, Observation]):
- can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left].
- each value in the array corresponds to an agent's action.

- reward: jax array (float) of shape ():
- dense: reward is 1 for each successful connection on that step. Additionally,
each pair of points that have not connected receives a penalty reward of -0.03.
- reward: jax array (float) of shape (num_agents,):
- dense: for each agent the reward is 1 for each successful connection on that step.
Additionally, each pair of points that have not connected receives a
penalty reward of -0.03.

- episode termination:
- all agents either can't move (no available actions) or have connected to their target.
Expand Down Expand Up @@ -142,7 +143,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
step_count=state.step_count,
)
extras = self._get_extras(state)
timestep = restart(observation=observation, extras=extras)
timestep = restart(observation=observation, extras=extras, shape=(self.num_agents,))
return state, timestep

def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
Expand Down Expand Up @@ -171,19 +172,23 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observ
grid=grid, action_mask=action_mask, step_count=new_state.step_count
)

done = jnp.all(jax.vmap(connected_or_blocked)(agents, action_mask))
done = jax.vmap(connected_or_blocked)(agents, action_mask)
discount = (1 - done).astype(float)
extras = self._get_extras(new_state)
timestep = jax.lax.cond(
done | (new_state.step_count >= self.time_limit),
jnp.all(done) | (new_state.step_count >= self.time_limit),
lambda: termination(
reward=reward,
observation=observation,
extras=extras,
shape=(self.num_agents,),
),
lambda: transition(
reward=reward,
observation=observation,
extras=extras,
discount=discount,
shape=(self.num_agents,),
),
)

Expand Down Expand Up @@ -362,3 +367,19 @@ def action_spec(self) -> specs.MultiDiscreteArray:
dtype=jnp.int32,
name="action",
)

@cached_property
def reward_spec(self) -> specs.Array:
"""Returns: a reward per agent."""
return specs.Array(shape=(self.num_agents,), dtype=float, name="reward")

@cached_property
def discount_spec(self) -> specs.BoundedArray:
"""Returns: discount per agent."""
return specs.BoundedArray(
shape=(self.num_agents,),
dtype=float,
minimum=0.0,
maximum=1.0,
name="discount",
)
12 changes: 6 additions & 6 deletions jumanji/environments/routing/connector/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def test_connector__reset(connector: Connector, key: jax.random.PRNGKey) -> None
assert all(is_head_on_grid(state.agents, state.grid))
assert all(is_target_on_grid(state.agents, state.grid))

assert timestep.discount == 1.0
assert timestep.reward == 0.0
assert jnp.allclose(timestep.discount, jnp.ones((connector.num_agents,)))
assert jnp.allclose(timestep.reward, jnp.zeros((connector.num_agents,)))
assert timestep.step_type == StepType.FIRST


Expand Down Expand Up @@ -94,7 +94,7 @@ def test_connector__step_connected(
chex.assert_trees_all_equal(real_state2, state2)

assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.asarray(0))
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))
reward = connector._reward_fn(real_state1, action2, real_state2)
assert jnp.array_equal(timestep.reward, reward)

Expand Down Expand Up @@ -146,7 +146,7 @@ def test_connector__step_blocked(

assert jnp.array_equal(state.grid, expected_grid)
assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.asarray(0))
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))

assert all(is_head_on_grid(state.agents, state.grid))
assert all(is_target_on_grid(state.agents, state.grid))
Expand All @@ -165,12 +165,12 @@ def test_connector__step_horizon(connector: Connector, state: State) -> None:
state, timestep = step_fn(state, actions)

assert timestep.step_type != StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.asarray(1))
assert jnp.array_equal(timestep.discount, jnp.ones(connector.num_agents))

# step 5
state, timestep = step_fn(state, actions)
assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.asarray(0))
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))


def test_connector__step_agents_collision(
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/routing/connector/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ def __call__(
~state.agents.connected & next_state.agents.connected, float
)
timestep_rewards = self.timestep_reward * jnp.asarray(~state.agents.connected, float)
return jnp.sum(connected_rewards + timestep_rewards)
return connected_rewards + timestep_rewards
23 changes: 12 additions & 11 deletions jumanji/environments/routing/connector/reward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,27 @@ def test_dense_reward(

# Reward of moving between the same states should be 0.
reward = dense_reward_fn(state, jnp.array([0, 0, 0]), state)
chex.assert_rank(reward, 0)
assert jnp.isclose(reward, jnp.asarray(timestep_reward * 3))
chex.assert_rank(reward, 1)
assert jnp.allclose(reward, jnp.array([timestep_reward] * 3))

# Reward for no agents finished to 2 agents finished.
reward = dense_reward_fn(state, action1, state1)
chex.assert_rank(reward, 0)
expected_reward = connected_reward * 2 + timestep_reward * 3
assert jnp.isclose(reward, expected_reward)
chex.assert_rank(reward, 1)
expected_reward = jnp.array([connected_reward, 0, connected_reward]) + timestep_reward
assert jnp.allclose(reward, expected_reward)

# Reward for some agents finished to all agents finished.
reward = dense_reward_fn(state1, action2, state2)
chex.assert_rank(reward, 0)
assert jnp.isclose(reward, jnp.array(connected_reward + timestep_reward))
chex.assert_rank(reward, 1)
expected_reward = jnp.array([0, connected_reward + timestep_reward, 0])
assert jnp.allclose(reward, expected_reward)

# Reward for none finished to all finished
reward = dense_reward_fn(state, action1, state2)
chex.assert_rank(reward, 0)
assert jnp.isclose(reward, jnp.array((connected_reward + timestep_reward) * 3))
chex.assert_rank(reward, 1)
assert jnp.allclose(reward, jnp.array([connected_reward + timestep_reward] * 3))

# Reward of all finished to all finished.
reward = dense_reward_fn(state2, jnp.zeros(3), state2)
chex.assert_rank(reward, 0)
assert jnp.isclose(reward, jnp.zeros(1))
chex.assert_rank(reward, 1)
assert jnp.allclose(reward, jnp.zeros(1))
2 changes: 1 addition & 1 deletion jumanji/training/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- env: snake # [bin_pack, cleaner, connector, cvrp, flat_pack, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, lbf, rubiks_cube, sliding_tile_puzzle, snake, sokoban, sudoku, tetris, tsp]
- env: connector # [bin_pack, cleaner, connector, cvrp, flat_pack, game_2048, graph_coloring, job_shop, knapsack, maze, minesweeper, mmst, multi_cvrp, pac_man, robot_warehouse, lbf, rubiks_cube, sliding_tile_puzzle, snake, sokoban, sudoku, tetris, tsp]
sash-a marked this conversation as resolved.
Show resolved Hide resolved

agent: random # [random, a2c]

Expand Down
2 changes: 1 addition & 1 deletion jumanji/training/setup_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def setup_logger(cfg: DictConfig) -> Logger:

def _make_raw_env(cfg: DictConfig) -> Environment:
env = jumanji.make(cfg.env.registered_version)
if cfg.env.name in {"lbf"}:
if cfg.env.name in {"lbf", "connector"}:
# Convert a multi-agent environment to a single-agent environment
env = MultiToSingleWrapper(env)
return env
Expand Down