Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advantage Actor Critic (A2C) Model #598

Merged
merged 46 commits into from
Aug 13, 2021
Merged
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
6f1afc9
a2c draft
blahBlahhhJ Mar 19, 2021
d6e6652
finish logic but not training
blahBlahhhJ Mar 19, 2021
b9ee7e9
cli pass converge on cartpole environment
blahBlahhhJ Mar 19, 2021
9a3a309
test by calling from package, fix code formatting, ready for review
blahBlahhhJ Mar 20, 2021
ed891bc
add tests, fix formatting
blahBlahhhJ Mar 20, 2021
415437b
fix typo
blahBlahhhJ Mar 20, 2021
47932be
fix tests, ready for review
blahBlahhhJ Mar 20, 2021
f2b19c8
Add A2C to __init__
akihironitta Mar 20, 2021
22f3b85
Update docs
akihironitta Mar 20, 2021
8221035
Fix formatting
akihironitta Mar 20, 2021
16bcd4a
Use self.hparams and remove n_steps
akihironitta Mar 20, 2021
e2ffd14
Update CHANGELOG
akihironitta Mar 20, 2021
a06528e
Merge branch 'master' into feature/596_a2c
blahBlahhhJ Mar 20, 2021
e397c47
fix typing hints, add documentation for A2C
blahBlahhhJ Mar 21, 2021
245feb0
minor formatting issue
blahBlahhhJ Mar 21, 2021
9211f20
delete print and add normalization
blahBlahhhJ Mar 21, 2021
17fc418
Adjust fig size
akihironitta Mar 21, 2021
b26b271
Fix typing
akihironitta Mar 21, 2021
f7d0a74
switch to function based pytest
blahBlahhhJ Apr 19, 2021
a1f2949
Merge branch 'feature/596_a2c' of https://github.com/blahBlahhhJ/ligh…
blahBlahhhJ Apr 19, 2021
85c407e
fix formatting
blahBlahhhJ Apr 19, 2021
0d10f0a
fix import
blahBlahhhJ Apr 19, 2021
cc9909b
fix format again
blahBlahhhJ Apr 19, 2021
46785bd
fix format again again
blahBlahhhJ Apr 19, 2021
bf14f13
ad another function test
blahBlahhhJ May 8, 2021
53a5703
Merge branch 'master' into feature/596_a2c
Borda Jun 24, 2021
83f5cef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 24, 2021
fa64829
formt
Borda Jun 24, 2021
8e1c783
Merge branch 'feature/596_a2c' of https://github.com/blahBlahhhJ/ligh…
Borda Jun 24, 2021
023912b
Apply suggestions from code review
Borda Jun 24, 2021
6167d04
Merge branch 'master' into feature/596_a2c
mergify[bot] Jun 25, 2021
53ff8cc
Merge branch 'master' into feature/596_a2c
mergify[bot] Jun 25, 2021
1159c63
Merge branch 'master' into feature/596_a2c
mergify[bot] Jun 29, 2021
1faa5f5
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 1, 2021
73b240f
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 4, 2021
cdada9d
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 4, 2021
89a3b1a
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 7, 2021
c90beb9
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 13, 2021
baa512a
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 13, 2021
eb30b22
fix test
blahBlahhhJ Jul 20, 2021
b37888d
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 26, 2021
a509d04
Merge branch 'master' into feature/596_a2c
mergify[bot] Jul 28, 2021
74bfa34
Merge branch 'master' into feature/596_a2c
mergify[bot] Aug 9, 2021
57542aa
Merge branch 'master' into feature/596_a2c
mergify[bot] Aug 13, 2021
d717a71
Merge branch 'master' into feature/596_a2c
mergify[bot] Aug 13, 2021
4687f9a
Update CHANGELOG.md
Borda Aug 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
finish logic but not training
blahBlahhhJ committed Mar 19, 2021
commit d6e6652db6e25b6f941e9ec738bcca11c0459e89
56 changes: 39 additions & 17 deletions pl_bolts/models/rl/advantage_actor_critic_model.py
Original file line number Diff line number Diff line change
@@ -101,8 +101,6 @@ def __init__(
self.batch_states = []
self.batch_actions = []
self.batch_rewards = []
self.batch_logprobs = []
self.batch_values = []
self.batch_masks = []

self.state = self.env.reset()
@@ -117,6 +115,11 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
Returns:
action log probabilities, values
"""
if not isinstance(x, list):
x = [x]

if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device)
logprobs, values = self.net(x)
return logprobs, values

@@ -126,18 +129,19 @@ def train_batch(self, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[to

Returns:
yields a tuple of Lists containing tensors for states, actions, returns, values, and log probabilities of the batch.

states: a list of numpy array
actions: a list of list of int
returns: a torch tensor
"""

for _ in range(self.batch_size):
logprob, value = self.net(self.state)
action = self.agent.get_action(logprob)
action = self.agent(self.state, self.device)[0]

next_state, reward, done, _ = self.env.step(action[0])
next_state, reward, done, _ = self.env.step(action)

self.batch_rewards.append(reward)
self.batch_actions.append(action)
self.batch_logprobs.append(logprob)
self.batch_values.append(value)
self.batch_states.append(self.state)
self.batch_masks.append(done)
self.state = next_state
@@ -150,15 +154,15 @@ def train_batch(self, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[to
self.episode_reward = 0
self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len:]))

