Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advantage Actor Critic (A2C) Model #598

Merged
merged 46 commits into from
Aug 13, 2021
Merged
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
6f1afc9
a2c draft
blahBlahhhJ Mar 19, 2021
d6e6652
finish logic but not training
blahBlahhhJ Mar 19, 2021
b9ee7e9
cli pass converge on cartpole environment
blahBlahhhJ Mar 19, 2021
9a3a309
test by calling from package, fix code formatting, ready for review
blahBlahhhJ Mar 20, 2021
ed891bc
add tests, fix formatting
blahBlahhhJ Mar 20, 2021
415437b
fix typo
blahBlahhhJ Mar 20, 2021
47932be
fix tests, ready for review
blahBlahhhJ Mar 20, 2021
f2b19c8
Add A2C to __init__
akihironitta Mar 20, 2021
22f3b85
Update docs
akihironitta Mar 20, 2021
8221035
Fix formatting
akihironitta Mar 20, 2021
16bcd4a
Use self.hparams and remove n_steps
akihironitta Mar 20, 2021
e2ffd14
Update CHANGELOG
akihironitta Mar 20, 2021
a06528e
Merge branch 'master' into feature/596_a2c
blahBlahhhJ Mar 20, 2021
e397c47
fix typing hints, add documentation for A2C
blahBlahhhJ Mar 21, 2021
245feb0
minor formatting issue
blahBlahhhJ Mar 21, 2021
9211f20
delete print and add normalization
blahBlahhhJ Mar 21, 2021
17fc418
Adjust fig size
akihironitta Mar 21, 2021
b26b271
Fix typing
akihironitta Mar 21, 2021
f7d0a74
switch to function based pytest
blahBlahhhJ Apr 19, 2021
a1f2949
Merge branch 'feature/596_a2c' of https://github.com/blahBlahhhJ/ligh…
blahBlahhhJ Apr 19, 2021
85c407e
fix formatting
blahBlahhhJ Apr 19, 2021
0d10f0a
fix import
blahBlahhhJ Apr 19, 2021
cc9909b
fix format again
blahBlahhhJ Apr 19, 2021
46785bd
fix format again again
blahBlahhhJ Apr 19, 2021
bf14f13
ad another function test
blahBlahhhJ May 8, 2021
53a5703
Merge branch 'master' into feature/596_a2c
Borda Jun 24, 2021
83f5cef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2021
fa64829
formt
Borda Jun 24, 2021
8e1c783
Merge branch 'feature/596_a2c' of https://github.com/blahBlahhhJ/ligh…
Borda Jun 24, 2021
023912b
Apply suggestions from code review
Borda Jun 24, 2021
6167d04
Merge branch 'master' into feature/596_a2c
mergify[bot] Jun 25, 2021
53ff8cc
Merge branch 'master' into feature/596_a2c
mergify[bot] Jun 25, 2021
1159c63
Merge branch 'master' into feature/596_a2c
mergify[bot] Jun 29, 2021
1faa5f5
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 1, 2021
73b240f
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 4, 2021
cdada9d
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 4, 2021
89a3b1a
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 7, 2021
c90beb9
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 13, 2021
baa512a
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 13, 2021
eb30b22
fix test
blahBlahhhJ Jul 20, 2021
b37888d
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 26, 2021
a509d04
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 28, 2021
74bfa34
Merge branch 'master' into feature/596_a2c
mergify[bot] Aug 9, 2021
57542aa
Merge branch 'master' into feature/596_a2c
mergify[bot] Aug 13, 2021
d717a71
Merge branch 'master' into feature/596_a2c
mergify[bot] Aug 13, 2021
4687f9a
Update CHANGELOG.md
Borda Aug 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
a2c draft
blahBlahhhJ committed Mar 19, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 6f1afc92565c2a45ed514db92621efbc82595903
300 changes: 300 additions & 0 deletions pl_bolts/models/rl/advantage_actor_critic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
"""
Advantage Actor Critic (A2C)
"""
import argparse
from collections import OrderedDict
from typing import List, Tuple

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import optim as optim
from torch.nn.functional import log_softmax, softmax
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pl_bolts.datamodules import ExperienceSourceDataset
from pl_bolts.models.rl.common.agents import ActorCriticAgent
from pl_bolts.models.rl.common.networks import ActorCriticMLP
from pl_bolts.utils import _GYM_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _GYM_AVAILABLE:
import gym
else: # pragma: no cover
warn_missing_pkg('gym')


