Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soft Actor Critic (SAC) Model #627

Merged
merged 43 commits into from
Sep 8, 2021
Merged
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8f1bf23
finish soft actor critic
blahBlahhhJ Apr 28, 2021
8c2145f
added tests
blahBlahhhJ Apr 29, 2021
0c872a1
finish document and init
blahBlahhhJ May 1, 2021
742943e
fix style 1
blahBlahhhJ May 1, 2021
700cdbb
fix style 2
blahBlahhhJ May 1, 2021
08ce087
fix style 3
blahBlahhhJ May 7, 2021
26ccf1c
Merge branch 'master' into feature/596-sac
Borda Jun 24, 2021
a544901
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2021
71e0dec
formt
Borda Jun 24, 2021
d4abe63
Merge branch 'feature/596-sac' of https://github.com/blahBlahhhJ/ligh…
Borda Jun 24, 2021
557ea57
Apply suggestions from code review
Borda Jun 24, 2021
ad47e34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2021
8c44f3e
Merge branch 'master' into feature/596-sac
mergify[bot] Jun 25, 2021
c26a88b
Merge branch 'master' into feature/596-sac
Borda Jul 4, 2021
d0e60d3
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 4, 2021
3254dbd
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 4, 2021
d81e8e0
use hyperparameters in hparams
blahBlahhhJ Jul 7, 2021
1a8e73f
Merge branch 'feature/596-sac' of https://github.com/blahBlahhhJ/ligh…
blahBlahhhJ Jul 7, 2021
d101d50
Add CHANGELOG
blahBlahhhJ Jul 7, 2021
c52ea1a
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 7, 2021
48800c9
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 13, 2021
47bb401
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 13, 2021
43daba3
fix test
blahBlahhhJ Jul 20, 2021
bfc7028
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 26, 2021
fd0964b
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 28, 2021
2576333
fix format
blahBlahhhJ Aug 1, 2021
a1ec703
Merge branch 'feature/596-sac' of https://github.com/blahBlahhhJ/ligh…
blahBlahhhJ Aug 1, 2021
4723212
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 9, 2021
05b1084
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 13, 2021
c1660af
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 13, 2021
b207d3c
Merge branch 'master' into feature/596-sac
blahBlahhhJ Aug 13, 2021
73a13d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2021
be19c64
fix __init__
blahBlahhhJ Aug 13, 2021
25aa7e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2021
c6104c0
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 19, 2021
4486569
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 27, 2021
427d5ab
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 27, 2021
cbcc5c0
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 29, 2021
41d7365
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 29, 2021
cccd10d
Merge branch 'master' into feature/596-sac
Sep 7, 2021
bfbae6b
Fix tests
Sep 8, 2021
c0d16fd
Fix reference
Sep 8, 2021
7a0e944
Fix duplication
Sep 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use hyperparameters in hparams
blahBlahhhJ committed Jul 7, 2021
commit d81e8e0a31d68fde9ef2a8a71d735458a97368fb
22 changes: 7 additions & 15 deletions pl_bolts/models/rl/sac_model.py
Original file line number Diff line number Diff line change
@@ -73,14 +73,6 @@ def __init__(
self.agent = SoftActorCriticAgent(self.policy)

# Hyperparameters
self.sync_rate = sync_rate
self.gamma = gamma
self.batch_size = batch_size
self.replay_size = replay_size
self.warm_start_size = warm_start_size
self.batches_per_epoch = batches_per_epoch
self.n_steps = n_steps

self.save_hyperparameters()

# Metrics
@@ -227,13 +219,13 @@ def train_batch(self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch
episode_steps = 0
episode_reward = 0

states, actions, rewards, dones, new_states = self.buffer.sample(self.batch_size)
states, actions, rewards, dones, new_states = self.buffer.sample(self.hparams.batch_size)

for idx, _ in enumerate(dones):
yield states[idx], actions[idx], rewards[idx], dones[idx], new_states[idx]

# Simulates epochs
if self.total_steps % self.batches_per_epoch == 0:
if self.total_steps % self.hparams.batches_per_epoch == 0:
break

def loss(
@@ -276,7 +268,7 @@ def loss(
next_q1_values = self.target_q1(new_next_states_actions)
next_q2_values = self.target_q2(new_next_states_actions)
next_qmin_values = torch.min(next_q1_values, next_q2_values) - new_next_logprobs
target_values = rewards + (1. - dones) * self.gamma * next_qmin_values
target_values = rewards + (1. - dones) * self.hparams.gamma * next_qmin_values

q1_loss = F.mse_loss(q1_values, target_values)
q2_loss = F.mse_loss(q2_values, target_values)
@@ -309,7 +301,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _, optimizer_i
q2_optim.step()

# Soft update of target network
if self.global_step % self.sync_rate == 0:
if self.global_step % self.hparams.sync_rate == 0:
self.soft_update_target(self.q1, self.target_q1)
self.soft_update_target(self.q2, self.target_q2)

@@ -338,11 +330,11 @@ def test_epoch_end(self, outputs) -> Dict[str, torch.Tensor]:

def _dataloader(self) -> DataLoader:
"""Initialize the Replay Buffer dataset used for retrieving experiences"""
self.buffer = MultiStepBuffer(self.replay_size, self.n_steps)
self.populate(self.warm_start_size)
self.buffer = MultiStepBuffer(self.hparams.replay_size, self.hparams.n_steps)
self.populate(self.hparams.warm_start_size)

self.dataset = ExperienceSourceDataset(self.train_batch)
return DataLoader(dataset=self.dataset, batch_size=self.batch_size)
return DataLoader(dataset=self.dataset, batch_size=self.hparams.batch_size)

def train_dataloader(self) -> DataLoader:
"""Get train loader"""