diff --git a/spinup/exercises/problem_set_1/exercise1_3.py b/spinup/exercises/problem_set_1/exercise1_3.py index f5c87b165..ca8da96b8 100644 --- a/spinup/exercises/problem_set_1/exercise1_3.py +++ b/spinup/exercises/problem_set_1/exercise1_3.py @@ -258,7 +258,7 @@ def td3(env_fn, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0, logger.setup_tf_saver(sess, inputs={'x': x_ph, 'a': a_ph}, outputs={'pi': pi, 'q1': q1, 'q2': q2}) def get_action(o, noise_scale): - a = sess.run(pi, feed_dict={x_ph: o.reshape(1,-1)}) + a = sess.run(pi, feed_dict={x_ph: o.reshape(1,-1)})[0] a += noise_scale * np.random.randn(act_dim) return np.clip(a, -act_limit, act_limit) @@ -382,4 +382,4 @@ def test_agent(n=10): if args.use_soln: true_td3(**all_kwargs) else: - td3(**all_kwargs) \ No newline at end of file + td3(**all_kwargs)