Skip to content

Commit

Permalink
fix: fix 'test_policy' rendering (#291)
Browse files Browse the repository at this point in the history
This commit ensures that the 'test_policy' utility is compatible with
the rendering behavoir in Gymnasium > v21 (see https://younis.dev/blog/render-api/).
  • Loading branch information
rickstaa authored Jul 9, 2023
1 parent be99008 commit 48443ca
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 22 deletions.
4 changes: 2 additions & 2 deletions stable_learning_control/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from stable_learning_control.utils.import_utils import tf_installed

# Make module version available.
from .version import __version__ # noqa: F401
from .version import __version_tuple__ # noqa: F401
from .version import __version__
from .version import __version_tuple__

if tf_installed():
from stable_learning_control.algos.tf2.lac.lac import lac as lac_tf2
Expand Down
2 changes: 1 addition & 1 deletion stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ def lac(
logger.log(
(
f"You defined your 'max_ep_len' to be {max_ep_len} "
"while the environment 'max_epsisode_steps' is "
"while the environment 'max_episode_steps' is "
f"{env.env._max_episode_steps}. As a result the environment "
f"'max_episode_steps' has been increased to {max_ep_len}"
),
Expand Down
2 changes: 1 addition & 1 deletion stable_learning_control/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ def sac(
logger.log(
(
f"You defined your 'max_ep_len' to be {max_ep_len} "
"while the environment 'max_epsisode_steps' is "
"while the environment 'max_episode_steps' is "
f"{env.env._max_episode_steps}. As a result the environment "
f"'max_episode_steps' has been increased to {max_ep_len}"
),
Expand Down
2 changes: 1 addition & 1 deletion stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ def lac(
logger.log(
(
f"You defined your 'max_ep_len' to be {max_ep_len} "
"while the environment 'max_epsisode_steps' is "
"while the environment 'max_episode_steps' is "
f"{env.env._max_episode_steps}. As a result the environment "
f"'max_episode_steps' has been increased to {max_ep_len}"
),
Expand Down
2 changes: 1 addition & 1 deletion stable_learning_control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def sac(
logger.log(
(
f"You defined your 'max_ep_len' to be {max_ep_len} "
"while the environment 'max_epsisode_steps' is "
"while the environment 'max_episode_steps' is "
f"{env.env._max_episode_steps}. As a result the environment "
f"'max_episode_steps' has been increased to {max_ep_len}"
),
Expand Down
36 changes: 20 additions & 16 deletions stable_learning_control/utils/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,27 +284,31 @@ def run_policy(
"handle this situation."
)

# Apply episode length and render settings.
if max_ep_len is not None:
env.env._max_episode_steps = max_ep_len
if render: # Enable rendering if requested.
render_modes = env.unwrapped.metadata.get("render_modes", [])
if render_modes:
env.unwrapped.render_mode = (
"human"
if "human" in render_modes
else None
)
else:
log_to_std_out(
(
f"Nothing was rendered since the '{env.unwrapped.spec.id}' "
f"environment does not contain a 'human' render mode."
),
type="warning",
)

logger = EpochLogger(verbose_fmt="table")
o, _ = env.reset()
r, d, ep_ret, ep_len, n = 0, False, 0, 0, 0
supports_deterministic = True # Only supported with gaussian algorithms.
render_error = False
while n < num_episodes:
# Render env if requested.
if render and not render_error:
try:
env.render()
time.sleep(1e-3)
except NotImplementedError:
render_error = True
log_to_std_out(
(
"Nothing was rendered since no render method was "
f"implemented for the '{env.unwrapped.spec.id}' environment."
),
type="warning",
)

# Retrieve action.
if deterministic and supports_deterministic:
try:
Expand Down

0 comments on commit 48443ca

Please sign in to comment.