Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix 'test_policy' rendering #291

Merged
merged 1 commit into from
Jul 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions stable_learning_control/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from stable_learning_control.utils.import_utils import tf_installed

# Make module version available.
from .version import __version__ # noqa: F401
from .version import __version_tuple__ # noqa: F401
from .version import __version__
from .version import __version_tuple__

if tf_installed():
from stable_learning_control.algos.tf2.lac.lac import lac as lac_tf2
Expand Down
2 changes: 1 addition & 1 deletion stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ def lac(
logger.log(
(
f"You defined your 'max_ep_len' to be {max_ep_len} "
"while the environment 'max_epsisode_steps' is "
"while the environment 'max_episode_steps' is "
f"{env.env._max_episode_steps}. As a result the environment "
f"'max_episode_steps' has been increased to {max_ep_len}"
),
Expand Down
2 changes: 1 addition & 1 deletion stable_learning_control/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ def sac(
logger.log(
(
f"You defined your 'max_ep_len' to be {max_ep_len} "
"while the environment 'max_epsisode_steps' is "
"while the environment 'max_episode_steps' is "
f"{env.env._max_episode_steps}. As a result the environment "
f"'max_episode_steps' has been increased to {max_ep_len}"
),
Expand Down
2 changes: 1 addition & 1 deletion stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ def lac(
logger.log(
(
f"You defined your 'max_ep_len' to be {max_ep_len} "
"while the environment 'max_epsisode_steps' is "
"while the environment 'max_episode_steps' is "
f"{env.env._max_episode_steps}. As a result the environment "
f"'max_episode_steps' has been increased to {max_ep_len}"
),
Expand Down
2 changes: 1 addition & 1 deletion stable_learning_control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def sac(
logger.log(
(
f"You defined your 'max_ep_len' to be {max_ep_len} "
"while the environment 'max_epsisode_steps' is "
"while the environment 'max_episode_steps' is "
f"{env.env._max_episode_steps}. As a result the environment "
f"'max_episode_steps' has been increased to {max_ep_len}"
),
Expand Down
36 changes: 20 additions & 16 deletions stable_learning_control/utils/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,27 +284,31 @@ def run_policy(
"handle this situation."
)

# Apply episode length and render settings.
if max_ep_len is not None:
env.env._max_episode_steps = max_ep_len
if render: # Enable rendering if requested.
render_modes = env.unwrapped.metadata.get("render_modes", [])
if render_modes:
env.unwrapped.render_mode = (
"human"
if "human" in render_modes
else None
)
else:
log_to_std_out(
(
f"Nothing was rendered since the '{env.unwrapped.spec.id}' "
f"environment does not contain a 'human' render mode."
),
type="warning",
)

logger = EpochLogger(verbose_fmt="table")
o, _ = env.reset()
r, d, ep_ret, ep_len, n = 0, False, 0, 0, 0
supports_deterministic = True # Only supported with gaussian algorithms.
render_error = False
while n < num_episodes:
# Render env if requested.
if render and not render_error:
try:
env.render()
time.sleep(1e-3)
except NotImplementedError:
render_error = True
log_to_std_out(
(
"Nothing was rendered since no render method was "
f"implemented for the '{env.unwrapped.spec.id}' environment."
),
type="warning",
)

# Retrieve action.
if deterministic and supports_deterministic:
try:
Expand Down