Skip to content

Commit

Permalink
adding temparature optimization, stable, #0
Browse files Browse the repository at this point in the history
  • Loading branch information
mahdinobar committed Apr 29, 2024
1 parent 102c3c2 commit 4332106
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 24 deletions.
51 changes: 28 additions & 23 deletions spinup/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def sac(env_fn, actor_critic=core.MLPActorCritic, ac_kwargs=dict(), seed=0,
steps_per_epoch=4000, epochs=100, replay_size=int(1e6), gamma=0.99,
polyak=0.995, lr=1e-3, alpha=0.2, batch_size=100, start_steps=10000,
update_after=1000, update_every=50, num_test_episodes=10, max_ep_len=1000,
logger_kwargs=dict(), save_freq=1, initial_actions="random", save_buffer=False, sample_mode=1):
logger_kwargs=dict(), save_freq=1, initial_actions="random", save_buffer=False, sample_mode=1, automatic_entropy_tuning=False):
"""
Soft Actor-Critic (SAC)
Expand Down Expand Up @@ -191,9 +191,9 @@ def compute_loss_q(data):

q1 = ac.q1(o, a)
q2 = ac.q2(o, a)

# get updated alpha
alpha = log_alpha.exp()
if automatic_entropy_tuning==True:
# get updated alpha
alpha = log_alpha.exp()

# Bellman backup for Q functions
with torch.no_grad():
Expand Down Expand Up @@ -224,8 +224,9 @@ def compute_loss_pi(data):
q1_pi = ac.q1(o, pi)
q2_pi = ac.q2(o, pi)
q_pi = torch.min(q1_pi, q2_pi)
# get updated alpha
alpha = log_alpha.exp()
if automatic_entropy_tuning==True:
# get updated alpha
alpha = log_alpha.exp()
# Entropy-regularized policy loss
loss_pi = (alpha * logp_pi - q_pi).mean()
# Useful info for logging
Expand All @@ -236,11 +237,12 @@ def compute_loss_pi(data):
pi_optimizer = Adam(ac.pi.parameters(), lr=lr)
q_optimizer = Adam(q_params, lr=lr)

device = torch.device("cpu")
target_entropy = -torch.prod(torch.Tensor(ac.pi.mu_layer.out_features).to(device)).item()
# log_alpha=torch.zeros(1, requires_grad=True, device=device)
log_alpha = torch.tensor([np.exp(alpha)], requires_grad=True, device=device)
alpha_optim = Adam([log_alpha], lr=lr)
if automatic_entropy_tuning is True:
device = torch.device("cpu")
target_entropy = -torch.prod(torch.Tensor(ac.pi.mu_layer.out_features).to(device)).item()
# log_alpha=torch.zeros(1, requires_grad=True, device=device)
log_alpha = torch.tensor([np.exp(alpha)], requires_grad=True, device=device)
alpha_optim = Adam([log_alpha], lr=lr)

# Set up model saving
logger.setup_pytorch_saver(ac)
Expand All @@ -266,18 +268,21 @@ def update(data):
loss_pi.backward()
pi_optimizer.step()


o = data['obs']
pi, logp_pi = ac.pi(o)
alpha_loss = -(log_alpha * (logp_pi + target_entropy).detach()).mean()
alpha_optim.zero_grad()
alpha_loss.backward()
alpha_optim.step()
alpha = log_alpha.exp()
# alpha_info = dict(LogAlpha=alpha.detach().numpy())
# logger.store(LossAlpha=alpha_loss.item(), **alpha_info)
logger.store(LossAlpha=alpha_loss.item(), LogAlpha=alpha.detach().numpy())

if automatic_entropy_tuning is True:
o = data['obs']
pi, logp_pi = ac.pi(o)
alpha_loss = -(log_alpha * (logp_pi + target_entropy).detach()).mean()
alpha_optim.zero_grad()
alpha_loss.backward()
alpha_optim.step()
alpha = log_alpha.exp()
# alpha_info = dict(LogAlpha=alpha.detach().numpy())
# logger.store(LossAlpha=alpha_loss.item(), **alpha_info)
logger.store(LossAlpha=alpha_loss.item(), LogAlpha=alpha.detach().numpy())
else:
# alpha_info = dict(LogAlpha=alpha.detach().numpy())
# logger.store(LossAlpha=alpha_loss.item(), **alpha_info)
logger.store(LossAlpha=0, LogAlpha=alpha)
# Unfreeze Q-networks so you can optimize it at next (DDPG, SAC, ...) step.
for p in q_params:
p.requires_grad = True
Expand Down
2 changes: 1 addition & 1 deletion spinup/examples/pytorch/Hybrid_SAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
logger_kwargs = dict(output_dir=output_dir, exp_name=exp_name)
sac(env_fn, ac_kwargs={}, seed=0, steps_per_epoch=100, epochs=200, replay_size=1000000, gamma=0.99, polyak=0.995,
lr=0.01, alpha=0.001, batch_size=100, start_steps=1000, update_after=1000, update_every=100, num_test_episodes=2,
max_ep_len=np.inf, logger_kwargs=logger_kwargs, save_freq=1, initial_actions="zero", save_buffer=True, sample_mode = 1)
max_ep_len=np.inf, logger_kwargs=logger_kwargs, save_freq=1, initial_actions="zero", save_buffer=True, sample_mode = 1, automatic_entropy_tuning=True)
else:
output_dir='/home/mahdi/ETHZ/codes/spinningup/spinup/examples/pytorch/logs/'+exp_name
env_loaded, get_action = load_policy_and_env(output_dir)
Expand Down

0 comments on commit 4332106

Please sign in to comment.