From 3dab01150e2590dd60731c662089289445494564 Mon Sep 17 00:00:00 2001 From: Benjamin Piwowarski Date: Sat, 18 Nov 2023 12:15:19 +0100 Subject: [PATCH] Fixed bugs and documentation --- CHANGELOG.md | 8 ++ README.md | 30 +++++- pyproject.toml | 5 +- src/pystk2_gymnasium/__init__.py | 2 + src/pystk2_gymnasium/definitions.py | 64 +++++++++++++ src/pystk2_gymnasium/envs.py | 136 +++++++++++++++------------ src/pystk2_gymnasium/stk_wrappers.py | 42 +-------- src/pystk2_gymnasium/wrappers.py | 130 +++++++++++++++++++++++-- tests/test_envs.py | 4 +- 9 files changed, 305 insertions(+), 116 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 src/pystk2_gymnasium/definitions.py diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..6ce66ea --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,8 @@ +# Version 0.4.0 + +- Multi-agent environment +- Use polar representation instead of coordinates (except for the "full" environment) +- Only two base environments (multi/mono-agent) and wrappers for the rest: this allows races to be organized with different set of wrappers (depending on the agent) +- Added `distance_center_path` +- Allow to change player name and camera mode +- breaking: Agent spec is used for mono-kart environments diff --git a/README.md b/README.md index 32b1919..c6d3c20 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # PySuperTuxKart gymnasium wrapper -*warning*: pystk2-gymnasium is in alpha stage - the environments might change abruptly! +[![PyPI version](https://badge.fury.io/py/pystk2-gymnasium.svg)](https://badge.fury.io/py/pystk2-gymnasium) + +Read the [Changelog](./CHANGELOG.md) ## Install @@ -10,19 +12,29 @@ The PySuperKart2 gymnasium wrapper is a Python package, so installing is fairly Note that during the first run, SuperTuxKart assets are downloaded in the cache directory. +## AgentSpec + +Each controlled kart is parametrized by `pystk2_gymnasium.AgentSpec`: + +- `name` defines name of the player (displayed on top of the kart) +- `rank_start` defines the starting position (None for random, which is the default) +- `use_ai` flag (False by default) to ignore actions (when calling `step`, and use a SuperTuxKart bot) +- `camera_mode` can be set to `AUTO` (camera on for non STK bots), `ON` (camera on) or `OFF` (no camera). + + ## Environments + *Warning* only one SuperTuxKart environment can be created for now. Moreover, no graphics information is available for now. After importing `pystk2_gymnasium`, the following environments are available: - `supertuxkart/full-v0` is the main environment containing complete observations. The observation and action spaces are both dictionaries with continuous or discrete variables (see below). The exact structure can be found using `env.observation_space` and `env.action_space`. The following options can be used to modify the environment: + - `agent` is an `AgentSpec (see above)` - `render_mode` can be None or `human` - `track` defines the SuperTuxKart track to use (None for random). The full list can be found in `STKRaceEnv.TRACKS` after initialization with `initialize.initialize(with_graphics: bool)` has been called. - `num_kart` defines the number of karts on the track (3 by default) - - `rank_start` defines the starting position (None for random, which is the default) - - `use_ai` flag (False by default) to ignore actions (when calling `step`, and use a SuperTuxKart bot) - `max_paths` the maximum number of the (nearest) paths (a track is made of paths) to consider in the observation state - `laps` is the number of laps (1 by default) - `difficulty` is the difficulty of the AI bots (lowest 0 to highest 2, default to 2) @@ -39,6 +51,13 @@ $$ r_{t} = \frac{1}{10}(d_{t} - d_{t-1}) + (1 - \frac{\mathrm{pos}_t}{K}) \time where $d_t$ is the overall track distance at time $t$, $\mathrm{pos}_t$ the position among the $K$ karts at time $t$, and $f_t$ is $1$ when the kart finishes the race. +## Multi-agent environment + +`supertuxkart/multi-full-v0` can be used to control multiple karts. It takes an +`agents` parameter that is a list of `AgentSpec`. Observations and actions are a dictionary of single-kart ones where **string** keys that range from `0` to `n-1` with `n` the number of karts. + +To use different gymnasium wrappers, one can use a `MonoAgentWrapperAdapter`. + ## Action and observation space All the 3D vectors are within the kart referential (`z` front, `x` left, `y` up): @@ -53,6 +72,7 @@ All the 3D vectors are within the kart referential (`z` front, `x` left, `y` up) - `jumping`: is the kart jumping - `karts_position`: position of other karts, beginning with the ones in front - `max_steer_angle` the max angle of the steering (given the current speed) +- `distance_center_path`: distance to the center of the path - `paths_distance`: the distance of the paths - `paths_start`, `paths_end`, `paths_width`: 3D vector to the paths start and end, with their widths (sccalar) - `paths_start`: 3D vectors to the the path s @@ -65,13 +85,13 @@ All the 3D vectors are within the kart referential (`z` front, `x` left, `y` up) ```py3 import gymnasium as gym -import pystk2_gymnasium +from pystk2_gymnasium import AgentSpec # Use a a flattened version of the observation and action spaces # In both case, this corresponds to a dictionary with two keys: # - `continuous` is a vector corresponding to the continuous observations # - `discrete` is a vector (of integers) corresponding to discrete observations -env = gym.make("supertuxkart/flattened-v0", render_mode="human", use_ai=False) +env = gym.make("supertuxkart/flattened-v0", render_mode="human", agent=AgentSpec(use_ai=False)) ix = 0 done = False diff --git a/pyproject.toml b/pyproject.toml index 2b1da17..0d08f56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,13 +5,14 @@ description = "Gymnasium wrapper for PySTK2" authors = ["Benjamin Piwowarski "] license = "GPL" readme = "README.md" - homepage = "https://github.com/bpiwowar/pystk2-gymnasium" repository = "https://github.com/bpiwowar/pystk2-gymnasium" +include = ["CHANGELOG.md"] + [tool.poetry.dependencies] python = "^3.8" -PySuperTuxKart2 = "=0.3.4" +PySuperTuxKart2 = ">=0.3.5" gymnasium = ">0.29.0" [build-system] diff --git a/src/pystk2_gymnasium/__init__.py b/src/pystk2_gymnasium/__init__.py index b0afa4a..c3ba4fb 100644 --- a/src/pystk2_gymnasium/__init__.py +++ b/src/pystk2_gymnasium/__init__.py @@ -1,4 +1,6 @@ from gymnasium.envs.registration import register, WrapperSpec +from .definitions import ActionObservationWrapper, AgentSpec # noqa: F401 +from .wrappers import MonoAgentWrapperAdapter # noqa: F401 register( id="supertuxkart/full-v0", diff --git a/src/pystk2_gymnasium/definitions.py b/src/pystk2_gymnasium/definitions.py new file mode 100644 index 0000000..0362f2a --- /dev/null +++ b/src/pystk2_gymnasium/definitions.py @@ -0,0 +1,64 @@ +""" +This module contains STK-specific wrappers +""" + +from typing import Any, Dict, Optional, Tuple +from dataclasses import dataclass +import pystk2 + +import gymnasium as gym +from gymnasium.core import ( + Wrapper, + WrapperActType, + WrapperObsType, + ObsType, + ActType, + SupportsFloat, +) + +CameraMode = pystk2.PlayerConfig.CameraMode + + +@dataclass +class AgentSpec: + #: The position of the controlled kart, defaults to None for random, 0 to + # num_kart-1 assigns a rank, all the other values discard the controlled + # kart. + rank_start: Optional[int] = None + #: Use the STK AI agent (ignores actions) + use_ai: bool = False + #: Player name + name: str = "" + #: Camera mode (AUTO, ON, OFF). By default, only non-AI agents get a camera + camera_mode: CameraMode = CameraMode.AUTO + + +class ActionObservationWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]): + """Combines action and observation wrapper""" + + def action(self, action: WrapperActType) -> ActType: + raise NotImplementedError + + def observation(self, observation: ObsType) -> WrapperObsType: + raise NotImplementedError + + def __init__(self, env: gym.Env[ObsType, ActType]): + """Constructor for the action wrapper.""" + Wrapper.__init__(self, env) + + def reset( + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + ) -> Tuple[WrapperObsType, Dict[str, Any]]: + """Modifies the :attr:`env` after calling :meth:`reset`, returning a + modified observation using :meth:`self.observation`.""" + obs, info = self.env.reset(seed=seed, options=options) + return self.observation(obs), info + + def step( + self, action: ActType + ) -> Tuple[WrapperObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + """Modifies the :attr:`env` after calling :meth:`step` using + :meth:`self.observation` on the returned observations.""" + action = self.action(action) + observation, reward, terminated, truncated, info = self.env.step(action) + return self.observation(observation), reward, terminated, truncated, info diff --git a/src/pystk2_gymnasium/envs.py b/src/pystk2_gymnasium/envs.py index 185ba24..087b85a 100644 --- a/src/pystk2_gymnasium/envs.py +++ b/src/pystk2_gymnasium/envs.py @@ -1,5 +1,3 @@ -from dataclasses import dataclass -from itertools import repeat import logging import functools from typing import Any, ClassVar, Dict, List, Optional, Tuple, TypedDict @@ -9,15 +7,12 @@ import pystk2 from gymnasium import spaces -from pystk2_gymnasium.utils import max_enum_value, rotate +from .utils import max_enum_value, rotate +from .definitions import AgentSpec logger = logging.getLogger("pystk2-gym") -float3D = Tuple[float, float, float] -float4D = Tuple[float, float, float, float] - - @functools.lru_cache def kart_action_space(): return spaces.Dict( @@ -58,6 +53,9 @@ def kart_observation_space(use_ai: bool): ), "max_steer_angle": spaces.Box(-1, 1, dtype=np.float32, shape=(1,)), "distance_down_track": spaces.Box(0.0, float("inf")), + "distance_center_path": spaces.Box( + 0, float("inf"), dtype=np.float32, shape=(1,) + ), "front": spaces.Box( -float("inf"), float("inf"), dtype=np.float32, shape=(3,) ), @@ -121,12 +119,6 @@ class BaseSTKRaceEnv(gym.Env[Any, STKAction]): #: List of available tracks TRACKS: ClassVar[List[str]] = [] - #: Rank of the observed kart (random if None) - rank_start: Optional[int] - - #: Use AI - use_ai: bool - @staticmethod def initialize(with_graphics: bool): if BaseSTKRaceEnv.INITIALIZED is None: @@ -326,6 +318,15 @@ def sort_closest(positions, *lists): items_type = [item.type.value for item in self.world.items] sort_closest(items_position, items_type) + # Distance from center of track + start, end = kartview(self.track.path_nodes[path_ix][0]), kartview( + self.track.path_nodes[path_ix][1] + ) + s_e = start - end + distance_center_path = np.linalg.norm( + start - np.dot(s_e, start) * s_e / np.linalg.norm(s_e) + ) + # Add action if using AI bot obs = {} if use_ai: @@ -359,6 +360,7 @@ def sort_closest(positions, *lists): "distance_down_track": np.array( [kart.distance_down_track], dtype=np.float32 ), + "distance_center_path": np.array([distance_center_path], dtype=np.float32), "velocity": kart.velocity_lc, "front": kartview(kart.front), # Items (kart point of view) @@ -385,27 +387,26 @@ def render(self): class STKRaceEnv(BaseSTKRaceEnv): """Single player race environment""" - def __init__(self, *, rank_start=None, use_ai=False, **kwargs): + #: Use AI + spec: AgentSpec + + def __init__(self, *, agent: Optional[AgentSpec] = None, **kwargs): """Creates a new race - :param use_ai: Use STK built AI bot instead of the agent action - :param rank_start: The position of the controlled kart, defaults to None - for random, 0 to num_kart-1 assigns a rank, all the other values - discard the controlled kart. + :param spec: Agent spec :param kwargs: General parameters, see BaseSTKRaceEnv """ super().__init__(**kwargs) # Setup the variables - self.rank_start = rank_start - self.use_ai = use_ai + self.agent = agent if agent is not None else AgentSpec() # Those will be set when the race is setup self.kart_ix = None # We have 4 actions, corresponding to "right", "up", "left", "down" self.action_space = kart_action_space() - self.observation_space = kart_observation_space(self.use_ai) + self.observation_space = kart_observation_space(self.agent.use_ai) def reset( self, @@ -418,16 +419,18 @@ def reset( super().reset_race(random, options=options) # Set the controlled kart position (if any) - self.kart_ix = self.rank_start + self.kart_ix = self.agent.rank_start if self.kart_ix is None: self.kart_ix = np.random.randint(0, self.num_kart) logging.debug("Observed kart index %d", self.kart_ix) - if self.use_ai: - self.config.players[ - self.kart_ix - ].camera_mode = pystk2.PlayerConfig.CameraMode.ON - else: + # Camera setup + self.config.players[ + self.kart_ix + ].camera_mode = pystk2.PlayerConfig.CameraMode.ON + self.config.players[self.kart_ix].name = self.agent.name + + if not self.agent.use_ai: self.config.players[ self.kart_ix ].controller = pystk2.PlayerConfig.Controller.PLAYER_CONTROL @@ -435,29 +438,23 @@ def reset( self.warmup_race() self.world.update() - return self.get_observation(self.kart_ix, self.use_ai), {} + return self.get_observation(self.kart_ix, self.agent.use_ai), {} def step( self, action: STKAction ) -> Tuple[pystk2.WorldState, float, bool, bool, Dict[str, Any]]: - if self.use_ai: + if self.agent.use_ai: self.race.step() else: self.race.step(get_action(action)) self.world_update() - obs, reward, terminated, info = self.get_state(self.kart_ix, self.use_ai) + obs, reward, terminated, info = self.get_state(self.kart_ix, self.agent.use_ai) return (obs, reward, terminated, False, info) -@dataclass -class AgentSpec: - rank_start: Optional[int] = None - use_ai: bool = False - - class STKRaceMultiEnv(BaseSTKRaceEnv): """Multi-agent race environment""" @@ -473,7 +470,9 @@ def __init__(self, *, agents: List[AgentSpec] = None, **kwargs): # Setup the variables self.agents = agents - assert len(self.agents) <= self.num_kart, f"Too many agents ({len(self.agents)}) for {self.num_kart} karts" + assert ( + len(self.agents) <= self.num_kart + ), f"Too many agents ({len(self.agents)}) for {self.num_kart} karts" # Kart index for each agent (set when the race is setup) self.kart_indices = None @@ -491,9 +490,14 @@ def __init__(self, *, agents: List[AgentSpec] = None, **kwargs): ix for ix in range(self.num_kart) if ix not in ranked_agents ] - self.action_space = spaces.Tuple(repeat(kart_action_space(), len(self.agents))) - self.observation_space = spaces.Tuple( - kart_observation_space(agent.use_ai) for agent in self.agents + self.action_space = spaces.Dict( + {str(ix): kart_action_space() for ix in range(len(self.agents))} + ) + self.observation_space = spaces.Dict( + { + str(ix): kart_observation_space(agent.use_ai) + for ix, agent in enumerate(self.agents) + } ) def reset( @@ -514,11 +518,12 @@ def reset( for agent in self.agents: kart_ix = agent.rank_start or next(pos_iter) self.kart_indices.append(kart_ix) - self.config.players[kart_ix].camera_mode = pystk2.PlayerConfig.CameraMode.ON + self.config.players[kart_ix].camera_mode = agent.camera_mode if not agent.use_ai: self.config.players[ kart_ix ].controller = pystk2.PlayerConfig.Controller.PLAYER_CONTROL + self.config.players[kart_ix].name = agent.name logging.debug("Observed kart indices %s", self.kart_indices) @@ -526,22 +531,24 @@ def reset( self.world.update() return ( - tuple( - self.get_observation(ix, agent.use_ai) - for agent, ix in zip(self.agents, self.kart_indices) - ), + { + str(agent_ix): self.get_observation(kart_ix, agent.use_ai) + for agent_ix, (agent, kart_ix) in enumerate( + zip(self.agents, self.kart_indices) + ) + }, {}, ) def step( - self, actions: Tuple[STKAction] + self, actions: Dict[str, STKAction] ) -> Tuple[pystk2.WorldState, float, bool, bool, Dict[str, Any]]: # Performs the action assert len(actions) == len(self.agents) self.race.step( [ - get_action(action) - for agent, action in zip(self.agents, actions) + get_action(actions[str(agent_ix)]) + for agent_ix, agent in enumerate(self.agents) if not agent.use_ai ] ) @@ -549,27 +556,36 @@ def step( # Update the world state self.world_update() - observations = [] - rewards = [] - infos = [] + observations = {} + rewards = {} + infos = {} + multi_terminated = {} + multi_done = {} terminated_count = 0 - for agent, kart_ix in zip(self.agents, self.kart_indices): + for agent_ix, (agent, kart_ix) in enumerate( + zip(self.agents, self.kart_indices) + ): obs, reward, terminated, info = self.get_state(kart_ix, agent.use_ai) - observations.append(obs) - rewards.append(reward) + key = str(agent_ix) + + observations[key] = obs + rewards[key] = reward + multi_terminated[key] = terminated + infos[key] = info + if terminated: terminated_count += 1 - infos.append(info) return ( - tuple(observations), - # Only scalar rewards can be given - np.sum(rewards), + observations, + # Only scalar rewards can be given: we sum them all + np.sum(list(rewards.values())), terminated_count == len(self.agents), False, { "infos": infos, - # We put back individual rewards - "rewards": rewards, + "done": multi_done, + "terminated": multi_terminated, + "reward": rewards, }, ) diff --git a/src/pystk2_gymnasium/stk_wrappers.py b/src/pystk2_gymnasium/stk_wrappers.py index 1fcea04..f02b5ce 100644 --- a/src/pystk2_gymnasium/stk_wrappers.py +++ b/src/pystk2_gymnasium/stk_wrappers.py @@ -3,56 +3,18 @@ """ import copy -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Tuple import gymnasium as gym import numpy as np import pystk2 from gymnasium import spaces -from gymnasium.core import ( - Wrapper, - WrapperActType, - WrapperObsType, - ObsType, - ActType, - SupportsFloat, -) from .envs import STKAction +from .definitions import ActionObservationWrapper from pystk2_gymnasium.utils import Discretizer, max_enum_value -class ActionObservationWrapper(Wrapper[ObsType, WrapperActType, ObsType, ActType]): - """Combines action and observation wrapper""" - - def action(self, action: WrapperActType) -> ActType: - raise NotImplementedError - - def observation(self, observation: ObsType) -> WrapperObsType: - raise NotImplementedError - - def __init__(self, env: gym.Env[ObsType, ActType]): - """Constructor for the action wrapper.""" - Wrapper.__init__(self, env) - - def reset( - self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ) -> Tuple[WrapperObsType, Dict[str, Any]]: - """Modifies the :attr:`env` after calling :meth:`reset`, returning a - modified observation using :meth:`self.observation`.""" - obs, info = self.env.reset(seed=seed, options=options) - return self.observation(obs), info - - def step( - self, action: ActType - ) -> Tuple[WrapperObsType, SupportsFloat, bool, bool, Dict[str, Any]]: - """Modifies the :attr:`env` after calling :meth:`step` using - :meth:`self.observation` on the returned observations.""" - action = self.action(action) - observation, reward, terminated, truncated, info = self.env.step(action) - return self.observation(observation), reward, terminated, truncated, info - - class PolarObservations(gym.ObservationWrapper): """Modifies position to polar positions diff --git a/src/pystk2_gymnasium/wrappers.py b/src/pystk2_gymnasium/wrappers.py index 05bb66f..30911ba 100644 --- a/src/pystk2_gymnasium/wrappers.py +++ b/src/pystk2_gymnasium/wrappers.py @@ -1,12 +1,21 @@ """ This module contains generic wrappers """ -from typing import Any, Dict, Tuple +from typing import Any, Callable, Dict, SupportsFloat import gymnasium as gym -import numpy as np from gymnasium import spaces -from gymnasium.core import Env +from gymnasium.core import ( + Wrapper, + WrapperActType, + WrapperObsType, + ObsType, + ActType, + Env, +) +import numpy as np + +from pystk2_gymnasium.definitions import ActionObservationWrapper class SpaceFlattener: @@ -64,7 +73,7 @@ def __init__(self, space: gym.Space): ) -class FlattenerWrapper(gym.ObservationWrapper): +class FlattenerWrapper(ActionObservationWrapper): def __init__(self, env: gym.Env): super().__init__(env) @@ -115,9 +124,6 @@ def observation(self, observation): return new_obs - def step(self, action) -> Tuple[Any, float, bool, bool, Dict[str, Any]]: - return super().step(self.action(action)) - def action(self, action): discrete_actions = {} if not self.action_flattener.only_continuous: @@ -172,3 +178,113 @@ def action(self, action): actions.append(action % n) action = action // n return actions + + +class MultiMonoEnv(gym.Env): + """Fake mono-kart environment for mono-kart wrappers""" + + def __init__(self, env: gym.Env, key: str): + self._env = env + self.observation_space = env.observation_space[key] + self.action_space = env.action_space[key] + + def reset(self, **kwargs): + raise RuntimeError("Should not be called - fake mono environment") + + def step( + self, action: Any + ) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]: + raise RuntimeError("Should not be called - fake mono environment") + + +class MonoAgentWrapperAdapter(ActionObservationWrapper): + """Adapts a mono agent wrapper for a multi-agent one""" + + def __init__( + self, + env: gym.Env, + *, + wrapper_factories: Dict[str, Callable[[gym.Env], Wrapper]], + ): + """Initialize an adapter that use distinct wrappers + + It supposes that the space/action space is a dictionary where each key + corresponds to a different agent. + + :param env: The base environment + :param wrapper_factories: Return a wrapper for every key in the + observation/action spaces dictionary. Supported wrappers are + `ActionObservationWrapper`, `ObservationWrapper`, and `ActionWrapper`. + """ + super().__init__(env) + + # Perform some checks + self.keys = set(self.action_space.keys()) + assert self.keys == set( + self.observation_space.keys() + ), "Observation and action keys differ" + + # Setup the wrapped environment + self.mono_envs = {} + self.wrappers = {} + + for key in env.observation_space.keys(): + mono_env = MultiMonoEnv(env, key) + self.mono_envs[key] = mono_env + wrapper = wrapper_factories[key](mono_env) + + # Build up the list of action/observation wrappers + self.wrappers[key] = wrappers = [] + while wrapper is not mono_env: + assert isinstance( + wrapper, + ( + gym.ObservationWrapper, + gym.ActionWrapper, + ActionObservationWrapper, + ), + ), f"{type(wrapper)} is not an action/observation wrapper" + wrappers.append(wrapper) + wrapper = wrapper.env + + # Change the observation space + self._action_space = spaces.Dict( + { + key: self.wrappers[key][0].action_space + if len(self.wrappers[key]) > 0 + else self.mono_envs[key].action_space + for key in self.keys + } + ) + self._observation_space = spaces.Dict( + { + key: self.wrappers[key][0].observation_space + if len(self.wrappers[key]) > 0 + else self.mono_envs[key].observation_space + for key in self.keys + } + ) + + def action(self, actions: WrapperActType) -> ActType: + new_action = {} + for key in self.keys: + action = actions[key] + for wrapper in self.wrappers[key]: + if isinstance(wrapper, (gym.ActionWrapper, ActionObservationWrapper)): + action = wrapper.action(action) + new_action[key] = action + + return new_action + + def observation(self, observations: ObsType) -> WrapperObsType: + new_observation = {} + for key in self.keys: + observation = observations[key] + for wrapper in reversed(self.wrappers[key]): + if isinstance( + wrapper, (gym.ObservationWrapper, ActionObservationWrapper) + ): + observation = wrapper.observation(observation) + new_observation[key] = observation + + return new_observation diff --git a/tests/test_envs.py b/tests/test_envs.py index 509876c..500aa19 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -13,7 +13,7 @@ def test_env(name, use_ai): if name.startswith("supertuxkart/multi-"): kwargs = {"agents": [AgentSpec(use_ai=use_ai), AgentSpec(use_ai=use_ai)]} else: - kwargs = {"use_ai": use_ai} + kwargs = {"agent": AgentSpec(use_ai=use_ai)} try: env = gym.make(name, render_mode=None, **kwargs) @@ -27,7 +27,7 @@ def test_env(name, use_ai): # print(action) state, reward, terminated, truncated, _ = env.step(action) done = truncated or terminated - if done: + if done: break finally: if env is not None: