Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soft Actor Critic (SAC) Model #627

Merged
merged 43 commits into from
Sep 8, 2021
Merged
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8f1bf23
finish soft actor critic
blahBlahhhJ Apr 28, 2021
8c2145f
added tests
blahBlahhhJ Apr 29, 2021
0c872a1
finish document and init
blahBlahhhJ May 1, 2021
742943e
fix style 1
blahBlahhhJ May 1, 2021
700cdbb
fix style 2
blahBlahhhJ May 1, 2021
08ce087
fix style 3
blahBlahhhJ May 7, 2021
26ccf1c
Merge branch 'master' into feature/596-sac
Borda Jun 24, 2021
a544901
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2021
71e0dec
formt
Borda Jun 24, 2021
d4abe63
Merge branch 'feature/596-sac' of https://github.com/blahBlahhhJ/ligh…
Borda Jun 24, 2021
557ea57
Apply suggestions from code review
Borda Jun 24, 2021
ad47e34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2021
8c44f3e
Merge branch 'master' into feature/596-sac
mergify[bot] Jun 25, 2021
c26a88b
Merge branch 'master' into feature/596-sac
Borda Jul 4, 2021
d0e60d3
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 4, 2021
3254dbd
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 4, 2021
d81e8e0
use hyperparameters in hparams
blahBlahhhJ Jul 7, 2021
1a8e73f
Merge branch 'feature/596-sac' of https://github.com/blahBlahhhJ/ligh…
blahBlahhhJ Jul 7, 2021
d101d50
Add CHANGELOG
blahBlahhhJ Jul 7, 2021
c52ea1a
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 7, 2021
48800c9
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 13, 2021
47bb401
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 13, 2021
43daba3
fix test
blahBlahhhJ Jul 20, 2021
bfc7028
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 26, 2021
fd0964b
Merge branch 'master' into feature/596-sac
mergify[bot] Jul 28, 2021
2576333
fix format
blahBlahhhJ Aug 1, 2021
a1ec703
Merge branch 'feature/596-sac' of https://github.com/blahBlahhhJ/ligh…
blahBlahhhJ Aug 1, 2021
4723212
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 9, 2021
05b1084
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 13, 2021
c1660af
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 13, 2021
b207d3c
Merge branch 'master' into feature/596-sac
blahBlahhhJ Aug 13, 2021
73a13d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2021
be19c64
fix __init__
blahBlahhhJ Aug 13, 2021
25aa7e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2021
c6104c0
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 19, 2021
4486569
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 27, 2021
427d5ab
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 27, 2021
cbcc5c0
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 29, 2021
41d7365
Merge branch 'master' into feature/596-sac
mergify[bot] Aug 29, 2021
cccd10d
Merge branch 'master' into feature/596-sac
Sep 7, 2021
bfbae6b
Fix tests
Sep 8, 2021
c0d16fd
Fix reference
Sep 8, 2021
7a0e944
Fix duplication
Sep 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
formt
Borda committed Jun 24, 2021
commit 71e0decdab93c18043d3a1ff7c3df378fb4a147f
12 changes: 6 additions & 6 deletions docs/source/reinforce_learn.rst
Original file line number Diff line number Diff line change
@@ -688,13 +688,13 @@ Paper authors: Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, Sergey Levine

Original implementation by: `Jason Wang <https://github.com/blahBlahhhJ>`_

Soft Actor Critic (SAC) is a powerful actor critic algorithm in reinforcement learning. Unlike A2C, SAC's policy outputs a
special continuous distribution for actions, and its critic estimates the Q value instead of the state value, which
means it now takes in not only states but also actions. The new actor allows SAC to support continuous action tasks such
as controlling robots, and the new critic allows SAC to support off-policy learning which is more sample efficient.
Soft Actor Critic (SAC) is a powerful actor critic algorithm in reinforcement learning. Unlike A2C, SAC's policy outputs a
special continuous distribution for actions, and its critic estimates the Q value instead of the state value, which
means it now takes in not only states but also actions. The new actor allows SAC to support continuous action tasks such
as controlling robots, and the new critic allows SAC to support off-policy learning which is more sample efficient.

The actor has a new objective to maximize entropy to encourage exploration while maximizing the expected rewards.
The critic uses two separate Q functions to "mitigate positive bias" during training by picking the minimum of the
The actor has a new objective to maximize entropy to encourage exploration while maximizing the expected rewards.
The critic uses two separate Q functions to "mitigate positive bias" during training by picking the minimum of the
two as the predicted Q value.

Since SAC is off-policy, its algorithm's training step is quite similar to DQN:
9 changes: 5 additions & 4 deletions pl_bolts/models/rl/common/agents.py
Original file line number Diff line number Diff line change
@@ -142,7 +142,8 @@ def __call__(self, states: Tensor, device: str) -> List[int]:

class SoftActorCriticAgent(Agent):
"""Actor-Critic based agent that returns a continuous action based on the policy"""
def __call__(self, states: torch.Tensor, device: str) -> List[float]:

