diff --git a/spinup/algos/pytorch/sac/sac.py b/spinup/algos/pytorch/sac/sac.py index 8cd0f5ae3..354aad9c6 100644 --- a/spinup/algos/pytorch/sac/sac.py +++ b/spinup/algos/pytorch/sac/sac.py @@ -53,7 +53,7 @@ def sample_batch(self, batch_size=32, sample_mode=1, sequence_length=1): def sac(env_fn, actor_critic=core.MLPActorCritic, ac_kwargs=dict(), seed=0, steps_per_epoch=4000, epochs=100, replay_size=int(1e6), gamma=0.99, - polyak=0.995, lr=1e-3, alpha=0.2, batch_size=100, start_steps=10000, + polyak=0.995, lr=1e-3, alpha_init=0.2, batch_size=100, start_steps=10000, update_after=1000, update_every=50, num_test_episodes=10, max_ep_len=1000, logger_kwargs=dict(), save_freq=1, initial_actions="random", save_buffer=False, sample_mode=1, automatic_entropy_tuning=False): """ @@ -194,6 +194,8 @@ def compute_loss_q(data): if automatic_entropy_tuning==True: # get updated alpha alpha = log_alpha.exp() + else: + alpha = alpha_init # Bellman backup for Q functions with torch.no_grad(): @@ -227,6 +229,9 @@ def compute_loss_pi(data): if automatic_entropy_tuning==True: # get updated alpha alpha = log_alpha.exp() + else: + alpha = alpha_init + # Entropy-regularized policy loss loss_pi = (alpha * logp_pi - q_pi).mean() # Useful info for logging @@ -282,7 +287,7 @@ def update(data): else: # alpha_info = dict(LogAlpha=alpha.detach().numpy()) # logger.store(LossAlpha=alpha_loss.item(), **alpha_info) - logger.store(LossAlpha=0, LogAlpha=alpha) + logger.store(LossAlpha=0, LogAlpha=alpha_init) # Unfreeze Q-networks so you can optimize it at next (DDPG, SAC, ...) step. for p in q_params: p.requires_grad = True diff --git a/spinup/examples/pytorch/Hybrid_SAC.py b/spinup/examples/pytorch/Hybrid_SAC.py index 938ca41c6..b3434ddc3 100644 --- a/spinup/examples/pytorch/Hybrid_SAC.py +++ b/spinup/examples/pytorch/Hybrid_SAC.py @@ -18,8 +18,8 @@ os.makedirs(output_dir) logger_kwargs = dict(output_dir=output_dir, exp_name=exp_name) sac(env_fn, ac_kwargs={}, seed=0, steps_per_epoch=100, epochs=200, replay_size=1000000, gamma=0.99, polyak=0.995, - lr=0.01, alpha=0.001, batch_size=100, start_steps=1000, update_after=1000, update_every=100, num_test_episodes=2, - max_ep_len=np.inf, logger_kwargs=logger_kwargs, save_freq=1, initial_actions="zero", save_buffer=True, sample_mode = 1, automatic_entropy_tuning=True) + lr=0.01, alpha_init=0.001, batch_size=100, start_steps=1000, update_after=1000, update_every=100, num_test_episodes=2, + max_ep_len=np.inf, logger_kwargs=logger_kwargs, save_freq=1, initial_actions="zero", save_buffer=True, sample_mode = 1, automatic_entropy_tuning=False) else: output_dir='/home/mahdi/ETHZ/codes/spinningup/spinup/examples/pytorch/logs/'+exp_name env_loaded, get_action = load_policy_and_env(output_dir)