-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulti_agent_env.py
93 lines (75 loc) · 2.83 KB
/
multi_agent_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Abstract base class for multi agent gym environments with JAX
Based on the Gymnax and PettingZoo APIs
"""
import jax
import jax.numpy as jnp
from typing import Dict
import chex
from functools import partial
from flax import struct
from typing import Tuple, Optional
@struct.dataclass
class State:
done: chex.Array
step: int
class MultiAgentEnv(object):
"""Jittable abstract base class for all jaxmarl Environments."""
def __init__(
self,
num_agents: int,
) -> None:
"""
num_agents (int): maximum number of agents within the environment, used to set array dimensions
"""
self.num_agents = num_agents
self.observation_spaces = dict()
self.action_spaces = dict()
@partial(jax.jit, static_argnums=(0,))
def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
"""Performs resetting of the environment."""
raise NotImplementedError
@partial(jax.jit, static_argnums=(0,))
def step(
self,
key: chex.PRNGKey,
state: State,
actions: Dict[str, chex.Array],
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
"""Performs step transitions in the environment."""
key, key_reset = jax.random.split(key)
obs_st, states_st, rewards, dones, infos = self.step_env(key, state, actions)
obs_re, states_re = self.reset(key_reset)
# Auto-reset environment based on termination
states = jax.tree_map(
lambda x, y: jax.lax.select(dones["__all__"], x, y), states_re, states_st
)
obs = jax.tree_map(
lambda x, y: jax.lax.select(dones["__all__"], x, y), obs_re, obs_st
)
return obs, states, rewards, dones, infos
def step_env(
self, key: chex.PRNGKey, state: State, actions: Dict[str, chex.Array]
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
"""Environment-specific step transition."""
raise NotImplementedError
def get_obs(self, state: State) -> Dict[str, chex.Array]:
"""Applies observation function to state."""
raise NotImplementedError
def observation_space(self, agent: str):
"""Observation space for a given agent."""
return self.observation_spaces[agent]
def action_space(self, agent: str):
"""Action space for a given agent."""
return self.action_spaces[agent]
@property
def name(self) -> str:
"""Environment name."""
return type(self).__name__
@property
def agent_classes(self) -> dict:
"""Returns a dictionary with agent classes, used in environments with hetrogenous agents.
Format:
agent_base_name: [agent_base_name_1, agent_base_name_2, ...]
"""
raise NotImplementedError