def __call__(self, states: Tensor, device: str) -> List[float]:
"""
Takes in the current state and returns the action based on the agents policy

@@ -156,15 +157,15 @@ def __call__(self, states: torch.Tensor, device: str) -> List[float]:
if not isinstance(states, list):
states = [states]

if not isinstance(states, torch.Tensor):
if not isinstance(states, Tensor):
states = torch.tensor(states, device=device)

dist = self.net(states)
actions = [a for a in dist.sample().cpu().numpy()]

return actions

def get_action(self, states: torch.Tensor, device: str) -> List[float]:
def get_action(self, states: Tensor, device: str) -> List[float]:
"""
Get the action greedily (without sampling)

@@ -178,7 +179,7 @@ def get_action(self, states: torch.Tensor, device: str) -> List[float]:
if not isinstance(states, list):
states = [states]

if not isinstance(states, torch.Tensor):
if not isinstance(states, Tensor):
states = torch.tensor(states, device=device)

actions = [self.net.get_action(states).cpu().numpy()]
6 changes: 4 additions & 2 deletions pl_bolts/models/rl/common/distributions.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ class TanhMultivariateNormal(torch.distributions.MultivariateNormal):
X = action_scale * tanh(Z) + action_bias
Z ~ Normal(mean, variance)
"""

def __init__(self, action_bias, action_scale, **kwargs):
super().__init__(**kwargs)

@@ -40,7 +41,7 @@ def log_prob_with_z(self, value, z):
"""
value = (value - self.action_bias) / self.action_scale
z_logprob = super().log_prob(z)
correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1)
correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1)
return z_logprob - correction

def rsample_and_log_prob(self, sample_shape=torch.Size()):
@@ -53,12 +54,13 @@ def rsample_and_log_prob(self, sample_shape=torch.Size()):
z = super().rsample()
z_logprob = super().log_prob(z)
value = torch.tanh(z)
correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1)
correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1)
return self.action_scale * value + self.action_bias, z_logprob - correction

"""
Some override methods
"""

def rsample(self, sample_shape=torch.Size()):
fz, z = self.rsample_with_z(sample_shape)
return fz
17 changes: 6 additions & 11 deletions pl_bolts/models/rl/common/networks.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@

import numpy as np
import torch
from torch import nn, Tensor
from torch import FloatTensor, nn, Tensor
from torch.nn import functional as F

from pl_bolts.models.rl.common.distributions import TanhMultivariateNormal
@@ -98,6 +98,7 @@ class ContinuousMLP(nn.Module):
"""
MLP network that outputs continuous value via Gaussian distribution
"""

def __init__(
self,
input_shape: Tuple[int],
@@ -119,15 +120,12 @@ def __init__(
self.action_scale = action_scale

self.shared_net = nn.Sequential(
nn.Linear(input_shape[0], hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU()
nn.Linear(input_shape[0], hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU()
)
self.mean_layer = nn.Linear(hidden_size, n_actions)
self.logstd_layer = nn.Linear(hidden_size, n_actions)

def forward(self, x: torch.FloatTensor) -> TanhMultivariateNormal:
def forward(self, x: FloatTensor) -> TanhMultivariateNormal:
"""
Forward pass through network. Calculates the action distribution

@@ -141,13 +139,10 @@ def forward(self, x: torch.FloatTensor) -> TanhMultivariateNormal:
logstd = torch.clamp(self.logstd_layer(x), -20, 2)
batch_scale_tril = torch.diag_embed(torch.exp(logstd))
return TanhMultivariateNormal(
action_bias=self.action_bias,
action_scale=self.action_scale,
loc=batch_mean,
scale_tril=batch_scale_tril
action_bias=self.action_bias, action_scale=self.action_scale, loc=batch_mean, scale_tril=batch_scale_tril
)

def get_action(self, x: torch.FloatTensor) -> torch.Tensor:
def get_action(self, x: FloatTensor) -> Tensor:
"""
Get the action greedily (without sampling)

49 changes: 17 additions & 32 deletions pl_bolts/models/rl/sac_model.py
Original file line number Diff line number Diff line change
@@ -5,13 +5,13 @@
from typing import Dict, List, Tuple

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning import seed_everything
from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import optim as optim
from torch.optim.optimizer import Optimizer
from torch import Tensor
from torch.nn import functional as F
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset
@@ -28,7 +28,8 @@
Env = object


class SAC(pl.LightningModule):
class SAC(LightningModule):

def __init__(
self,
env: str,
@@ -134,13 +135,7 @@ def populate(self, warm_start: int) -> None:
for _ in range(warm_start):
action = self.agent(self.state, self.device)
next_state, reward, done, _ = self.env.step(action[0])
exp = Experience(
state=self.state,
action=action[0],
reward=reward,
done=done,
new_state=next_state
)
exp = Experience(state=self.state, action=action[0], reward=reward, done=done, new_state=next_state)
self.buffer.append(exp)
self.state = next_state

@@ -151,12 +146,7 @@ def build_networks(self) -> None:
"""Initializes the SAC policy and q networks (with targets)"""
action_bias = torch.from_numpy((self.env.action_space.high + self.env.action_space.low) / 2)
action_scale = torch.from_numpy((self.env.action_space.high - self.env.action_space.low) / 2)
self.policy = ContinuousMLP(
self.obs_shape,
self.n_actions,
action_bias=action_bias,
action_scale=action_scale
)
self.policy = ContinuousMLP(self.obs_shape, self.n_actions, action_bias=action_bias, action_scale=action_scale)

concat_shape = [self.obs_shape[0] + self.n_actions]
self.q1 = MLP(concat_shape, 1)
@@ -176,12 +166,10 @@ def soft_update_target(self, q_net, target_net):
target_net: the target (q) network
"""
for q_param, target_param in zip(q_net.parameters(), target_net.parameters()):
target_param.data.copy_(
(1.0 - self.hparams.target_alpha) * target_param.data +
self.hparams.target_alpha * q_param
)
target_param.data.copy_((1.0 - self.hparams.target_alpha) * target_param.data
+ self.hparams.target_alpha * q_param)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: Tensor) -> Tensor:
"""
Passes in a state x through the network and gets the q_values of each action as an output

@@ -194,7 +182,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self.policy(x).sample()
return output

def train_batch(self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def train_batch(self, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
Contains the logic for generating a new batch of data to be passed to the DataLoader

@@ -236,10 +224,7 @@ def train_batch(self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch
if self.total_steps % self.batches_per_epoch == 0:
break

def loss(
self,
batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def loss(self, batch: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
"""
Calculates the loss for SAC which contains a total of 3 losses

@@ -283,7 +268,7 @@ def loss(

return policy_loss, q1_loss, q2_loss

def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _, optimizer_idx):
def training_step(self, batch: Tuple[Tensor, Tensor], _, optimizer_idx):
"""
Carries out a single step through the environment to update the replay buffer.
Then calculates loss based on the minibatch recieved
@@ -323,13 +308,13 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _, optimizer_i
"episode_steps": self.total_episode_steps[-1]
})

def test_step(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
def test_step(self, *args, **kwargs) -> Dict[str, Tensor]:
"""Evaluate the agent for 10 episodes"""
test_reward = self.run_n_episodes(self.test_env, 1)
avg_reward = sum(test_reward) / len(test_reward)
return {"test_reward": avg_reward}

def test_epoch_end(self, outputs) -> Dict[str, torch.Tensor]:
def test_epoch_end(self, outputs) -> Dict[str, Tensor]:
"""Log the avg of the test results"""
rewards = [x["test_reward"] for x in outputs]
avg_reward = sum(rewards) / len(rewards)
@@ -415,7 +400,7 @@ def cli_main():
parser = argparse.ArgumentParser(add_help=False)

# trainer args
parser = pl.Trainer.add_argparse_args(parser)
parser = Trainer.add_argparse_args(parser)

# model args
parser = SAC.add_model_specific_args(parser)
@@ -427,7 +412,7 @@ def cli_main():
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True)

seed_everything(123)
trainer = pl.Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback)
trainer = Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback)

trainer.fit(model)

6 changes: 3 additions & 3 deletions tests/models/rl/integration/test_actor_critic_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse

import pytorch_lightning as pl
from pytorch_lightning import Trainer

from pl_bolts.models.rl.sac_model import SAC

@@ -9,7 +9,7 @@ def test_sac():
"""Smoke test that the SAC model runs"""

parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser = pl.Trainer.add_argparse_args(parent_parser)
parent_parser = Trainer.add_argparse_args(parent_parser)
parent_parser = SAC.add_model_specific_args(parent_parser)
args_list = [
"--warm_start_size",
@@ -23,7 +23,7 @@ def test_sac():
]
hparams = parent_parser.parse_args(args_list)

trainer = pl.Trainer(
trainer = Trainer(
gpus=hparams.gpus,
max_steps=100,
max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early
17 changes: 9 additions & 8 deletions tests/models/rl/unit/test_sac.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse

import torch
from torch import Tensor

from pl_bolts.models.rl.sac_model import SAC

@@ -27,9 +28,9 @@ def test_sac_loss():

policy_loss, q1_loss, q2_loss = model.loss(batch)

assert isinstance(policy_loss, torch.Tensor)
assert isinstance(q1_loss, torch.Tensor)
assert isinstance(q2_loss, torch.Tensor)
assert isinstance(policy_loss, Tensor)
assert isinstance(q1_loss, Tensor)
assert isinstance(q2_loss, Tensor)


def test_sac_train_batch():
@@ -52,8 +53,8 @@ def test_sac_train_batch():
assert len(batch) == 5
assert len(batch[0]) == model.hparams.batch_size
assert isinstance(batch, list)
assert isinstance(batch[0], torch.Tensor)
assert isinstance(batch[1], torch.Tensor)
assert isinstance(batch[2], torch.Tensor)
assert isinstance(batch[3], torch.Tensor)
assert isinstance(batch[4], torch.Tensor)
assert isinstance(batch[0], Tensor)
assert isinstance(batch[1], Tensor)
assert isinstance(batch[2], Tensor)
assert isinstance(batch[3], Tensor)
assert isinstance(batch[4], Tensor)