Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ppo_pettingzoo_ma_atari.py #408

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d9b9b11
Update ppo_pettingzoo_ma_atari.py
elliottower Jul 12, 2023
edc79d6
Pre-commit
elliottower Jul 13, 2023
d39da5e
Update PZ version
elliottower Jul 13, 2023
2b2dfce
Update Super
elliottower Jul 13, 2023
6d37313
Run pre-commit --hook-stage manual --all-files
elliottower Jul 13, 2023
0168986
run poetry lock --no-update to fix inconsistencies with versions
elliottower Jul 13, 2023
b7bffe9
re-run pre-commit with --hook-stage manual
elliottower Jul 13, 2023
2c76bb1
Change torch.maximum to torch.logical_or for dones
elliottower Jul 17, 2023
025f491
Use np.logical_or instead of torch (allows subtraction)
elliottower Jul 18, 2023
09f7a7f
Merge remote-tracking branch 'upstream/master' into patch-1
elliottower Jan 18, 2024
16e0764
Finish merge with upstream master
elliottower Jan 18, 2024
928b7b3
Fix SuperSuit to most recent version
elliottower Jan 18, 2024
d7a2aa2
Fix SuperSuit version in poetry lockfile and tinyscaler in pettingzoo…
elliottower Jan 18, 2024
d77cca0
Fix pettingzoo-requirements export (pre-commit hooks)
elliottower Jan 18, 2024
afba4e8
Test updating pettingzoo to new version 1.24.3
elliottower Jan 18, 2024
8671154
Update ma_atari to match regular atari (tyro, minor code style changes)
elliottower Jan 18, 2024
d2cf1a5
pre-commit
elliottower Jan 18, 2024
981bc63
Revert accidentally changed files (zoo and ipynb, which randomly seem…
elliottower Jan 18, 2024
454364d
Revert ipynb change
elliottower Jan 18, 2024
06473b2
Update dead pettingzoo.ml links to Farama foundation links
elliottower Jan 18, 2024
1b725cf
Update to newly release SuperSuit 3.9.2 (minor bugfixes but best to k…
elliottower Jan 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions cleanrl/ppo_pettingzoo_ma_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
from distutils.util import strtobool

import gym
import gymnasium as gym
import numpy as np
import supersuit as ss
import torch
Expand Down Expand Up @@ -156,11 +156,10 @@ def get_action_and_value(self, x, action=None):
env = ss.frame_stack_v1(env, 4)
env = ss.agent_indicator_v0(env, type_only=False)
env = ss.pettingzoo_env_to_vec_env_v1(env)
envs = ss.concat_vec_envs_v1(env, args.num_envs // 2, num_cpus=0, base_class="gym")
envs = ss.concat_vec_envs_v1(env, args.num_envs // 2, num_cpus=0, base_class="gymnasium")
envs.single_observation_space = envs.observation_space
envs.single_action_space = envs.action_space
envs.is_vector_env = True
envs = gym.wrappers.RecordEpisodeStatistics(envs)
if args.capture_video:
envs = gym.wrappers.RecordVideo(envs, f"videos/{run_name}")
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
Expand All @@ -173,14 +172,17 @@ def get_action_and_value(self, x, action=None):
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
terminations = torch.zeros((args.num_steps, args.num_envs)).to(device)
truncations = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)

# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
next_obs = torch.Tensor(envs.reset()).to(device)
next_done = torch.zeros(args.num_envs).to(device)
next_obs, info = envs.reset(seed=args.seed)
next_obs = torch.Tensor(next_obs).to(device)
next_termination = torch.zeros(args.num_envs).to(device)
next_truncation = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size

for update in range(1, num_updates + 1):
Expand All @@ -193,7 +195,8 @@ def get_action_and_value(self, x, action=None):
for step in range(0, args.num_steps):
global_step += 1 * args.num_envs
obs[step] = next_obs
dones[step] = next_done
terminations[step] = next_termination
truncations[step] = next_truncation

# ALGO LOGIC: action logic
with torch.no_grad():
Expand All @@ -203,10 +206,15 @@ def get_action_and_value(self, x, action=None):
logprobs[step] = logprob

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, done, info = envs.step(action.cpu().numpy())
next_obs, reward, termination, truncation, info = envs.step(action.cpu().numpy())
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
next_obs, next_termination, next_truncation = (
torch.Tensor(next_obs).to(device),
torch.Tensor(termination).to(device),
torch.Tensor(truncation).to(device),
)

# TODO: fix this
for idx, item in enumerate(info):
player_idx = idx % 2
if "episode" in item.keys():
Expand All @@ -219,6 +227,8 @@ def get_action_and_value(self, x, action=None):
next_value = agent.get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
next_done = np.logical_or(next_termination, next_truncation)
Copy link

@KaleabTessera KaleabTessera Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is bug here. We should still bootstrap if next_truncation=True :

"you should bootstrap if infos[env_idx]["TimeLimit.truncated"] is True (episode over due to a timeout/truncation) or dones[env_idx] is False (episode not finished)." - stable baselines

So next_done=next_termination and dones=terminations (probs just use next_terminations and terminations directly e.g. nextnonterminal = 1.0 - next_termination ).

To implement this correctly we also need access to terminal_observation from pettingzoo_env_to_vec_env_v1 since we need access to the true terminal obs and not the obs returned by the next restart (the case currently -- so we need infos to provide access to the terminal obs). I have a PR out for this . Then we can implement something like this to do correct bootstrapping for truncating/timeout.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch @KaleabTessera would you be willing to update this branch with the changes? I can give you edit access, I currently have a lot of other obligations from work so don’t have much time for this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh shoot it’s a patch-1 so I don’t know if you can be given access. But if you clone the repo you can make a new branch from this branch and make a new PR if it’s not possible to edit this branch? Or maybe make a PR to update this branch itself. Sorry I can’t help more

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I am doing a refactor at #424 . Gonna try run a whole suite of benchmark soon.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay cool, sorry I remember you gave access to the WandB thing but I've not had time to do it. Probably simplest if you do it anyways, so thanks for that. It may be interesting to compare performance with the AgileRL multi agent atari example https://docs.agilerl.com/en/latest/tutorials/pettingzoo/maddpg.html

I see the issue linked in that PR mentions timeout handling, is that the same as mentioned below with termination vs truncation? Anyways there's anything needed from PettingZoo or SuperSuit's end let me know.

Copy link
Author

@elliottower elliottower Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is bug here. We should still bootstrap if next_truncation=True :

"you should bootstrap if infos[env_idx]["TimeLimit.truncated"] is True (episode over due to a timeout/truncation) or dones[env_idx] is False (episode not finished)." - stable baselines

So next_done=next_termination and dones=terminations (probs just use next_terminations and terminations directly e.g. nextnonterminal = 1.0 - next_termination ).

To implement this correctly we also need access to terminal_observation from pettingzoo_env_to_vec_env_v1 since we need access to the true terminal obs and not the obs returned by the next restart (the case currently -- so we need infos to provide access to the terminal obs). I have a PR out for this . Then we can implement something like this to do correct bootstrapping for truncating/timeout.

Btw, just as an update, the SuperSuit PR linked above has been merged. My only concern with this is that whatever bootstrapping behavior is done here should mirror what is done with the single agent PPO implementations, so this is a question for @vwxyzjn.

My inclination is to keep the logic as it currently is in this PR and address that bootstrapping issue in another PR (maybe @KaleabTessera is interested in doing that? I don't have a whole lot of time to look into it nor am I the best person to do it as I'm not an expert)

dones = np.logical_or(terminations, truncations)
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
Expand Down
Loading