class AdvantageActorCritic(pl.LightningModule):
"""
PyTorch Lightning implementation of `Advantage Actor Critic`
Model implemented by:
- `Jason Wang <https://github.com/blahBlahhhJ>`
Example:
>>> from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic
...
>>> model = AdvantageActorCritic("CartPole-v0")
Train::
trainer = Trainer()
trainer.fit(model)
"""
def __init__(
self,
env: str,
gamma: float = 0.99,
lr: float = 0.01,
batch_size: int = 8,
n_steps: int = 10,
avg_reward_len: int = 100,
entropy_beta: float = 0.01,
critic_beta: float = 0.5,
epoch_len: int = 1000,
**kwargs
) -> None:
"""
Args:
env: gym environment tag
gamma: discount factor
lr: learning rate
batch_size: size of minibatch pulled from the DataLoader
batch_episodes: how many episodes to rollout for each batch of training
entropy_beta: dictates the level of entropy per batch
critic_beta: dictates the level of critic loss per batch
avg_reward_len: how many episodes to take into account when calculating the avg reward
epoch_len: how many batches before pseudo epoch
"""
super().__init__()

if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('This Module requires gym environment which is not installed yet.')

# Hyperparameters
self.lr = lr
self.batch_size = batch_size
self.batches_per_epoch = self.batch_size * epoch_len
self.entropy_beta = entropy_beta
self.critic_beta = critic_beta
self.gamma = gamma
self.n_steps = n_steps

self.save_hyperparameters()

# Model components
self.env = gym.make(env)
self.net = ActorCriticMLP(self.env.observation_space.shape, self.env.action_space.n)
self.agent = ActorCriticAgent(self.net)

# Tracking metrics
self.total_rewards = []
self.episode_reward = 0
self.done_episodes = 0
self.avg_rewards = 0
self.avg_reward_len = avg_reward_len
self.eps = np.finfo(np.float32).eps.item()
self.batch_states = []
self.batch_actions = []
self.batch_rewards = []
self.batch_logprobs = []
self.batch_values = []
self.batch_masks = []

self.state = self.env.reset()

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
"""
Passes in a state x through the network and gets the log prob of each action and the value for the state as an output
Args:
x: environment state
Returns:
action log probabilities, values
"""
logprobs, values = self.net(x)
return logprobs, values

def train_batch(self, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
"""
Contains the logic for generating a new batch of data to be passed to the DataLoader
Returns:
yields a tuple of Lists containing tensors for states, actions, returns, values, and log probabilities of the batch.
"""

for _ in range(self.batch_size):
logprob, value = self.net(self.state)
action = self.agent.get_action(logprob)

next_state, reward, done, _ = self.env.step(action[0])

self.batch_rewards.append(reward)
self.batch_actions.append(action)
self.batch_logprobs.append(logprob)
self.batch_values.append(value)
self.batch_states.append(self.state)
self.batch_masks.append(done)
self.state = next_state
self.episode_reward += reward

if done:
self.done_episodes += 1
self.state = self.env.reset()
self.total_rewards.append(self.episode_reward)
self.episode_reward = 0
self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len:]))

returns = self.compute_returns(self.batch_rewards, self.batch_dones, self.batch_values[-1])

for idx in range(len(self.batch_actions)):
yield self.batch_states[idx], self.batch_actions[idx], returns[idx], self.batch_values[idx], self.batch_logprobs[idx]

self.batch_states = []
self.batch_actions = []
self.batch_values = []
self.batch_logprobs = []
self.batch_masks = []

def compute_returns(self, rewards, dones, last_value):
"""
Calculate the discounted rewards of the batched rewards
Args:
rewards: list of batched rewards
dones: list of done masks
last_value: the predicted value for the last state
Returns:
list of discounted rewards
"""
reward = 0
# if last state isn't terminal, bootstrap the last value
if not dones[-1]:
reward = last_value
returns = []

for r, d in zip(rewards[::-1], dones[::-1]):
reward = r + self.gamma * reward * (1 - d)
returns.append(reward)

returns = torch.tensor(returns[::-1])

return returns

def loss(self, states, actions, returns, values, logprobs):
with torch.no_grad():
advs = returns - values
advs = (advs - advs.mean()) / (advs.std() + self.eps)

# entropy loss
entropy = -logprobs.exp() * logprobs
entropy = self.entropy_beta * entropy.sum(1).mean()

# actor loss
logprobs = logprobs.gather(1, actions)
actor_loss = -(logprobs * advs).mean()

# critic loss
critic_loss = self.critic_beta * torch.square(values - returns).mean()

total_loss = actor_loss + critic_loss - entropy
return total_loss

def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict:
states, actions, returns, values, logprobs = batch

loss = self.loss(states, actions, returns, values)

log = {
"episodes": self.done_episodes,
"reward": self.total_rewards[-1],
"avg_reward": self.avg_rewards,
}
return OrderedDict({
"loss": loss,
"avg_reward": self.avg_rewards,
"log": log,
"progress_bar": log,
})

def configure_optimizers(self) -> List[Optimizer]:
""" Initialize Adam optimizer"""
optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
return [optimizer]

def _dataloader(self) -> DataLoader:
"""Initialize the Replay Buffer dataset used for retrieving experiences"""
dataset = ExperienceSourceDataset(self.train_batch)
dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size)
return dataloader

def train_dataloader(self) -> DataLoader:
"""Get train loader"""
return self._dataloader()

def get_device(self, batch) -> str:
"""Retrieve device currently being used by minibatch"""
return batch[0][0][0].device.index if self.on_gpu else "cpu"

@staticmethod
def add_model_specific_args(arg_parser) -> argparse.ArgumentParser:
"""
Adds arguments for DQN model
Note:
These params are fine tuned for Pong env.
Args:
arg_parser: the current argument parser to add to
Returns:
arg_parser with model specific cargs added
"""

arg_parser.add_argument("--entropy_beta", type=float, default=0.01, help="entropy coefficient")
arg_parser.add_argument("--critic_beta", type=float, default=0.5, help="critic loss coefficient")
arg_parser.add_argument("--batches_per_epoch", type=int, default=10000, help="number of batches in an epoch")
arg_parser.add_argument("--batch_size", type=int, default=32, help="size of the batches")
arg_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
arg_parser.add_argument("--env", type=str, required=True, help="gym environment tag")
arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
arg_parser.add_argument("--seed", type=int, default=123, help="seed for training run")

arg_parser.add_argument(
"--avg_reward_len",
type=int,
default=100,
help="how many episodes to include in avg reward",
)

return arg_parser

def cli_main():
parser = argparse.ArgumentParser(add_help=False)

# trainer args
parser = pl.Trainer.add_argparse_args(parser)

# model args
parser = AdvantageActorCritic.add_model_specific_args(parser)
args = parser.parse_args()

model = AdvantageActorCritic(**args.__dict__)

# save checkpoints based on avg_reward
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True)

seed_everything(123)
trainer = pl.Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback)
trainer.fit(model)


if __name__ == '__main__':
cli_main()

49 changes: 49 additions & 0 deletions pl_bolts/models/rl/common/agents.py
Original file line number Diff line number Diff line change
@@ -138,3 +138,52 @@ def __call__(self, states: torch.Tensor, device: str) -> List[int]:
actions = [np.random.choice(len(prob), p=prob) for prob in prob_np]

return actions


class ActorCriticAgent(Agent):
"""Actor-Critic based agent that returns an action based on the networks policy"""

def __call__(self, states: torch.Tensor, device: str) -> List[int]:
"""
Takes in the current state and returns the action based on the agents policy
Args:
states: current state of the environment
device: the device used for the current batch
Returns:
action defined by policy
"""
if not isinstance(states, list):
states = [states]

if not isinstance(states, torch.Tensor):
states = torch.tensor(states, device=device)

# get the logits and pass through softmax for probability distribution
logprobs, _ = self.net(states)
probabilities = logprobs.exp().squeeze(dim=-1)
prob_np = probabilities.data.cpu().numpy()

# take the numpy values and randomly select action based on prob distribution
actions = [np.random.choice(len(prob), p=prob) for prob in prob_np]

return actions

def get_action(self, logprobs: torch.Tensor):
"""
Takes in the current state and returns the action and value based on the agents policy
Args:
logprobs: the actor head output from the network
Returns:
action sampled according to logits
"""
probabilities = logprobs.exp().squeeze(dim=-1)
prob_np = probabilities.data.cpu().numpy()

# take the numpy values and randomly select action based on prob distribution
actions = [np.random.choice(len(prob), p=prob) for prob in prob_np]

return actions
36 changes: 36 additions & 0 deletions pl_bolts/models/rl/common/networks.py
Original file line number Diff line number Diff line change
@@ -92,6 +92,42 @@ def forward(self, input_x):
return self.net(input_x.float())


class ActorCriticMLP(nn.Module):
"""
MLP network with heads for actor and critic
"""

def __init__(self, input_shape: Tuple[int], n_actions: int, hidden_size: int = 128):
"""
Args:
input_shape: observation shape of the environment
n_actions: number of discrete actions available in the environment
hidden_size: size of hidden layers
"""
super().__init__()

self.fc1 = nn.Linear(input_shape[0], hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.actor_head = nn.Linear(hidden_size, n_actions)
self.critic_head = nn.Linear(hidden_size, 1)

def forward(self, x) -> Tuple[Tensor]:
"""
Forward pass through network. Calculates the action logits and the value
Args:
x: input to network
Returns:
action log probs (logits), value
"""
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
a = F.log_softmax(self.actor_head(x), dim=-1)
v = self.critic_head(x)
return a, v


class DuelingMLP(nn.Module):
"""
MLP network with duel heads for val and advantage