Skip to content

Commit

Permalink
adding temparature optimization, stable, openai#1
Browse files Browse the repository at this point in the history
  • Loading branch information
mahdinobar committed Apr 29, 2024
1 parent 4332106 commit e9a02d0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
9 changes: 7 additions & 2 deletions spinup/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def sample_batch(self, batch_size=32, sample_mode=1, sequence_length=1):

def sac(env_fn, actor_critic=core.MLPActorCritic, ac_kwargs=dict(), seed=0,
steps_per_epoch=4000, epochs=100, replay_size=int(1e6), gamma=0.99,
polyak=0.995, lr=1e-3, alpha=0.2, batch_size=100, start_steps=10000,
polyak=0.995, lr=1e-3, alpha_init=0.2, batch_size=100, start_steps=10000,
update_after=1000, update_every=50, num_test_episodes=10, max_ep_len=1000,
logger_kwargs=dict(), save_freq=1, initial_actions="random", save_buffer=False, sample_mode=1, automatic_entropy_tuning=False):
"""
Expand Down Expand Up @@ -194,6 +194,8 @@ def compute_loss_q(data):
if automatic_entropy_tuning==True:
# get updated alpha
alpha = log_alpha.exp()
else:
alpha = alpha_init

# Bellman backup for Q functions
with torch.no_grad():
Expand Down Expand Up @@ -227,6 +229,9 @@ def compute_loss_pi(data):
if automatic_entropy_tuning==True:
# get updated alpha
alpha = log_alpha.exp()
else:
alpha = alpha_init

# Entropy-regularized policy loss
loss_pi = (alpha * logp_pi - q_pi).mean()
# Useful info for logging
Expand Down Expand Up @@ -282,7 +287,7 @@ def update(data):
else:
# alpha_info = dict(LogAlpha=alpha.detach().numpy())
# logger.store(LossAlpha=alpha_loss.item(), **alpha_info)
logger.store(LossAlpha=0, LogAlpha=alpha)
logger.store(LossAlpha=0, LogAlpha=alpha_init)
# Unfreeze Q-networks so you can optimize it at next (DDPG, SAC, ...) step.
for p in q_params:
p.requires_grad = True
Expand Down
4 changes: 2 additions & 2 deletions spinup/examples/pytorch/Hybrid_SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
os.makedirs(output_dir)
logger_kwargs = dict(output_dir=output_dir, exp_name=exp_name)
sac(env_fn, ac_kwargs={}, seed=0, steps_per_epoch=100, epochs=200, replay_size=1000000, gamma=0.99, polyak=0.995,
lr=0.01, alpha=0.001, batch_size=100, start_steps=1000, update_after=1000, update_every=100, num_test_episodes=2,
max_ep_len=np.inf, logger_kwargs=logger_kwargs, save_freq=1, initial_actions="zero", save_buffer=True, sample_mode = 1, automatic_entropy_tuning=True)
lr=0.01, alpha_init=0.001, batch_size=100, start_steps=1000, update_after=1000, update_every=100, num_test_episodes=2,
max_ep_len=np.inf, logger_kwargs=logger_kwargs, save_freq=1, initial_actions="zero", save_buffer=True, sample_mode = 1, automatic_entropy_tuning=False)
else:
output_dir='/home/mahdi/ETHZ/codes/spinningup/spinup/examples/pytorch/logs/'+exp_name
env_loaded, get_action = load_policy_and_env(output_dir)
Expand Down

0 comments on commit e9a02d0

Please sign in to comment.