From db4fb3f39c0f67959ca2f2241fd028a44fb8db9d Mon Sep 17 00:00:00 2001 From: Collin Donahue-Oponski Date: Thu, 31 Jan 2019 10:07:26 -0700 Subject: [PATCH] [Exercise 1.3] Flatten actions in policy rollout Updates exercise 1.3 to match the change made in commit 07bb739. --- spinup/exercises/problem_set_1/exercise1_3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spinup/exercises/problem_set_1/exercise1_3.py b/spinup/exercises/problem_set_1/exercise1_3.py index f5c87b165..ca8da96b8 100644 --- a/spinup/exercises/problem_set_1/exercise1_3.py +++ b/spinup/exercises/problem_set_1/exercise1_3.py @@ -258,7 +258,7 @@ def td3(env_fn, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0, logger.setup_tf_saver(sess, inputs={'x': x_ph, 'a': a_ph}, outputs={'pi': pi, 'q1': q1, 'q2': q2}) def get_action(o, noise_scale): - a = sess.run(pi, feed_dict={x_ph: o.reshape(1,-1)}) + a = sess.run(pi, feed_dict={x_ph: o.reshape(1,-1)})[0] a += noise_scale * np.random.randn(act_dim) return np.clip(a, -act_limit, act_limit) @@ -382,4 +382,4 @@ def test_agent(n=10): if args.use_soln: true_td3(**all_kwargs) else: - td3(**all_kwargs) \ No newline at end of file + td3(**all_kwargs)