From 80fe370a6132e1663fd2dadb4d8397a5adea1b3b Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 3 Jul 2023 12:11:25 +0200 Subject: [PATCH] fix: ensure 'test_policy' works with gymnasium>=0.28.1 (#276) This commit ensures the `test_policy` utility works with gymnasium>=0.28.1 (see https://gymnasium.farama.org/content/migration-guide/). --- stable_learning_control/utils/test_policy.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/stable_learning_control/utils/test_policy.py b/stable_learning_control/utils/test_policy.py index a983a4403..d30a560d2 100644 --- a/stable_learning_control/utils/test_policy.py +++ b/stable_learning_control/utils/test_policy.py @@ -285,7 +285,8 @@ def run_policy( ) logger = EpochLogger(verbose_fmt="table") - o, r, d, ep_ret, ep_len, n = env.reset(), 0, False, 0, 0, 0 + o, _ = env.reset() + r, d, ep_ret, ep_len, n = 0, False, 0, 0, 0 supports_deterministic = True # Only supported with gaussian algorithms. render_error = False while n < num_episodes: @@ -298,7 +299,7 @@ def run_policy( render_error = True log_to_std_out( ( - "WARNING: Nothing was rendered since no render method was " + "Nothing was rendered since no render method was " f"implemented for the '{env.unwrapped.spec.id}' environment." ), type="warning", @@ -321,13 +322,14 @@ def run_policy( a = policy.get_action(o) # Perform action in the environment and store result. - o, r, d, _ = env.step(a) + o, r, d, truncated, _ = env.step(a) ep_ret += r ep_len += 1 - if d or (ep_len == max_ep_len): + if d or truncated: logger.store(EpRet=ep_ret, EpLen=ep_len) logger.log("Episode %d \t EpRet %.3f \t EpLen %d" % (n, ep_ret, ep_len)) - o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 + o, _ = env.reset() + r, d, ep_ret, ep_len = 0, False, 0, 0 n += 1 print("")