Skip to content

Commit

Permalink
Flattens action before sending to environment in DDPG/TD3/SAC.
Browse files Browse the repository at this point in the history
  • Loading branch information
jachiam committed Nov 24, 2018
1 parent 52e7ec3 commit 07bb739
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion spinup/algos/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def ddpg(env_fn, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0,
logger.setup_tf_saver(sess, inputs={'x': x_ph, 'a': a_ph}, outputs={'pi': pi, 'q': q})

def get_action(o, noise_scale):
a = sess.run(pi, feed_dict={x_ph: o.reshape(1,-1)})
a = sess.run(pi, feed_dict={x_ph: o.reshape(1,-1)})[0]
a += noise_scale * np.random.randn(act_dim)
return np.clip(a, -act_limit, act_limit)

Expand Down
2 changes: 1 addition & 1 deletion spinup/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def sac(env_fn, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0,

def get_action(o, deterministic=False):
act_op = mu if deterministic else pi
return sess.run(act_op, feed_dict={x_ph: o.reshape(1,-1)})
return sess.run(act_op, feed_dict={x_ph: o.reshape(1,-1)})[0]

def test_agent(n=10):
global sess, mu, pi, q1, q2, q1_pi, q2_pi
Expand Down
2 changes: 1 addition & 1 deletion spinup/algos/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def td3(env_fn, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0,
logger.setup_tf_saver(sess, inputs={'x': x_ph, 'a': a_ph}, outputs={'pi': pi, 'q1': q1, 'q2': q2})

def get_action(o, noise_scale):
a = sess.run(pi, feed_dict={x_ph: o.reshape(1,-1)})
a = sess.run(pi, feed_dict={x_ph: o.reshape(1,-1)})[0]
a += noise_scale * np.random.randn(act_dim)
return np.clip(a, -act_limit, act_limit)

Expand Down

0 comments on commit 07bb739

Please sign in to comment.