returns = self.compute_returns(self.batch_rewards, self.batch_dones, self.batch_values[-1])
_, last_value = self.forward(self.state)

for idx in range(len(self.batch_actions)):
yield self.batch_states[idx], self.batch_actions[idx], returns[idx], self.batch_values[idx], self.batch_logprobs[idx]
returns = self.compute_returns(self.batch_rewards, self.batch_masks, last_value)
for idx in range(self.batch_size):
yield self.batch_states[idx], self.batch_actions[idx], returns[idx]

self.batch_states = []
self.batch_actions = []
self.batch_values = []
self.batch_logprobs = []
self.batch_rewards = []
self.batch_masks = []

def compute_returns(self, rewards, dones, last_value):
@@ -168,7 +172,7 @@ def compute_returns(self, rewards, dones, last_value):
Args:
rewards: list of batched rewards
dones: list of done masks
last_value: the predicted value for the last state
last_value: the predicted value for the last state (for bootstrap)

Returns:
list of discounted rewards
@@ -183,11 +187,24 @@ def compute_returns(self, rewards, dones, last_value):
reward = r + self.gamma * reward * (1 - d)
returns.append(reward)

# reverse list and stop the gradients
returns = torch.tensor(returns[::-1])

return returns

def loss(self, states, actions, returns, values, logprobs):
def loss(self, states, actions, returns):
"""
Calculates the loss for A2C which is a weighted sum of
actor loss (MSE), critic loss (PG), entropy (for exploration)

Args:
states: (batch_size, state dimension)
actions: (batch_size, )
returns: (batch_size, )
"""

logprobs, values = self.net(states)

with torch.no_grad():
advs = returns - values
advs = (advs - advs.mean()) / (advs.std() + self.eps)
@@ -197,7 +214,7 @@ def loss(self, states, actions, returns, values, logprobs):
entropy = self.entropy_beta * entropy.sum(1).mean()

# actor loss
logprobs = logprobs.gather(1, actions)
logprobs = logprobs[range(self.batch_size), actions[0]]
actor_loss = -(logprobs * advs).mean()

# critic loss
@@ -207,9 +224,14 @@ def loss(self, states, actions, returns, values, logprobs):
return total_loss

def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict:
states, actions, returns, values, logprobs = batch
"""
Perform one actor-critic update using a batch of data

loss = self.loss(states, actions, returns, values)
Args:
batch: a batch of (states, actions, returns)
"""
states, actions, returns = batch
loss = self.loss(states, actions, returns)

log = {
"episodes": self.done_episodes,
18 changes: 0 additions & 18 deletions pl_bolts/models/rl/common/agents.py
Original file line number Diff line number Diff line change
@@ -169,21 +169,3 @@ def __call__(self, states: torch.Tensor, device: str) -> List[int]:
actions = [np.random.choice(len(prob), p=prob) for prob in prob_np]

return actions

def get_action(self, logprobs: torch.Tensor):
"""
Takes in the current state and returns the action and value based on the agents policy

Args:
logprobs: the actor head output from the network

Returns:
action sampled according to logits
"""
probabilities = logprobs.exp().squeeze(dim=-1)
prob_np = probabilities.data.cpu().numpy()

# take the numpy values and randomly select action based on prob distribution
actions = [np.random.choice(len(prob), p=prob) for prob in prob_np]

return actions
6 changes: 3 additions & 3 deletions pl_bolts/models/rl/common/networks.py
Original file line number Diff line number Diff line change
@@ -110,8 +110,8 @@ def __init__(self, input_shape: Tuple[int], n_actions: int, hidden_size: int = 1
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.actor_head = nn.Linear(hidden_size, n_actions)
self.critic_head = nn.Linear(hidden_size, 1)
def forward(self, x) -> Tuple[Tensor]:

def forward(self, x) -> Tuple[Tensor, Tensor]:
"""
Forward pass through network. Calculates the action logits and the value

@@ -121,7 +121,7 @@ def forward(self, x) -> Tuple[Tensor]:
Returns:
action log probs (logits), value
"""
x = F.relu(self.fc1(x))
x = F.relu(self.fc1(x.float()))
x = F.relu(self.fc2(x))
a = F.log_softmax(self.actor_head(x), dim=-1)
v = self.critic_head(x)