diff --git a/sac_ae.py b/sac_ae.py index 38b5f15..0e5d915 100644 --- a/sac_ae.py +++ b/sac_ae.py @@ -412,7 +412,7 @@ def update(self, replay_buffer, L, step): self.encoder_tau ) - if self.decoder is None and step % self.decoder_update_freq == 0: + if self.decoder is not None and step % self.decoder_update_freq == 0: self.update_decoder(obs, obs, L, step) def save(self, model_dir, step): diff --git a/train.py b/train.py index 35b740e..4f6cde4 100644 --- a/train.py +++ b/train.py @@ -32,15 +32,15 @@ def parse_args(): parser.add_argument('--agent', default='sac_ae', type=str) parser.add_argument('--init_steps', default=1000, type=int) parser.add_argument('--num_train_steps', default=1000000, type=int) - parser.add_argument('--batch_size', default=512, type=int) - parser.add_argument('--hidden_dim', default=256, type=int) + parser.add_argument('--batch_size', default=128, type=int) + parser.add_argument('--hidden_dim', default=1024, type=int) # eval parser.add_argument('--eval_freq', default=10000, type=int) parser.add_argument('--num_eval_episodes', default=10, type=int) # critic parser.add_argument('--critic_lr', default=1e-3, type=float) parser.add_argument('--critic_beta', default=0.9, type=float) - parser.add_argument('--critic_tau', default=0.005, type=float) + parser.add_argument('--critic_tau', default=0.01, type=float) parser.add_argument('--critic_target_update_freq', default=2, type=int) # actor parser.add_argument('--actor_lr', default=1e-3, type=float) @@ -52,19 +52,19 @@ def parse_args(): parser.add_argument('--encoder_type', default='pixel', type=str) parser.add_argument('--encoder_feature_dim', default=50, type=int) parser.add_argument('--encoder_lr', default=1e-3, type=float) - parser.add_argument('--encoder_tau', default=0.005, type=float) + parser.add_argument('--encoder_tau', default=0.05, type=float) parser.add_argument('--decoder_type', default='pixel', type=str) parser.add_argument('--decoder_lr', default=1e-3, type=float) parser.add_argument('--decoder_update_freq', default=1, type=int) - parser.add_argument('--decoder_latent_lambda', default=0.0, type=float) - parser.add_argument('--decoder_weight_lambda', default=0.0, type=float) + parser.add_argument('--decoder_latent_lambda', default=1e-6, type=float) + parser.add_argument('--decoder_weight_lambda', default=1e-7, type=float) parser.add_argument('--num_layers', default=4, type=int) parser.add_argument('--num_filters', default=32, type=int) # sac parser.add_argument('--discount', default=0.99, type=float) - parser.add_argument('--init_temperature', default=0.01, type=float) - parser.add_argument('--alpha_lr', default=1e-3, type=float) - parser.add_argument('--alpha_beta', default=0.9, type=float) + parser.add_argument('--init_temperature', default=0.1, type=float) + parser.add_argument('--alpha_lr', default=1e-4, type=float) + parser.add_argument('--alpha_beta', default=0.5, type=float) # misc parser.add_argument('--seed', default=1, type=int) parser.add_argument('--work_dir', default='.', type=str)