From d9b9b11169f1fc4677fafc3f4b2bf3803a525d77 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Wed, 12 Jul 2023 14:10:21 -0400 Subject: [PATCH 01/20] Update ppo_pettingzoo_ma_atari.py Updates to use Gymnasium and current PettingZoo API --- cleanrl/ppo_pettingzoo_ma_atari.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/cleanrl/ppo_pettingzoo_ma_atari.py b/cleanrl/ppo_pettingzoo_ma_atari.py index bc51c703c..4053c32fa 100644 --- a/cleanrl/ppo_pettingzoo_ma_atari.py +++ b/cleanrl/ppo_pettingzoo_ma_atari.py @@ -6,7 +6,7 @@ import time from distutils.util import strtobool -import gym +import gymnasium as gym import numpy as np import supersuit as ss import torch @@ -156,11 +156,10 @@ def get_action_and_value(self, x, action=None): env = ss.frame_stack_v1(env, 4) env = ss.agent_indicator_v0(env, type_only=False) env = ss.pettingzoo_env_to_vec_env_v1(env) - envs = ss.concat_vec_envs_v1(env, args.num_envs // 2, num_cpus=0, base_class="gym") + envs = ss.concat_vec_envs_v1(env, args.num_envs // 2, num_cpus=0, base_class="gymnasium") envs.single_observation_space = envs.observation_space envs.single_action_space = envs.action_space envs.is_vector_env = True - envs = gym.wrappers.RecordEpisodeStatistics(envs) if args.capture_video: envs = gym.wrappers.RecordVideo(envs, f"videos/{run_name}") assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" @@ -173,14 +172,17 @@ def get_action_and_value(self, x, action=None): actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) - dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + terminations = torch.zeros((args.num_steps, args.num_envs)).to(device) + truncations = torch.zeros((args.num_steps, args.num_envs)).to(device) values = torch.zeros((args.num_steps, args.num_envs)).to(device) # TRY NOT TO MODIFY: start the game global_step = 0 start_time = time.time() - next_obs = torch.Tensor(envs.reset()).to(device) - next_done = torch.zeros(args.num_envs).to(device) + next_obs, info = envs.reset(seed=args.seed) + next_obs = torch.Tensor(next_obs).to(device) + next_termination = torch.zeros(args.num_envs).to(device) + next_truncation = torch.zeros(args.num_envs).to(device) num_updates = args.total_timesteps // args.batch_size for update in range(1, num_updates + 1): @@ -193,7 +195,8 @@ def get_action_and_value(self, x, action=None): for step in range(0, args.num_steps): global_step += 1 * args.num_envs obs[step] = next_obs - dones[step] = next_done + terminations[step] = next_termination + truncations[step] = next_truncation # ALGO LOGIC: action logic with torch.no_grad(): @@ -203,10 +206,11 @@ def get_action_and_value(self, x, action=None): logprobs[step] = logprob # TRY NOT TO MODIFY: execute the game and log data. - next_obs, reward, done, info = envs.step(action.cpu().numpy()) + next_obs, reward, termination, truncation, info = envs.step(action.cpu().numpy()) rewards[step] = torch.tensor(reward).to(device).view(-1) - next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) + next_obs, next_termination, next_truncation = torch.Tensor(next_obs).to(device), torch.Tensor(termination).to(device), torch.Tensor(truncation).to(device) + # TODO: fix this for idx, item in enumerate(info): player_idx = idx % 2 if "episode" in item.keys(): @@ -219,6 +223,8 @@ def get_action_and_value(self, x, action=None): next_value = agent.get_value(next_obs).reshape(1, -1) advantages = torch.zeros_like(rewards).to(device) lastgaelam = 0 + next_done = torch.maximum(next_termination, next_truncation) + dones = torch.maximum(terminations, truncations) for t in reversed(range(args.num_steps)): if t == args.num_steps - 1: nextnonterminal = 1.0 - next_done From edc79d64675a9fa99eb30828292d226935a331a8 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Wed, 12 Jul 2023 23:05:32 -0400 Subject: [PATCH 02/20] Pre-commit --- cleanrl/ppo_pettingzoo_ma_atari.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cleanrl/ppo_pettingzoo_ma_atari.py b/cleanrl/ppo_pettingzoo_ma_atari.py index 4053c32fa..71c69d0d3 100644 --- a/cleanrl/ppo_pettingzoo_ma_atari.py +++ b/cleanrl/ppo_pettingzoo_ma_atari.py @@ -208,7 +208,11 @@ def get_action_and_value(self, x, action=None): # TRY NOT TO MODIFY: execute the game and log data. next_obs, reward, termination, truncation, info = envs.step(action.cpu().numpy()) rewards[step] = torch.tensor(reward).to(device).view(-1) - next_obs, next_termination, next_truncation = torch.Tensor(next_obs).to(device), torch.Tensor(termination).to(device), torch.Tensor(truncation).to(device) + next_obs, next_termination, next_truncation = ( + torch.Tensor(next_obs).to(device), + torch.Tensor(termination).to(device), + torch.Tensor(truncation).to(device), + ) # TODO: fix this for idx, item in enumerate(info): From d39da5e5222d02a59687750a56c4076fe333d9cf Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Thu, 13 Jul 2023 10:26:08 -0400 Subject: [PATCH 03/20] Update PZ version --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9bea94d94..1c5f1e932 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ flax = {version = "^0.6.0", optional = true} optuna = {version = "^3.0.1", optional = true} optuna-dashboard = {version = "^0.7.2", optional = true} envpool = {version = "^0.6.4", optional = true} -PettingZoo = {version = "1.18.1", optional = true} +PettingZoo = {version = "^1.23.0", optional = true} SuperSuit = {version = "3.4.0", optional = true} multi-agent-ale-py = {version = "0.1.11", optional = true} boto3 = {version = "^1.24.70", optional = true} @@ -105,4 +105,4 @@ qdagger_dqn_atari_impalacnn = [ qdagger_dqn_atari_jax_impalacnn = [ "ale-py", "AutoROM", "opencv-python", # atari "jax", "jaxlib", "flax", # jax -] \ No newline at end of file +] From 2b2dfcec0faee9276051b1bbab1a2f71e7f357b3 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Thu, 13 Jul 2023 10:29:59 -0400 Subject: [PATCH 04/20] Update Super --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1c5f1e932..5bbd45014 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,11 +43,11 @@ optuna = {version = "^3.0.1", optional = true} optuna-dashboard = {version = "^0.7.2", optional = true} envpool = {version = "^0.6.4", optional = true} PettingZoo = {version = "^1.23.0", optional = true} -SuperSuit = {version = "3.4.0", optional = true} +SuperSuit = {version = "^3.8.1", optional = true} multi-agent-ale-py = {version = "0.1.11", optional = true} boto3 = {version = "^1.24.70", optional = true} awscli = {version = "^1.25.71", optional = true} -shimmy = {version = ">=1.0.0", extras = ["dm-control"], optional = true} +shimmy = {version = ">=1.1.0", extras = ["dm-control"], optional = true} [tool.poetry.group.dev.dependencies] pre-commit = "^2.20.0" From 6d373131117e636011fb893ab621fc87f6cf36d3 Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 13 Jul 2023 10:57:45 -0400 Subject: [PATCH 05/20] Run pre-commit --hook-stage manual --all-files --- requirements/requirements-dm_control.txt | 8 -------- requirements/requirements-pettingzoo.txt | 3 --- 2 files changed, 11 deletions(-) diff --git a/requirements/requirements-dm_control.txt b/requirements/requirements-dm_control.txt index db4ed8cbf..6d01d9d96 100644 --- a/requirements/requirements-dm_control.txt +++ b/requirements/requirements-dm_control.txt @@ -9,9 +9,6 @@ colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" commonmark==0.9.1 ; python_full_version >= "3.7.1" and python_version < "3.11" cycler==0.11.0 ; python_full_version >= "3.7.1" and python_version < "3.11" decorator==4.4.2 ; python_full_version >= "3.7.1" and python_version < "3.11" -dm-control==1.0.11 ; python_full_version >= "3.7.1" and python_version < "3.11" -dm-env==1.6 ; python_full_version >= "3.7.1" and python_version < "3.11" -dm-tree==0.1.8 ; python_full_version >= "3.7.1" and python_version < "3.11" docker-pycreds==0.4.0 ; python_full_version >= "3.7.1" and python_version < "3.11" farama-notifications==0.0.4 ; python_full_version >= "3.7.1" and python_version < "3.11" filelock==3.12.0 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -25,7 +22,6 @@ grpcio==1.54.0 ; python_full_version >= "3.7.1" and python_version < "3.11" gym-notices==0.0.8 ; python_full_version >= "3.7.1" and python_version < "3.11" gym==0.23.1 ; python_full_version >= "3.7.1" and python_version < "3.11" gymnasium==0.28.1 ; python_full_version >= "3.7.1" and python_version < "3.11" -h5py==3.8.0 ; python_full_version >= "3.7.1" and python_version < "3.11" huggingface-hub==0.11.1 ; python_full_version >= "3.7.1" and python_version < "3.11" idna==3.4 ; python_full_version >= "3.7.1" and python_version < "3.11" imageio-ffmpeg==0.3.0 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -33,8 +29,6 @@ imageio==2.28.1 ; python_full_version >= "3.7.1" and python_version < "3.11" importlib-metadata==5.2.0 ; python_full_version >= "3.7.1" and python_version < "3.10" jax-jumpy==1.0.0 ; python_full_version >= "3.7.1" and python_version < "3.11" kiwisolver==1.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" -labmaze==1.0.6 ; python_full_version >= "3.7.1" and python_version < "3.11" -lxml==4.9.2 ; python_full_version >= "3.7.1" and python_version < "3.11" markdown==3.3.7 ; python_full_version >= "3.7.1" and python_version < "3.11" markupsafe==2.1.2 ; python_full_version >= "3.7.1" and python_version < "3.11" matplotlib==3.5.3 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -62,11 +56,9 @@ requests-oauthlib==1.3.1 ; python_full_version >= "3.7.1" and python_version < " requests==2.30.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rich==11.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rsa==4.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" -scipy==1.7.3 ; python_full_version >= "3.7.1" and python_version < "3.11" sentry-sdk==1.22.2 ; python_full_version >= "3.7.1" and python_version < "3.11" setproctitle==1.3.2 ; python_full_version >= "3.7.1" and python_version < "3.11" setuptools==67.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" -shimmy[dm-control]==1.0.0 ; python_full_version >= "3.7.1" and python_version < "3.11" six==1.16.0 ; python_full_version >= "3.7.1" and python_version < "3.11" smmap==5.0.0 ; python_full_version >= "3.7.1" and python_version < "3.11" stable-baselines3==1.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" diff --git a/requirements/requirements-pettingzoo.txt b/requirements/requirements-pettingzoo.txt index 91af37e12..c3dc9d9ba 100644 --- a/requirements/requirements-pettingzoo.txt +++ b/requirements/requirements-pettingzoo.txt @@ -38,7 +38,6 @@ oauthlib==3.2.2 ; python_full_version >= "3.7.1" and python_version < "3.11" packaging==23.1 ; python_full_version >= "3.7.1" and python_version < "3.11" pandas==1.3.5 ; python_full_version >= "3.7.1" and python_version < "3.11" pathtools==0.1.2 ; python_full_version >= "3.7.1" and python_version < "3.11" -pettingzoo==1.18.1 ; python_full_version >= "3.7.1" and python_version < "3.11" pillow==9.5.0 ; python_full_version >= "3.7.1" and python_version < "3.11" proglog==0.1.10 ; python_full_version >= "3.7.1" and python_version < "3.11" protobuf==3.20.3 ; python_version < "3.11" and python_full_version >= "3.7.1" @@ -61,12 +60,10 @@ setuptools==67.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" six==1.16.0 ; python_full_version >= "3.7.1" and python_version < "3.11" smmap==5.0.0 ; python_full_version >= "3.7.1" and python_version < "3.11" stable-baselines3==1.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" -supersuit==3.4.0 ; python_full_version >= "3.7.1" and python_version < "3.11" tenacity==8.2.2 ; python_full_version >= "3.7.1" and python_version < "3.11" tensorboard-data-server==0.6.1 ; python_full_version >= "3.7.1" and python_version < "3.11" tensorboard-plugin-wit==1.8.1 ; python_full_version >= "3.7.1" and python_version < "3.11" tensorboard==2.11.2 ; python_full_version >= "3.7.1" and python_version < "3.11" -tinyscaler==1.2.5 ; python_full_version >= "3.7.1" and python_version < "3.11" torch==1.12.1 ; python_full_version >= "3.7.1" and python_version < "3.11" tqdm==4.65.0 ; python_full_version >= "3.7.1" and python_version < "3.11" typing-extensions==4.5.0 ; python_full_version >= "3.7.1" and python_version < "3.11" From 01689866a632d810bd507edc8aa0207e40cd9bc4 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Thu, 13 Jul 2023 11:43:28 -0400 Subject: [PATCH 06/20] run poetry lock --no-update to fix inconsistencies with versions --- poetry.lock | 273 ++++++++-------------------------------------------- 1 file changed, 39 insertions(+), 234 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0664aba35..3d00b35be 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,10 +1,9 @@ -# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "absl-py" version = "1.4.0" description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -16,7 +15,6 @@ files = [ name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -31,7 +29,6 @@ frozenlist = ">=1.1.0" name = "ale-py" version = "0.7.4" description = "The Arcade Learning Environment (ALE) - a platform for AI research." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -64,7 +61,6 @@ test = ["gym", "pytest"] name = "alembic" version = "1.10.4" description = "A database migration tool for SQLAlchemy." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -86,7 +82,6 @@ tz = ["python-dateutil"] name = "antlr4-python3-runtime" version = "4.9.3" description = "ANTLR 4.9.3 runtime for Python 3.7" -category = "dev" optional = false python-versions = "*" files = [ @@ -97,7 +92,6 @@ files = [ name = "appdirs" version = "1.4.4" description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -category = "main" optional = false python-versions = "*" files = [ @@ -109,7 +103,6 @@ files = [ name = "attrs" version = "23.1.0" description = "Classes Without Boilerplate" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -131,7 +124,6 @@ tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pyte name = "autorom" version = "0.4.2" description = "Automated installation of Atari ROMs for Gym/ALE-Py" -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -153,7 +145,6 @@ accept-rom-license = ["AutoROM.accept-rom-license"] name = "autorom-accept-rom-license" version = "0.6.1" description = "Automated installation of Atari ROMs for Gym/ALE-Py" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -172,7 +163,6 @@ tests = ["ale_py", "multi_agent_ale_py"] name = "awscli" version = "1.27.132" description = "Universal Command Line Environment for AWS." -category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -192,7 +182,6 @@ s3transfer = ">=0.6.0,<0.7.0" name = "bitmath" version = "1.3.3.1" description = "Pythonic module for representing and manipulating file sizes with different prefix notations (file size unit conversion)" -category = "main" optional = true python-versions = "*" files = [ @@ -203,7 +192,6 @@ files = [ name = "boto3" version = "1.26.132" description = "The AWS SDK for Python" -category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -223,7 +211,6 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] name = "botocore" version = "1.29.132" description = "Low-level, data-driven core of boto 3." -category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -243,7 +230,6 @@ crt = ["awscrt (==0.16.9)"] name = "bottle" version = "0.12.25" description = "Fast and simple WSGI-framework for small web-applications." -category = "main" optional = true python-versions = "*" files = [ @@ -255,7 +241,6 @@ files = [ name = "cached-property" version = "1.5.2" description = "A decorator for caching properties in classes." -category = "main" optional = true python-versions = "*" files = [ @@ -267,7 +252,6 @@ files = [ name = "cachetools" version = "5.3.0" description = "Extensible memoizing collections and decorators" -category = "main" optional = false python-versions = "~=3.7" files = [ @@ -279,7 +263,6 @@ files = [ name = "certifi" version = "2023.5.7" description = "Python package for providing Mozilla's CA Bundle." -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -291,7 +274,6 @@ files = [ name = "cffi" version = "1.15.1" description = "Foreign Function Interface for Python calling C code." -category = "main" optional = true python-versions = "*" files = [ @@ -368,7 +350,6 @@ pycparser = "*" name = "cfgv" version = "3.3.1" description = "Validate configuration and produce human readable error messages." -category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -380,7 +361,6 @@ files = [ name = "chardet" version = "4.0.0" description = "Universal encoding detector for Python 2 and 3" -category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -392,7 +372,6 @@ files = [ name = "charset-normalizer" version = "3.1.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -477,7 +456,6 @@ files = [ name = "chex" version = "0.1.5" description = "Chex: Testing made fun, in JAX!" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -497,7 +475,6 @@ toolz = ">=0.9.0" name = "click" version = "8.1.3" description = "Composable command line interface toolkit" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -513,7 +490,6 @@ importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} name = "cloudpickle" version = "2.2.1" description = "Extended pickling support for Python objects" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -525,7 +501,6 @@ files = [ name = "cmaes" version = "0.9.1" description = "Lightweight Covariance Matrix Adaptation Evolution Strategy (CMA-ES) implementation for Python 3." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -543,7 +518,6 @@ cmawm = ["scipy"] name = "colorama" version = "0.4.4" description = "Cross-platform colored terminal text." -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -555,7 +529,6 @@ files = [ name = "colorlog" version = "6.7.0" description = "Add colours to the output of Python's logging module." -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -573,7 +546,6 @@ development = ["black", "flake8", "mypy", "pytest", "types-colorama"] name = "commonmark" version = "0.9.1" description = "Python parser for the CommonMark Markdown spec" -category = "main" optional = false python-versions = "*" files = [ @@ -588,7 +560,6 @@ test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"] name = "cycler" version = "0.11.0" description = "Composable style cycles" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -600,7 +571,6 @@ files = [ name = "cython" version = "0.29.34" description = "The Cython compiler for writing C extensions for the Python language." -category = "main" optional = true python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -650,7 +620,6 @@ files = [ name = "dataclasses" version = "0.6" description = "A backport of the dataclasses module for Python 3.6" -category = "main" optional = true python-versions = "*" files = [ @@ -662,7 +631,6 @@ files = [ name = "decorator" version = "4.4.2" description = "Decorators for Humans" -category = "main" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*" files = [ @@ -674,7 +642,6 @@ files = [ name = "dill" version = "0.3.6" description = "serialize all of python" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -689,7 +656,6 @@ graph = ["objgraph (>=1.7.2)"] name = "distlib" version = "0.3.6" description = "Distribution utilities" -category = "dev" optional = false python-versions = "*" files = [ @@ -701,7 +667,6 @@ files = [ name = "dm-control" version = "1.0.11" description = "Continuous control environments and MuJoCo Python bindings." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -733,7 +698,6 @@ hdf5 = ["h5py"] name = "dm-env" version = "1.6" description = "A Python interface for Reinforcement Learning environments." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -750,7 +714,6 @@ numpy = "*" name = "dm-tree" version = "0.1.8" description = "Tree is a library for working with nested data structures." -category = "main" optional = true python-versions = "*" files = [ @@ -799,7 +762,6 @@ files = [ name = "docker-pycreds" version = "0.4.0" description = "Python bindings for the docker credentials store API" -category = "main" optional = false python-versions = "*" files = [ @@ -814,7 +776,6 @@ six = ">=1.4.0" name = "docutils" version = "0.16" description = "Docutils -- Python Documentation Utilities" -category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ @@ -826,7 +787,6 @@ files = [ name = "enum-tools" version = "0.9.0.post1" description = "Tools to expand Python's enum module." -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -846,7 +806,6 @@ sphinx = ["sphinx (>=3.2.0)", "sphinx-toolbox (>=2.16.0)"] name = "envpool" version = "0.6.6" description = "\"C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.\"" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -869,7 +828,6 @@ typing-extensions = "*" name = "etils" version = "0.9.0" description = "Collection of common python utils" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -898,7 +856,6 @@ lazy-imports = ["etils[ecolab]"] name = "exceptiongroup" version = "1.1.1" description = "Backport of PEP 654 (exception groups)" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -913,7 +870,6 @@ test = ["pytest (>=6)"] name = "expt" version = "0.4.1" description = "EXperiment. Plot. Tabulate." -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -936,7 +892,6 @@ test = ["mock (>=2.0.0)", "pytest (>=5.0)", "pytest-asyncio", "pytest-cov", "ten name = "farama-notifications" version = "0.0.4" description = "Notifications for all Farama Foundation maintained libraries." -category = "main" optional = false python-versions = "*" files = [ @@ -948,7 +903,6 @@ files = [ name = "fasteners" version = "0.15" description = "A python package that provides useful locks." -category = "main" optional = true python-versions = "*" files = [ @@ -964,7 +918,6 @@ six = "*" name = "filelock" version = "3.12.0" description = "A platform independent file lock." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -980,7 +933,6 @@ testing = ["covdefaults (>=2.3)", "coverage (>=7.2.3)", "diff-cover (>=7.5)", "p name = "flax" version = "0.6.4" description = "Flax: A neural network library for JAX designed for flexibility" -category = "main" optional = true python-versions = "*" files = [ @@ -1007,7 +959,6 @@ testing = ["atari-py (==0.2.5)", "clu", "gym (==0.18.3)", "jaxlib", "jraph (>=0. name = "fonttools" version = "4.38.0" description = "Tools to manipulate font files" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1033,7 +984,6 @@ woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] name = "free-mujoco-py" version = "2.1.6" description = "" -category = "main" optional = true python-versions = ">=3.7.1,<3.11" files = [ @@ -1053,7 +1003,6 @@ numpy = ">=1.21.3,<2.0.0" name = "frozenlist" version = "1.3.3" description = "A list-like structure which implements collections.abc.MutableSequence" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1137,7 +1086,6 @@ files = [ name = "ghp-import" version = "2.1.0" description = "Copy your docs directly to the gh-pages branch." -category = "main" optional = true python-versions = "*" files = [ @@ -1155,7 +1103,6 @@ dev = ["flake8", "markdown", "twine", "wheel"] name = "gitdb" version = "4.0.10" description = "Git Object Database" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1170,7 +1117,6 @@ smmap = ">=3.0.1,<6" name = "gitpython" version = "3.1.31" description = "GitPython is a Python library used to interact with Git repositories" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1186,7 +1132,6 @@ typing-extensions = {version = ">=3.7.4.3", markers = "python_version < \"3.8\"" name = "glcontext" version = "2.3.7" description = "Portable OpenGL Context" -category = "main" optional = true python-versions = "*" files = [ @@ -1246,7 +1191,6 @@ files = [ name = "glfw" version = "1.12.0" description = "A ctypes-based wrapper for GLFW3." -category = "main" optional = true python-versions = "*" files = [ @@ -1263,7 +1207,6 @@ files = [ name = "google-auth" version = "2.18.0" description = "Google Authentication Library" -category = "main" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*" files = [ @@ -1289,7 +1232,6 @@ requests = ["requests (>=2.20.0,<3.0.0dev)"] name = "google-auth-oauthlib" version = "0.4.6" description = "Google Authentication Library" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -1308,7 +1250,6 @@ tool = ["click (>=6.0.0)"] name = "graphviz" version = "0.20.1" description = "Simple Python interface for Graphviz" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1325,7 +1266,6 @@ test = ["coverage", "mock (>=4)", "pytest (>=7)", "pytest-cov", "pytest-mock (>= name = "greenlet" version = "2.0.2" description = "Lightweight in-process concurrent programming" -category = "main" optional = true python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*" files = [ @@ -1399,7 +1339,6 @@ test = ["objgraph", "psutil"] name = "grpcio" version = "1.54.0" description = "HTTP/2-based RPC framework" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1457,7 +1396,6 @@ protobuf = ["grpcio-tools (>=1.54.0)"] name = "gym" version = "0.23.1" description = "Gym: A universal API for reinforcement learning environments" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1485,7 +1423,6 @@ toy-text = ["pygame (==2.1.0)", "scipy (>=1.4.1)"] name = "gym-notices" version = "0.0.8" description = "Notices for gym" -category = "main" optional = false python-versions = "*" files = [ @@ -1497,7 +1434,6 @@ files = [ name = "gym3" version = "0.3.3" description = "Vectorized Reinforcement Learning Environment Interface" -category = "main" optional = true python-versions = ">=3.6.0" files = [ @@ -1519,7 +1455,6 @@ test = ["gym (==0.17.2)", "gym-retro (==0.8.0)", "mpi4py (==3.0.3)", "pytest (== name = "gymnasium" version = "0.28.1" description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1537,9 +1472,9 @@ typing-extensions = ">=4.3.0" [package.extras] accept-rom-license = ["autorom[accept-rom-license] (>=0.4.2,<0.5.0)"] -all = ["box2d-py (==2.3.5)", "imageio (>=2.14.1)", "jax (==0.3.24)", "jaxlib (==0.3.24)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (>=2.3.2)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (==2.1.3)", "shimmy[atari] (>=0.1.0,<1.0)", "swig (>=4.0.0,<5.0.0)", "torch (>=1.0.0)"] +all = ["box2d-py (==2.3.5)", "imageio (>=2.14.1)", "jax (==0.3.24)", "jaxlib (==0.3.24)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (>=2.3.2)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (==2.1.3)", "shimmy[atari] (>=0.1.0,<1.0)", "swig (==4.*)", "torch (>=1.0.0)"] atari = ["shimmy[atari] (>=0.1.0,<1.0)"] -box2d = ["box2d-py (==2.3.5)", "pygame (==2.1.3)", "swig (>=4.0.0,<5.0.0)"] +box2d = ["box2d-py (==2.3.5)", "pygame (==2.1.3)", "swig (==4.*)"] classic-control = ["pygame (==2.1.3)", "pygame (==2.1.3)"] jax = ["jax (==0.3.24)", "jaxlib (==0.3.24)"] mujoco = ["imageio (>=2.14.1)", "mujoco (>=2.3.2)"] @@ -1552,7 +1487,6 @@ toy-text = ["pygame (==2.1.3)", "pygame (==2.1.3)"] name = "h5py" version = "3.8.0" description = "Read and write HDF5 files from Python" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1590,7 +1524,6 @@ numpy = ">=1.14.5" name = "hbutils" version = "0.8.6" description = "Some useful functions and classes in Python infrastructure development." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1613,7 +1546,6 @@ test = ["click (>=7.0.0)", "coverage (>=5)", "easydict (>=1.7,<2)", "faker", "fl name = "huggingface-hub" version = "0.11.1" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -1645,7 +1577,6 @@ typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "t name = "hydra-core" version = "1.3.2" description = "A framework for elegantly configuring complex applications" -category = "dev" optional = false python-versions = "*" files = [ @@ -1654,7 +1585,7 @@ files = [ ] [package.dependencies] -antlr4-python3-runtime = ">=4.9.0,<4.10.0" +antlr4-python3-runtime = "==4.9.*" importlib-resources = {version = "*", markers = "python_version < \"3.9\""} omegaconf = ">=2.2,<2.4" packaging = "*" @@ -1663,7 +1594,6 @@ packaging = "*" name = "identify" version = "2.5.24" description = "File identification library for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1678,7 +1608,6 @@ license = ["ukkonen"] name = "idna" version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" -category = "main" optional = false python-versions = ">=3.5" files = [ @@ -1690,7 +1619,6 @@ files = [ name = "imageio" version = "2.28.1" description = "Library for reading and writing a wide range of image, video, scientific, and volumetric data formats." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1722,7 +1650,6 @@ tifffile = ["tifffile"] name = "imageio-ffmpeg" version = "0.3.0" description = "FFMPEG wrapper for Python" -category = "main" optional = false python-versions = "*" files = [ @@ -1737,7 +1664,6 @@ files = [ name = "importlib-metadata" version = "5.2.0" description = "Read metadata from Python packages" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1758,7 +1684,6 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "importlib-resources" version = "5.12.0" description = "Read resources from Python packages" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1777,7 +1702,6 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1789,7 +1713,6 @@ files = [ name = "isaacgym" version = "1.0.preview4" description = "" -category = "dev" optional = false python-versions = ">=3.7.1" files = [] @@ -1814,7 +1737,6 @@ url = "cleanrl/ppo_continuous_action_isaacgym/isaacgym" name = "isaacgymenvs" version = "0.1.0" description = "" -category = "dev" optional = false python-versions = ">=3.7.1,<3.10" files = [] @@ -1839,7 +1761,6 @@ resolved_reference = "27cc130a811b2305056c2f03f5f4cc0819b7867c" name = "jax" version = "0.3.25" description = "Differentiate, compile, and transform Numpy code." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1866,7 +1787,6 @@ tpu = ["jaxlib (==0.3.25)", "libtpu-nightly (==0.1.dev20221109)", "requests"] name = "jax-jumpy" version = "1.0.0" description = "Common backend for Jax or Numpy." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1885,7 +1805,6 @@ testing = ["pytest (==7.1.3)"] name = "jaxlib" version = "0.3.25" description = "XLA library for JAX" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1913,7 +1832,6 @@ scipy = ">=1.5" name = "jinja2" version = "3.1.2" description = "A very fast and expressive template engine." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1931,7 +1849,6 @@ i18n = ["Babel (>=2.7)"] name = "jmespath" version = "1.0.1" description = "JSON Matching Expressions" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1943,7 +1860,6 @@ files = [ name = "joblib" version = "1.2.0" description = "Lightweight pipelining with Python functions" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -1955,7 +1871,6 @@ files = [ name = "jsonschema" version = "4.17.3" description = "An implementation of JSON Schema validation for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1979,7 +1894,6 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "kiwisolver" version = "1.4.4" description = "A fast implementation of the Cassowary constraint solver" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2060,7 +1974,6 @@ typing-extensions = {version = "*", markers = "python_version < \"3.8\""} name = "labmaze" version = "1.0.6" description = "LabMaze: DeepMind Lab's text maze generator." -category = "main" optional = true python-versions = "*" files = [ @@ -2100,7 +2013,6 @@ setuptools = "!=50.0.0" name = "lxml" version = "4.9.2" description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." -category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, != 3.4.*" files = [ @@ -2193,7 +2105,6 @@ source = ["Cython (>=0.29.7)"] name = "mako" version = "1.2.4" description = "A super-fast templating language that borrows the best ideas from the existing templating languages." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2214,7 +2125,6 @@ testing = ["pytest"] name = "markdown" version = "3.3.7" description = "Python implementation of Markdown." -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2232,7 +2142,6 @@ testing = ["coverage", "pyyaml"] name = "markdown-include" version = "0.7.2" description = "A Python-Markdown extension which provides an 'include' function" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2247,7 +2156,6 @@ markdown = ">=3.0" name = "markupsafe" version = "2.1.2" description = "Safely add untrusted strings to HTML/XML markup." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2307,7 +2215,6 @@ files = [ name = "matplotlib" version = "3.5.3" description = "Python plotting package" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2362,7 +2269,6 @@ python-dateutil = ">=2.7" name = "mergedeep" version = "1.3.4" description = "A deep merge function for 🐍." -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -2374,7 +2280,6 @@ files = [ name = "mkdocs" version = "1.4.3" description = "Project documentation with Markdown." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2404,7 +2309,6 @@ min-versions = ["babel (==2.9.0)", "click (==7.0)", "colorama (==0.4)", "ghp-imp name = "mkdocs-material" version = "8.5.11" description = "Documentation that simply works" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2425,7 +2329,6 @@ requests = ">=2.26" name = "mkdocs-material-extensions" version = "1.1.1" description = "Extension pack for Python Markdown and MkDocs Material." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2437,7 +2340,6 @@ files = [ name = "moderngl" version = "5.8.2" description = "ModernGL: High performance rendering for Python 3" -category = "main" optional = true python-versions = "*" files = [ @@ -2500,7 +2402,6 @@ glcontext = ">=2.3.6,<3" name = "monotonic" version = "1.6" description = "An implementation of time.monotonic() for Python 2 & < 3.3" -category = "main" optional = true python-versions = "*" files = [ @@ -2512,7 +2413,6 @@ files = [ name = "moviepy" version = "1.0.3" description = "Video editing with Python" -category = "main" optional = false python-versions = "*" files = [ @@ -2540,7 +2440,6 @@ test = ["coverage (<5.0)", "coveralls (>=1.1,<2.0)", "pytest (>=3.0.0,<4.0)", "p name = "msgpack" version = "1.0.5" description = "MessagePack serializer" -category = "main" optional = false python-versions = "*" files = [ @@ -2613,7 +2512,6 @@ files = [ name = "mujoco" version = "2.3.3" description = "MuJoCo Physics Simulator" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2654,7 +2552,6 @@ pyopengl = "*" name = "multi-agent-ale-py" version = "0.1.11" description = "Multi-Agent Arcade Learning Environment Python Interface" -category = "main" optional = true python-versions = "*" files = [ @@ -2677,7 +2574,6 @@ numpy = "*" name = "multiprocess" version = "0.70.14" description = "better multiprocessing and multithreading in python" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2704,7 +2600,6 @@ dill = ">=0.3.6" name = "ninja" version = "1.11.1" description = "Ninja is a small build system with a focus on speed" -category = "dev" optional = false python-versions = "*" files = [ @@ -2734,7 +2629,6 @@ test = ["codecov (>=2.0.5)", "coverage (>=4.2)", "flake8 (>=3.0.4)", "pytest (>= name = "nodeenv" version = "1.7.0" description = "Node.js virtual environment builder" -category = "dev" optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ @@ -2749,7 +2643,6 @@ setuptools = "*" name = "numpy" version = "1.21.6" description = "NumPy is the fundamental package for array computing with Python." -category = "main" optional = false python-versions = ">=3.7,<3.11" files = [ @@ -2790,7 +2683,6 @@ files = [ name = "oauthlib" version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2807,7 +2699,6 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] name = "omegaconf" version = "2.3.0" description = "A flexible configuration library" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -2816,14 +2707,13 @@ files = [ ] [package.dependencies] -antlr4-python3-runtime = ">=4.9.0,<4.10.0" +antlr4-python3-runtime = "==4.9.*" PyYAML = ">=5.1.0" [[package]] name = "opencv-python" version = "4.7.0.72" description = "Wrapper package for OpenCV python bindings." -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -2850,7 +2740,6 @@ numpy = [ name = "openrlbenchmark" version = "0.1.1b4" description = "" -category = "main" optional = true python-versions = ">=3.7.1,<4.0.0" files = [ @@ -2874,7 +2763,6 @@ wandb = ">=0.13.7,<0.14.0" name = "opt-einsum" version = "3.3.0" description = "Optimizing numpys einsum function" -category = "main" optional = true python-versions = ">=3.5" files = [ @@ -2893,7 +2781,6 @@ tests = ["pytest", "pytest-cov", "pytest-pep8"] name = "optax" version = "0.1.4" description = "A gradient processing and optimisation library in JAX." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2913,7 +2800,6 @@ typing-extensions = ">=3.10.0" name = "optuna" version = "3.1.1" description = "A hyperparameter optimization framework" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2943,7 +2829,6 @@ test = ["codecov", "fakeredis[lua]", "kaleido", "pytest", "scipy (>=1.9.2)"] name = "optuna-dashboard" version = "0.7.3" description = "Real-time dashboard for Optuna" -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -2962,7 +2847,6 @@ typing-extensions = {version = "*", markers = "python_version < \"3.8\""} name = "orbax" version = "0.1.0" description = "Orbax" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -2990,7 +2874,6 @@ dev = ["pytest-xdist"] name = "packaging" version = "23.1" description = "Core utilities for Python packages" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3002,7 +2885,6 @@ files = [ name = "pandas" version = "1.3.5" description = "Powerful data structures for data analysis, time series, and statistics" -category = "main" optional = false python-versions = ">=3.7.1" files = [ @@ -3038,7 +2920,7 @@ numpy = [ {version = ">=1.20.0", markers = "platform_machine == \"arm64\" and python_version < \"3.10\""}, {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, {version = ">=1.19.2", markers = "platform_machine == \"aarch64\" and python_version < \"3.10\""}, - {version = ">=1.17.3", markers = "platform_machine != \"aarch64\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, + {version = ">=1.17.3", markers = "(platform_machine != \"aarch64\" and platform_machine != \"arm64\") and python_version < \"3.10\""}, ] python-dateutil = ">=2.7.3" pytz = ">=2017.3" @@ -3050,7 +2932,6 @@ test = ["hypothesis (>=3.58)", "pytest (>=6.0)", "pytest-xdist"] name = "pathtools" version = "0.1.2" description = "File system general utilities" -category = "main" optional = false python-versions = "*" files = [ @@ -3059,36 +2940,33 @@ files = [ [[package]] name = "pettingzoo" -version = "1.18.1" -description = "Gym for multi-agent reinforcement learning" -category = "main" +version = "1.23.1" +description = "Gymnasium for multi-agent reinforcement learning." optional = true -python-versions = ">=3.7, <3.11" +python-versions = ">=3.7" files = [ - {file = "PettingZoo-1.18.1-py3-none-any.whl", hash = "sha256:25ae45fcfa2c623800e1f81b98ae50f5f5a1af6caabc5946764248de71a0371d"}, - {file = "PettingZoo-1.18.1.tar.gz", hash = "sha256:7e6a3231dc3fc3801af83fe880f199f570d46a9acdcb990f2a223f121b6e5038"}, + {file = "pettingzoo-1.23.1-py3-none-any.whl", hash = "sha256:2a243b260b1801a3c0e4826f22aec5d94e5c748c8cc091c99953f758f34a1082"}, + {file = "pettingzoo-1.23.1.tar.gz", hash = "sha256:bbf12cbd798fc014043288b82710a8e668317797592f01f17005909926252094"}, ] [package.dependencies] -gym = ">=0.21.0" -numpy = ">=1.18.0" +gymnasium = ">=0.28.0" +numpy = ">=1.21.0" [package.extras] -all = ["box2d-py (==2.3.5)", "chess (==1.7.0)", "hanabi-learning-environment (==0.0.1)", "magent (==0.2.2)", "multi-agent-ale-py (==0.1.11)", "pillow (>=8.0.1)", "pygame (==2.1.0)", "pyglet (>=1.4.0)", "pymunk (==6.2.0)", "rlcard (==1.0.4)", "scipy (>=1.4.1)"] -atari = ["multi-agent-ale-py (==0.1.11)", "pygame (==2.1.0)"] -butterfly = ["pygame (==2.1.0)", "pymunk (==6.2.0)"] -classic = ["chess (==1.7.0)", "hanabi-learning-environment (==0.0.1)", "pygame (==2.1.0)", "rlcard (==1.0.4)"] -magent = ["magent (==0.2.2)"] -mpe = ["pyglet (>=1.4.0)"] +all = ["box2d-py (==2.3.5)", "chess (==1.7.0)", "hanabi-learning-environment (==0.0.4)", "multi-agent-ale-py (==0.1.11)", "pillow (>=8.0.1)", "pygame (==2.3.0)", "pymunk (==6.2.0)", "rlcard (==1.0.5)", "scipy (>=1.4.1)"] +atari = ["multi-agent-ale-py (==0.1.11)", "pygame (==2.3.0)"] +butterfly = ["pygame (==2.3.0)", "pymunk (==6.2.0)"] +classic = ["chess (==1.7.0)", "hanabi-learning-environment (==0.0.4)", "pygame (==2.3.0)", "rlcard (==1.0.5)"] +mpe = ["pygame (==2.3.0)"] other = ["pillow (>=8.0.1)"] -sisl = ["box2d-py (==2.3.5)", "pygame (==2.1.0)", "scipy (>=1.4.1)"] -tests = ["codespell", "flake8", "isort", "pynput", "pytest"] +sisl = ["box2d-py (==2.3.5)", "pygame (==2.3.0)", "scipy (>=1.4.1)"] +testing = ["AutoROM", "pre-commit", "pynput", "pytest", "pytest-cov"] [[package]] name = "pillow" version = "9.5.0" description = "Python Imaging Library (Fork)" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3168,7 +3046,6 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa name = "pip" version = "22.3.1" description = "The PyPA recommended tool for installing Python packages." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3180,7 +3057,6 @@ files = [ name = "pkgutil-resolve-name" version = "1.3.10" description = "Resolve a name to an object." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -3192,7 +3068,6 @@ files = [ name = "platformdirs" version = "3.5.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3211,7 +3086,6 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest- name = "pluggy" version = "1.0.0" description = "plugin and hook calling mechanisms for python" -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -3230,7 +3104,6 @@ testing = ["pytest", "pytest-benchmark"] name = "pre-commit" version = "2.21.0" description = "A framework for managing and maintaining multi-language pre-commit hooks." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3250,7 +3123,6 @@ virtualenv = ">=20.10.0" name = "procgen" version = "0.10.7" description = "Procedurally Generated Game-Like RL Environments" -category = "main" optional = true python-versions = ">=3.6.0" files = [ @@ -3281,7 +3153,6 @@ test = ["pytest (==6.2.5)", "pytest-benchmark (==3.4.1)"] name = "proglog" version = "0.1.10" description = "Log and progress bar manager for console, notebooks, web..." -category = "main" optional = false python-versions = "*" files = [ @@ -3296,7 +3167,6 @@ tqdm = "*" name = "protobuf" version = "3.20.3" description = "Protocol Buffers" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3328,7 +3198,6 @@ files = [ name = "psutil" version = "5.9.5" description = "Cross-platform lib for process and system monitoring in Python." -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3355,7 +3224,6 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] name = "pyasn1" version = "0.5.0" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -3367,7 +3235,6 @@ files = [ name = "pyasn1-modules" version = "0.3.0" description = "A collection of ASN.1-based protocols modules" -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" files = [ @@ -3382,7 +3249,6 @@ pyasn1 = ">=0.4.6,<0.6.0" name = "pycparser" version = "2.21" description = "C parser in Python" -category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3394,7 +3260,6 @@ files = [ name = "pygame" version = "2.1.0" description = "Python Game Development" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -3462,7 +3327,6 @@ files = [ name = "pygments" version = "2.15.1" description = "Pygments is a syntax highlighting package written in Python." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3477,7 +3341,6 @@ plugins = ["importlib-metadata"] name = "pymdown-extensions" version = "9.11" description = "Extension pack for Python Markdown." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3493,7 +3356,6 @@ pyyaml = "*" name = "pyopengl" version = "3.1.6" description = "Standard OpenGL bindings for Python" -category = "main" optional = true python-versions = "*" files = [ @@ -3506,7 +3368,6 @@ files = [ name = "pyparsing" version = "3.0.9" description = "pyparsing module - Classes and methods to define and execute parsing grammars" -category = "main" optional = false python-versions = ">=3.6.8" files = [ @@ -3521,7 +3382,6 @@ diagrams = ["jinja2", "railroad-diagrams"] name = "pyrsistent" version = "0.19.3" description = "Persistent/Functional/Immutable data structures" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -3558,7 +3418,6 @@ files = [ name = "pytest" version = "7.3.1" description = "pytest: simple powerful testing with Python" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -3582,7 +3441,6 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -3597,7 +3455,6 @@ six = ">=1.5" name = "pytimeparse" version = "1.1.8" description = "Time expression parser" -category = "main" optional = true python-versions = "*" files = [ @@ -3609,7 +3466,6 @@ files = [ name = "pytz" version = "2023.3" description = "World timezone definitions, modern and historical" -category = "main" optional = false python-versions = "*" files = [ @@ -3621,7 +3477,6 @@ files = [ name = "pyvirtualdisplay" version = "3.0" description = "python wrapper for Xvfb, Xephyr and Xvnc" -category = "dev" optional = false python-versions = "*" files = [ @@ -3633,7 +3488,6 @@ files = [ name = "pyyaml" version = "5.4.1" description = "YAML parser and emitter for Python" -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -3672,7 +3526,6 @@ files = [ name = "pyyaml-env-tag" version = "0.1" description = "A custom YAML tag for referencing environment variables in YAML files. " -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -3687,7 +3540,6 @@ pyyaml = "*" name = "ray" version = "2.2.0" description = "Ray provides a simple, universal API for building distributed applications." -category = "dev" optional = false python-versions = "*" files = [ @@ -3751,7 +3603,6 @@ tune = ["pandas", "requests", "tabulate", "tensorboardX (>=1.9)"] name = "ray" version = "2.3.1" description = "Ray provides a simple, universal API for building distributed applications." -category = "dev" optional = false python-versions = "*" files = [ @@ -3816,7 +3667,6 @@ tune = ["pandas", "requests", "tabulate", "tensorboardX (>=1.9)"] name = "ray" version = "2.4.0" description = "Ray provides a simple, universal API for building distributed applications." -category = "dev" optional = false python-versions = "*" files = [ @@ -3878,7 +3728,6 @@ tune = ["pandas", "requests", "tabulate", "tensorboardX (>=1.9)"] name = "requests" version = "2.30.0" description = "Python HTTP for Humans." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -3900,7 +3749,6 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] name = "requests-oauthlib" version = "1.3.1" description = "OAuthlib authentication support for Requests." -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ @@ -3919,7 +3767,6 @@ rsa = ["oauthlib[signedtoken] (>=3.0.0)"] name = "rich" version = "11.2.0" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -category = "main" optional = false python-versions = ">=3.6.2,<4.0.0" files = [ @@ -3940,7 +3787,6 @@ jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] name = "rl-games" version = "1.5.2" description = "" -category = "dev" optional = false python-versions = "*" files = [ @@ -3963,7 +3809,6 @@ torch = ">=1.7.0" name = "rsa" version = "4.7.2" description = "Pure-Python RSA implementation" -category = "main" optional = false python-versions = ">=3.5, <4" files = [ @@ -3978,7 +3823,6 @@ pyasn1 = ">=0.1.3" name = "s3transfer" version = "0.6.1" description = "An Amazon S3 Transfer Manager" -category = "main" optional = true python-versions = ">= 3.7" files = [ @@ -3996,7 +3840,6 @@ crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"] name = "scikit-learn" version = "1.0.2" description = "A set of python modules for machine learning and data mining" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4050,7 +3893,6 @@ tests = ["black (>=21.6b0)", "flake8 (>=3.8.2)", "matplotlib (>=2.2.3)", "mypy ( name = "scipy" version = "1.7.3" description = "SciPy: Scientific Library for Python" -category = "main" optional = false python-versions = ">=3.7,<3.11" files = [ @@ -4092,7 +3934,6 @@ numpy = ">=1.16.5,<1.23.0" name = "seaborn" version = "0.12.2" description = "Statistical data visualization" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4115,7 +3956,6 @@ stats = ["scipy (>=1.3)", "statsmodels (>=0.10)"] name = "sentry-sdk" version = "1.22.2" description = "Python client for Sentry (https://sentry.io)" -category = "main" optional = false python-versions = "*" files = [ @@ -4157,7 +3997,6 @@ tornado = ["tornado (>=5)"] name = "setproctitle" version = "1.3.2" description = "A Python module to customize the process title" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4242,7 +4081,6 @@ test = ["pytest"] name = "setuptools" version = "67.7.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4257,14 +4095,13 @@ testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs ( [[package]] name = "shimmy" -version = "1.0.0" +version = "1.1.0" description = "An API conversion tool providing Gymnasium and PettingZoo bindings for popular external reinforcement learning environments." -category = "main" optional = true python-versions = ">=3.7" files = [ - {file = "Shimmy-1.0.0-py3-none-any.whl", hash = "sha256:f26540d595ad56c9d0e99462d6388dc0dbb7976a97095337365ec79668cdf836"}, - {file = "Shimmy-1.0.0.tar.gz", hash = "sha256:30b9473402e846149137d5d71a0fbe47787d309c7e3a0c1aca97c95375de5f26"}, + {file = "Shimmy-1.1.0-py3-none-any.whl", hash = "sha256:0d2f44cdc3384b792336eb54002d23eb8c0ddb67580760e9c4e234fdf6077a69"}, + {file = "Shimmy-1.1.0.tar.gz", hash = "sha256:028ff42861fd8fa168927631f8f8cb2bda4ffef67e65633c51bf3116792e1f88"}, ] [package.dependencies] @@ -4275,23 +4112,22 @@ imageio = {version = "*", optional = true, markers = "extra == \"dm-control\""} numpy = ">=1.18.0" [package.extras] -all = ["ale-py (>=0.8.1,<0.9.0)", "bsuite (>=0.3.5)", "dm-control (>=1.0.10)", "dm-env (>=1.6)", "gym (>=0.21.0)", "gym (>=0.26.2)", "h5py (>=3.7.0)", "imageio", "open-spiel (>=1.2)", "pettingzoo (>=1.22.3)", "pyglet (==1.5.11)"] +all = ["ale-py (>=0.8.1,<0.9.0)", "bsuite (>=0.3.5)", "dm-control (>=1.0.10)", "dm-env (>=1.6)", "gym (>=0.26.2)", "h5py (>=3.7.0)", "imageio", "open-spiel (>=1.2)", "pettingzoo (>=1.23)"] atari = ["ale-py (>=0.8.1,<0.9.0)"] bsuite = ["bsuite (>=0.3.5)"] dm-control = ["dm-control (>=1.0.10)", "h5py (>=3.7.0)", "imageio"] -dm-control-multi-agent = ["dm-control (>=1.0.10)", "h5py (>=3.7.0)", "imageio", "pettingzoo (>=1.22.3)"] +dm-control-multi-agent = ["dm-control (>=1.0.10)", "h5py (>=3.7.0)", "imageio", "pettingzoo (>=1.23)"] dm-lab = ["dm-env (>=1.6)"] -gym-v21 = ["gym (>=0.21.0)", "pyglet (==1.5.11)"] +gym-v21 = ["gym (>=0.21.0,<0.26)", "pyglet (==1.5.11)"] gym-v26 = ["gym (>=0.26.2)"] -meltingpot = ["pettingzoo (>=1.22.3)"] -openspiel = ["open-spiel (>=1.2)", "pettingzoo (>=1.22.3)"] +meltingpot = ["pettingzoo (>=1.23)"] +openspiel = ["open-spiel (>=1.2)", "pettingzoo (>=1.23)"] testing = ["autorom[accept-rom-license] (>=0.6.0,<0.7.0)", "pillow (>=9.3.0)", "pytest (==7.1.3)"] [[package]] name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -4303,7 +4139,6 @@ files = [ name = "smmap" version = "5.0.0" description = "A pure Python implementation of a sliding window memory map manager" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4315,7 +4150,6 @@ files = [ name = "sqlalchemy" version = "2.0.13" description = "Database Abstraction Library" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4363,7 +4197,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""} importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} typing-extensions = ">=4.2.0" @@ -4394,7 +4228,6 @@ sqlcipher = ["sqlcipher3-binary"] name = "stable-baselines3" version = "1.2.0" description = "Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms." -category = "main" optional = false python-versions = "*" files = [ @@ -4417,26 +4250,24 @@ tests = ["black", "flake8 (>=3.8)", "flake8-bugbear", "isort (>=5.0)", "pytest", [[package]] name = "supersuit" -version = "3.4.0" -description = "Wrappers for Gym and PettingZoo" -category = "main" +version = "3.8.1" +description = "Wrappers for Gymnasium and PettingZoo" optional = true -python-versions = ">=3.7" +python-versions = ">=3.7, <3.12" files = [ - {file = "SuperSuit-3.4.0-py3-none-any.whl", hash = "sha256:45b541b2b29faffd6494b53d649c8d94889966f407fd380b3e3211f9e68a49e9"}, - {file = "SuperSuit-3.4.0.tar.gz", hash = "sha256:5999beec8d7923c11c9511eaa9dec8a38269cb0d7af029e17903c79234233409"}, + {file = "SuperSuit-3.8.1-py3-none-any.whl", hash = "sha256:e9b4bfb3f95b433ecff4eca72fae2ecf2768395bd2668d88fcfefcd08ccea3b5"}, + {file = "SuperSuit-3.8.1.tar.gz", hash = "sha256:80bf18bf4c74676e2a498bd88ac056af5b586b197663cd888c6c3dd6e80eda5e"}, ] [package.dependencies] -gym = ">=0.22.0" -pettingzoo = ">=1.15.0" -tinyscaler = ">=1.0.4" +gymnasium = ">=0.26.0" +numpy = ">=1.19.0" +tinyscaler = ">=1.2.5" [[package]] name = "tabulate" version = "0.9.0" description = "Pretty-print tabular data" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4451,7 +4282,6 @@ widechars = ["wcwidth"] name = "tenacity" version = "8.2.2" description = "Retry code until it succeeds" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4466,7 +4296,6 @@ doc = ["reno", "sphinx", "tornado (>=4.5)"] name = "tensorboard" version = "2.11.2" description = "TensorBoard lets you watch Tensors Flow" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4492,7 +4321,6 @@ wheel = ">=0.26" name = "tensorboard-data-server" version = "0.6.1" description = "Fast data loading for TensorBoard" -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4505,7 +4333,6 @@ files = [ name = "tensorboard-plugin-wit" version = "1.8.1" description = "What-If Tool TensorBoard plugin." -category = "main" optional = false python-versions = "*" files = [ @@ -4516,7 +4343,6 @@ files = [ name = "tensorboardx" version = "2.6" description = "TensorBoardX lets you watch Tensors Flow without Tensorflow" -category = "dev" optional = false python-versions = "*" files = [ @@ -4533,7 +4359,6 @@ protobuf = ">=3.8.0,<4" name = "tensorstore" version = "0.1.28" description = "Read and write large, multi-dimensional arrays" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4562,7 +4387,6 @@ numpy = ">=1.16.0" name = "termcolor" version = "1.1.0" description = "ANSII Color formatting for output in terminal." -category = "dev" optional = false python-versions = "*" files = [ @@ -4573,7 +4397,6 @@ files = [ name = "threadpoolctl" version = "3.1.0" description = "threadpoolctl" -category = "main" optional = true python-versions = ">=3.6" files = [ @@ -4585,7 +4408,6 @@ files = [ name = "tinyscaler" version = "1.2.5" description = "A tiny, simple image scaler" -category = "main" optional = true python-versions = ">=3.7, <3.11" files = [ @@ -4603,7 +4425,6 @@ numpy = "*" name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4615,7 +4436,6 @@ files = [ name = "toolz" version = "0.12.0" description = "List processing tools and functional utilities" -category = "main" optional = true python-versions = ">=3.5" files = [ @@ -4627,7 +4447,6 @@ files = [ name = "torch" version = "1.12.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" -category = "main" optional = false python-versions = ">=3.7.0" files = [ @@ -4660,7 +4479,6 @@ typing-extensions = "*" name = "torchvision" version = "0.13.1" description = "image and video datasets and models for torch deep learning" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4687,7 +4505,7 @@ files = [ [package.dependencies] numpy = "*" -pillow = ">=5.3.0,<8.3.0 || >=8.4.0" +pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" requests = "*" torch = "1.12.1" typing-extensions = "*" @@ -4699,7 +4517,6 @@ scipy = ["scipy"] name = "tqdm" version = "4.65.0" description = "Fast, Extensible Progress Meter" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4720,7 +4537,6 @@ telegram = ["requests"] name = "treevalue" version = "1.4.10" description = "A flexible, generalized tree-based data structure." -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4774,7 +4590,6 @@ test = ["coverage (>=5)", "easydict (>=1.7,<2)", "flake8 (>=3.5,<4.0)", "hbutils name = "tueplots" version = "0.0.4" description = "Scientific plotting made easy" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4793,7 +4608,6 @@ examples = ["jupyter"] name = "typeguard" version = "2.13.3" description = "Run-time type checker for Python" -category = "main" optional = true python-versions = ">=3.5.3" files = [ @@ -4809,7 +4623,6 @@ test = ["mypy", "pytest", "typing-extensions"] name = "types-protobuf" version = "4.23.0.1" description = "Typing stubs for protobuf" -category = "main" optional = true python-versions = "*" files = [ @@ -4821,7 +4634,6 @@ files = [ name = "typing-extensions" version = "4.5.0" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4833,7 +4645,6 @@ files = [ name = "urllib3" version = "1.26.15" description = "HTTP library with thread-safe connection pooling, file post, and more." -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" files = [ @@ -4850,7 +4661,6 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] name = "virtualenv" version = "20.21.0" description = "Virtual Python Environment builder" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -4872,7 +4682,6 @@ test = ["covdefaults (>=2.2.2)", "coverage (>=7.1)", "coverage-enable-subprocess name = "wandb" version = "0.13.11" description = "A CLI and library for interacting with the Weights and Biases API." -category = "main" optional = false python-versions = ">=3.6" files = [ @@ -4914,7 +4723,6 @@ sweeps = ["sweeps (>=0.2.0)"] name = "watchdog" version = "3.0.0" description = "Filesystem events monitoring" -category = "main" optional = true python-versions = ">=3.7" files = [ @@ -4954,7 +4762,6 @@ watchmedo = ["PyYAML (>=3.10)"] name = "werkzeug" version = "2.2.3" description = "The comprehensive WSGI web application library." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4972,7 +4779,6 @@ watchdog = ["watchdog"] name = "wheel" version = "0.40.0" description = "A built-package format for Python" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -4987,7 +4793,6 @@ test = ["pytest (>=6.0.0)"] name = "zipp" version = "3.15.0" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -5028,4 +4833,4 @@ qdagger-dqn-atari-jax-impalacnn = ["AutoROM", "ale-py", "flax", "jax", "jaxlib", [metadata] lock-version = "2.0" python-versions = ">=3.7.1,<3.11" -content-hash = "83763cefd7c948380a16349ea5ec80fd36816adace1f8101bc5a50fd686e5a81" +content-hash = "da2ec1d2aebe6ca270c122e0b3bedeeba36213a687b09b37af72e58e30b168ea" From b7bffe94d465747e6be4eb26145e38ba56599f89 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Thu, 13 Jul 2023 11:46:02 -0400 Subject: [PATCH 07/20] re-run pre-commit with --hook-stage manual --- requirements/requirements-dm_control.txt | 8 ++++++++ requirements/requirements-optuna.txt | 2 +- requirements/requirements-pettingzoo.txt | 3 +++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/requirements/requirements-dm_control.txt b/requirements/requirements-dm_control.txt index 6d01d9d96..301268c2e 100644 --- a/requirements/requirements-dm_control.txt +++ b/requirements/requirements-dm_control.txt @@ -9,6 +9,9 @@ colorama==0.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" commonmark==0.9.1 ; python_full_version >= "3.7.1" and python_version < "3.11" cycler==0.11.0 ; python_full_version >= "3.7.1" and python_version < "3.11" decorator==4.4.2 ; python_full_version >= "3.7.1" and python_version < "3.11" +dm-control==1.0.11 ; python_full_version >= "3.7.1" and python_version < "3.11" +dm-env==1.6 ; python_full_version >= "3.7.1" and python_version < "3.11" +dm-tree==0.1.8 ; python_full_version >= "3.7.1" and python_version < "3.11" docker-pycreds==0.4.0 ; python_full_version >= "3.7.1" and python_version < "3.11" farama-notifications==0.0.4 ; python_full_version >= "3.7.1" and python_version < "3.11" filelock==3.12.0 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -22,6 +25,7 @@ grpcio==1.54.0 ; python_full_version >= "3.7.1" and python_version < "3.11" gym-notices==0.0.8 ; python_full_version >= "3.7.1" and python_version < "3.11" gym==0.23.1 ; python_full_version >= "3.7.1" and python_version < "3.11" gymnasium==0.28.1 ; python_full_version >= "3.7.1" and python_version < "3.11" +h5py==3.8.0 ; python_full_version >= "3.7.1" and python_version < "3.11" huggingface-hub==0.11.1 ; python_full_version >= "3.7.1" and python_version < "3.11" idna==3.4 ; python_full_version >= "3.7.1" and python_version < "3.11" imageio-ffmpeg==0.3.0 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -29,6 +33,8 @@ imageio==2.28.1 ; python_full_version >= "3.7.1" and python_version < "3.11" importlib-metadata==5.2.0 ; python_full_version >= "3.7.1" and python_version < "3.10" jax-jumpy==1.0.0 ; python_full_version >= "3.7.1" and python_version < "3.11" kiwisolver==1.4.4 ; python_full_version >= "3.7.1" and python_version < "3.11" +labmaze==1.0.6 ; python_full_version >= "3.7.1" and python_version < "3.11" +lxml==4.9.2 ; python_full_version >= "3.7.1" and python_version < "3.11" markdown==3.3.7 ; python_full_version >= "3.7.1" and python_version < "3.11" markupsafe==2.1.2 ; python_full_version >= "3.7.1" and python_version < "3.11" matplotlib==3.5.3 ; python_full_version >= "3.7.1" and python_version < "3.11" @@ -56,9 +62,11 @@ requests-oauthlib==1.3.1 ; python_full_version >= "3.7.1" and python_version < " requests==2.30.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rich==11.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" rsa==4.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" +scipy==1.7.3 ; python_full_version >= "3.7.1" and python_version < "3.11" sentry-sdk==1.22.2 ; python_full_version >= "3.7.1" and python_version < "3.11" setproctitle==1.3.2 ; python_full_version >= "3.7.1" and python_version < "3.11" setuptools==67.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" +shimmy[dm-control]==1.1.0 ; python_full_version >= "3.7.1" and python_version < "3.11" six==1.16.0 ; python_full_version >= "3.7.1" and python_version < "3.11" smmap==5.0.0 ; python_full_version >= "3.7.1" and python_version < "3.11" stable-baselines3==1.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" diff --git a/requirements/requirements-optuna.txt b/requirements/requirements-optuna.txt index 799c8f289..aed9bf871 100644 --- a/requirements/requirements-optuna.txt +++ b/requirements/requirements-optuna.txt @@ -21,7 +21,7 @@ gitdb==4.0.10 ; python_full_version >= "3.7.1" and python_version < "3.11" gitpython==3.1.31 ; python_full_version >= "3.7.1" and python_version < "3.11" google-auth-oauthlib==0.4.6 ; python_full_version >= "3.7.1" and python_version < "3.11" google-auth==2.18.0 ; python_full_version >= "3.7.1" and python_version < "3.11" -greenlet==2.0.2 ; python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "aarch64" or python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "ppc64le" or python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "x86_64" or python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "amd64" or python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "AMD64" or python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "win32" or python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "WIN32" +greenlet==2.0.2 ; python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "win32" or python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "WIN32" or python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "AMD64" or python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "amd64" or python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "x86_64" or python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "ppc64le" or python_full_version >= "3.7.1" and python_version < "3.11" and platform_machine == "aarch64" grpcio==1.54.0 ; python_full_version >= "3.7.1" and python_version < "3.11" gym-notices==0.0.8 ; python_full_version >= "3.7.1" and python_version < "3.11" gym==0.23.1 ; python_full_version >= "3.7.1" and python_version < "3.11" diff --git a/requirements/requirements-pettingzoo.txt b/requirements/requirements-pettingzoo.txt index c3dc9d9ba..9883b2f24 100644 --- a/requirements/requirements-pettingzoo.txt +++ b/requirements/requirements-pettingzoo.txt @@ -38,6 +38,7 @@ oauthlib==3.2.2 ; python_full_version >= "3.7.1" and python_version < "3.11" packaging==23.1 ; python_full_version >= "3.7.1" and python_version < "3.11" pandas==1.3.5 ; python_full_version >= "3.7.1" and python_version < "3.11" pathtools==0.1.2 ; python_full_version >= "3.7.1" and python_version < "3.11" +pettingzoo==1.23.1 ; python_full_version >= "3.7.1" and python_version < "3.11" pillow==9.5.0 ; python_full_version >= "3.7.1" and python_version < "3.11" proglog==0.1.10 ; python_full_version >= "3.7.1" and python_version < "3.11" protobuf==3.20.3 ; python_version < "3.11" and python_full_version >= "3.7.1" @@ -60,10 +61,12 @@ setuptools==67.7.2 ; python_full_version >= "3.7.1" and python_version < "3.11" six==1.16.0 ; python_full_version >= "3.7.1" and python_version < "3.11" smmap==5.0.0 ; python_full_version >= "3.7.1" and python_version < "3.11" stable-baselines3==1.2.0 ; python_full_version >= "3.7.1" and python_version < "3.11" +supersuit==3.8.1 ; python_full_version >= "3.7.1" and python_version < "3.11" tenacity==8.2.2 ; python_full_version >= "3.7.1" and python_version < "3.11" tensorboard-data-server==0.6.1 ; python_full_version >= "3.7.1" and python_version < "3.11" tensorboard-plugin-wit==1.8.1 ; python_full_version >= "3.7.1" and python_version < "3.11" tensorboard==2.11.2 ; python_full_version >= "3.7.1" and python_version < "3.11" +tinyscaler==1.2.5 ; python_full_version >= "3.7.1" and python_version < "3.11" torch==1.12.1 ; python_full_version >= "3.7.1" and python_version < "3.11" tqdm==4.65.0 ; python_full_version >= "3.7.1" and python_version < "3.11" typing-extensions==4.5.0 ; python_full_version >= "3.7.1" and python_version < "3.11" From 2c76bb15983d30cb8bbf81856d21b5ea251026d1 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Mon, 17 Jul 2023 14:51:36 -0400 Subject: [PATCH 08/20] Change torch.maximum to torch.logical_or for dones --- cleanrl/ppo_pettingzoo_ma_atari.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cleanrl/ppo_pettingzoo_ma_atari.py b/cleanrl/ppo_pettingzoo_ma_atari.py index 71c69d0d3..8485f7005 100644 --- a/cleanrl/ppo_pettingzoo_ma_atari.py +++ b/cleanrl/ppo_pettingzoo_ma_atari.py @@ -227,8 +227,8 @@ def get_action_and_value(self, x, action=None): next_value = agent.get_value(next_obs).reshape(1, -1) advantages = torch.zeros_like(rewards).to(device) lastgaelam = 0 - next_done = torch.maximum(next_termination, next_truncation) - dones = torch.maximum(terminations, truncations) + next_done = torch.logical_or(next_termination, next_truncation) + dones = torch.logical_or(terminations, truncations) for t in reversed(range(args.num_steps)): if t == args.num_steps - 1: nextnonterminal = 1.0 - next_done From 025f491194a3aa1272eb783255dff8f9821f6b48 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Tue, 18 Jul 2023 16:29:33 -0400 Subject: [PATCH 09/20] Use np.logical_or instead of torch (allows subtraction) --- cleanrl/ppo_pettingzoo_ma_atari.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cleanrl/ppo_pettingzoo_ma_atari.py b/cleanrl/ppo_pettingzoo_ma_atari.py index 8485f7005..2124f8c86 100644 --- a/cleanrl/ppo_pettingzoo_ma_atari.py +++ b/cleanrl/ppo_pettingzoo_ma_atari.py @@ -227,8 +227,8 @@ def get_action_and_value(self, x, action=None): next_value = agent.get_value(next_obs).reshape(1, -1) advantages = torch.zeros_like(rewards).to(device) lastgaelam = 0 - next_done = torch.logical_or(next_termination, next_truncation) - dones = torch.logical_or(terminations, truncations) + next_done = np.logical_or(next_termination, next_truncation) + dones = np.logical_or(terminations, truncations) for t in reversed(range(args.num_steps)): if t == args.num_steps - 1: nextnonterminal = 1.0 - next_done From 16e076430b30c7f054258cecb8acaeda4e447757 Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 18 Jan 2024 11:12:41 -0500 Subject: [PATCH 10/20] Finish merge with upstream master --- ...CleanRL_Huggingface_Integration_Demo.ipynb | 9772 ++++++++--------- pyproject.toml | 2 +- 2 files changed, 4887 insertions(+), 4887 deletions(-) diff --git a/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb b/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb index 4cb022ec3..51775005f 100644 --- a/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb +++ b/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb @@ -1,4941 +1,4941 @@ { - "cells": [ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "oTuvowvgFQpm" + }, + "source": [ + "# CleanRL's Huggingface Integration Demo\n", + "\n", + "\n", + "\n", + "[](https://github.com/vwxyzjn/cleanrl)\n", + "[![tests](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml/badge.svg)](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml)\n", + "[![docs](https://img.shields.io/github/deployments/vwxyzjn/cleanrl/Production?label=docs&logo=vercel)](https://docs.cleanrl.dev/)\n", + "[](https://discord.gg/D6RCjA6sVT)\n", + "[](https://www.youtube.com/channel/UCDdC6BIFRI0jvcwuhi3aI6w/videos)\n", + "[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)\n", + "[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)\n", + "[](https://huggingface.co/cleanrl)\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vwxyzjn/cleanrl/blob/master/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb)\n", + "\n", + "\n", + "CleanRL is a Deep Reinforcement Learning library that provides high-quality single-file implementation with research-friendly features. It now has has 🧪 experimental support for saving and loading models from 🤗 HuggingFace's [Model Hub](https://huggingface.co/models). This notebook is a preliminary demo.\n", + "\n", + "\n", + "* 💾 [GitHub Repo](https://github.com/vwxyzjn/cleanrl)\n", + "* 📜 [Documentation](https://docs.cleanrl.dev/)\n", + "* 🤗 [HuggingFace Model Hub](https://huggingface.co/cleanrl)\n", + "* 🔗 [Open RL Benchmark reports](https://wandb.ai/openrlbenchmark/openrlbenchmark/reportlist)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J0zhqyfea0If" + }, + "source": [ + "## Get Started\n", + "\n", + "CleanRL can be installed via `pip`. Let's say we are interested in pulling the model for [`dqn_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari_jax.py), we can install the algorithm-variant-specific dependencies as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lhnJkrYLOvcs", + "outputId": "381d9d0d-7e83-4f21-ef89-91d4e3b93c18" + }, + "outputs": [ { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "id": "oTuvowvgFQpm" - }, - "source": [ - "# CleanRL's Huggingface Integration Demo\n", - "\n", - "\n", - "\n", - "[](https://github.com/vwxyzjn/cleanrl)\n", - "[![tests](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml/badge.svg)](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml)\n", - "[![docs](https://img.shields.io/github/deployments/vwxyzjn/cleanrl/Production?label=docs&logo=vercel)](https://docs.cleanrl.dev/)\n", - "[](https://discord.gg/D6RCjA6sVT)\n", - "[](https://www.youtube.com/channel/UCDdC6BIFRI0jvcwuhi3aI6w/videos)\n", - "[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)\n", - "[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)\n", - "[](https://huggingface.co/cleanrl)\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vwxyzjn/cleanrl/blob/master/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb)\n", - "\n", - "\n", - "CleanRL is a Deep Reinforcement Learning library that provides high-quality single-file implementation with research-friendly features. It now has has 🧪 experimental support for saving and loading models from 🤗 HuggingFace's [Model Hub](https://huggingface.co/models). This notebook is a preliminary demo.\n", - "\n", - "\n", - "* 💾 [GitHub Repo](https://github.com/vwxyzjn/cleanrl)\n", - "* 📜 [Documentation](https://docs.cleanrl.dev/)\n", - "* 🤗 [HuggingFace Model Hub](https://huggingface.co/cleanrl)\n", - "* 🔗 [Open RL Benchmark reports](https://wandb.ai/openrlbenchmark/openrlbenchmark/reportlist)\n", - "\n" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting cleanrl[dqn-atari-jax]\n", + " Downloading cleanrl-1.1.2-py3-none-any.whl (16.9 MB)\n", + "\u001B[K |████████████████████████████████| 16.9 MB 241 kB/s \n", + "\u001B[?25hCollecting pygame==2.1.0\n", + " Downloading pygame-2.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)\n", + "\u001B[K |████████████████████████████████| 18.3 MB 59.3 MB/s \n", + "\u001B[?25hCollecting huggingface-hub<0.12.0,>=0.11.1\n", + " Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)\n", + "\u001B[K |████████████████████████████████| 182 kB 74.9 MB/s \n", + "\u001B[?25hCollecting wandb<0.14.0,>=0.13.6\n", + " Downloading wandb-0.13.7-py2.py3-none-any.whl (1.9 MB)\n", + "\u001B[K |████████████████████████████████| 1.9 MB 63.6 MB/s \n", + "\u001B[?25hRequirement already satisfied: torch>=1.12.1 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (1.13.0+cu116)\n", + "Collecting stable-baselines3==1.2.0\n", + " Downloading stable_baselines3-1.2.0-py3-none-any.whl (161 kB)\n", + "\u001B[K |████████████████████████████████| 161 kB 64.7 MB/s \n", + "\u001B[?25hCollecting tensorboard<3.0.0,>=2.10.0\n", + " Downloading tensorboard-2.11.0-py3-none-any.whl (6.0 MB)\n", + "\u001B[K |████████████████████████████████| 6.0 MB 65.0 MB/s \n", + "\u001B[?25hCollecting moviepy<2.0.0,>=1.0.3\n", + " Downloading moviepy-1.0.3.tar.gz (388 kB)\n", + "\u001B[K |████████████████████████████████| 388 kB 59.5 MB/s \n", + "\u001B[?25hCollecting gym==0.23.1\n", + " Downloading gym-0.23.1.tar.gz (626 kB)\n", + "\u001B[K |████████████████████████████████| 626 kB 59.9 MB/s \n", + "\u001B[?25h Installing build dependencies ... \u001B[?25l\u001B[?25hdone\n", + " Getting requirements to build wheel ... \u001B[?25l\u001B[?25hdone\n", + " Preparing wheel metadata ... \u001B[?25l\u001B[?25hdone\n", + "Collecting gymnasium<0.27.0,>=0.26.3\n", + " Downloading Gymnasium-0.26.3-py3-none-any.whl (836 kB)\n", + "\u001B[K |████████████████████████████████| 836 kB 64.6 MB/s \n", + "\u001B[?25hCollecting AutoROM[accept-rom-license]<0.5.0,>=0.4.2\n", + " Downloading AutoROM-0.4.2-py3-none-any.whl (16 kB)\n", + "Collecting ale-py==0.7.4\n", + " Downloading ale_py-0.7.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)\n", + "\u001B[K |████████████████████████████████| 1.6 MB 55.1 MB/s \n", + "\u001B[?25hRequirement already satisfied: opencv-python<5.0.0.0,>=4.6.0.66 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (4.6.0.66)\n", + "Requirement already satisfied: jax<0.4.0,>=0.3.17 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (0.3.25)\n", + "Collecting flax<0.7.0,>=0.6.0\n", + " Downloading flax-0.6.3-py3-none-any.whl (197 kB)\n", + "\u001B[K |████████████████████████████████| 197 kB 73.9 MB/s \n", + "\u001B[?25hRequirement already satisfied: jaxlib<0.4.0,>=0.3.15 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (0.3.25+cuda11.cudnn805)\n", + "Requirement already satisfied: importlib-metadata>=4.10.0 in /usr/local/lib/python3.8/dist-packages (from ale-py==0.7.4->cleanrl[dqn-atari-jax]) (5.1.0)\n", + "Requirement already satisfied: importlib-resources in /usr/local/lib/python3.8/dist-packages (from ale-py==0.7.4->cleanrl[dqn-atari-jax]) (5.10.1)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from ale-py==0.7.4->cleanrl[dqn-atari-jax]) (1.21.6)\n", + "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.8/dist-packages (from gym==0.23.1->cleanrl[dqn-atari-jax]) (0.0.8)\n", + "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.8/dist-packages (from gym==0.23.1->cleanrl[dqn-atari-jax]) (1.5.0)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (1.3.5)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (3.2.2)\n", + "Requirement already satisfied: click in /usr/local/lib/python3.8/dist-packages (from AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (7.1.2)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (2.23.0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (from AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (4.64.1)\n", + "Collecting AutoROM.accept-rom-license\n", + " Downloading AutoROM.accept-rom-license-0.5.0.tar.gz (10 kB)\n", + " Installing build dependencies ... \u001B[?25l\u001B[?25hdone\n", + " Getting requirements to build wheel ... \u001B[?25l\u001B[?25hdone\n", + " Preparing wheel metadata ... \u001B[?25l\u001B[?25hdone\n", + "Requirement already satisfied: PyYAML>=5.4.1 in /usr/local/lib/python3.8/dist-packages (from flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (6.0)\n", + "Collecting rich>=11.1\n", + " Downloading rich-13.0.0-py3-none-any.whl (238 kB)\n", + "\u001B[K |████████████████████████████████| 238 kB 76.7 MB/s \n", + "\u001B[?25hRequirement already satisfied: typing-extensions>=4.1.1 in /usr/local/lib/python3.8/dist-packages (from flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (4.4.0)\n", + "Collecting orbax\n", + " Downloading orbax-0.0.23-py3-none-any.whl (66 kB)\n", + "\u001B[K |████████████████████████████████| 66 kB 6.3 MB/s \n", + "\u001B[?25hCollecting optax\n", + " Downloading optax-0.1.4-py3-none-any.whl (154 kB)\n", + "\u001B[K |████████████████████████████████| 154 kB 82.0 MB/s \n", + "\u001B[?25hCollecting tensorstore\n", + " Downloading tensorstore-0.1.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)\n", + "\u001B[K |████████████████████████████████| 8.3 MB 64.7 MB/s \n", + "\u001B[?25hRequirement already satisfied: msgpack in /usr/local/lib/python3.8/dist-packages (from flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (1.0.4)\n", + "Collecting gymnasium-notices>=0.0.1\n", + " Downloading gymnasium_notices-0.0.1-py3-none-any.whl (2.8 kB)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<0.12.0,>=0.11.1->cleanrl[dqn-atari-jax]) (3.8.2)\n", + "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<0.12.0,>=0.11.1->cleanrl[dqn-atari-jax]) (21.3)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata>=4.10.0->ale-py==0.7.4->cleanrl[dqn-atari-jax]) (3.11.0)\n", + "Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.8/dist-packages (from jax<0.4.0,>=0.3.17->cleanrl[dqn-atari-jax]) (1.7.3)\n", + "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.8/dist-packages (from jax<0.4.0,>=0.3.17->cleanrl[dqn-atari-jax]) (3.3.0)\n", + "Requirement already satisfied: decorator<5.0,>=4.0.2 in /usr/local/lib/python3.8/dist-packages (from moviepy<2.0.0,>=1.0.3->cleanrl[dqn-atari-jax]) (4.4.2)\n", + "Collecting proglog<=1.0.0\n", + " Downloading proglog-0.1.10-py3-none-any.whl (6.1 kB)\n", + "Requirement already satisfied: imageio<3.0,>=2.5 in /usr/local/lib/python3.8/dist-packages (from moviepy<2.0.0,>=1.0.3->cleanrl[dqn-atari-jax]) (2.9.0)\n", + "Collecting imageio_ffmpeg>=0.2.0\n", + " Downloading imageio_ffmpeg-0.4.7-py3-none-manylinux2010_x86_64.whl (26.9 MB)\n", + "\u001B[K |████████████████████████████████| 26.9 MB 47.9 MB/s \n", + "\u001B[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.8/dist-packages (from imageio<3.0,>=2.5->moviepy<2.0.0,>=1.0.3->cleanrl[dqn-atari-jax]) (7.1.2)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.9->huggingface-hub<0.12.0,>=0.11.1->cleanrl[dqn-atari-jax]) (3.0.9)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (2022.12.7)\n", + "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (1.24.3)\n", + "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (2.10)\n", + "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (3.0.4)\n", + "Collecting commonmark<0.10.0,>=0.9.0\n", + " Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)\n", + "\u001B[K |████████████████████████████████| 51 kB 5.0 MB/s \n", + "\u001B[?25hRequirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.8/dist-packages (from rich>=11.1->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (2.6.1)\n", + "Requirement already satisfied: protobuf<4,>=3.9.2 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (3.19.6)\n", + "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (2.15.0)\n", + "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.6.1)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (3.4.1)\n", + "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.38.4)\n", + "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.4.6)\n", + "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.3.0)\n", + "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (57.4.0)\n", + "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.8.1)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.0.1)\n", + "Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.51.1)\n", + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.2.8)\n", + "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (5.2.0)\n", + "Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.15.0)\n", + "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (4.9)\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.8/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.3.1)\n", + "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.8/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.4.8)\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.8/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (3.2.2)\n", + "Collecting pathtools\n", + " Downloading pathtools-0.1.2.tar.gz (11 kB)\n", + "Requirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.8/dist-packages (from wandb<0.14.0,>=0.13.6->cleanrl[dqn-atari-jax]) (2.3)\n", + "Collecting GitPython>=1.0.0\n", + " Downloading GitPython-3.1.30-py3-none-any.whl (184 kB)\n", + "\u001B[K |████████████████████████████████| 184 kB 71.4 MB/s \n", + "\u001B[?25hRequirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.8/dist-packages (from wandb<0.14.0,>=0.13.6->cleanrl[dqn-atari-jax]) (5.4.8)\n", + "Collecting shortuuid>=0.5.0\n", + " Downloading shortuuid-1.0.11-py3-none-any.whl (10 kB)\n", + "Collecting sentry-sdk>=1.0.0\n", + " Downloading sentry_sdk-1.12.1-py2.py3-none-any.whl (174 kB)\n", + "\u001B[K |████████████████████████████████| 174 kB 80.8 MB/s \n", + "\u001B[?25hCollecting docker-pycreds>=0.4.0\n", + " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", + "Collecting setproctitle\n", + " Downloading setproctitle-1.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB)\n", + "Collecting gitdb<5,>=4.0.1\n", + " Downloading gitdb-4.0.10-py3-none-any.whl (62 kB)\n", + "\u001B[K |████████████████████████████████| 62 kB 1.7 MB/s \n", + "\u001B[?25hCollecting smmap<6,>=3.0.1\n", + " Downloading smmap-5.0.0-py3-none-any.whl (24 kB)\n", + "Collecting sentry-sdk>=1.0.0\n", + " Downloading sentry_sdk-1.12.0-py2.py3-none-any.whl (173 kB)\n", + "\u001B[K |████████████████████████████████| 173 kB 69.0 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.11.1-py2.py3-none-any.whl (168 kB)\n", + "\u001B[K |████████████████████████████████| 168 kB 66.6 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.11.0-py2.py3-none-any.whl (168 kB)\n", + "\u001B[K |████████████████████████████████| 168 kB 8.1 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.10.1-py2.py3-none-any.whl (166 kB)\n", + "\u001B[K |████████████████████████████████| 166 kB 10.6 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.10.0-py2.py3-none-any.whl (166 kB)\n", + "\u001B[K |████████████████████████████████| 166 kB 71.4 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.9.10-py2.py3-none-any.whl (162 kB)\n", + "\u001B[K |████████████████████████████████| 162 kB 70.1 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.9.9-py2.py3-none-any.whl (162 kB)\n", + "\u001B[K |████████████████████████████████| 162 kB 70.2 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.9.8-py2.py3-none-any.whl (158 kB)\n", + "\u001B[K |████████████████████████████████| 158 kB 75.4 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.9.7-py2.py3-none-any.whl (157 kB)\n", + "\u001B[K |████████████████████████████████| 157 kB 77.6 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.9.6-py2.py3-none-any.whl (157 kB)\n", + "\u001B[K |████████████████████████████████| 157 kB 83.8 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.9.5-py2.py3-none-any.whl (157 kB)\n", + "\u001B[K |████████████████████████████████| 157 kB 88.0 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.9.4-py2.py3-none-any.whl (157 kB)\n", + "\u001B[K |████████████████████████████████| 157 kB 80.1 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.9.3-py2.py3-none-any.whl (157 kB)\n", + "\u001B[K |████████████████████████████████| 157 kB 84.8 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.9.2-py2.py3-none-any.whl (157 kB)\n", + "\u001B[K |████████████████████████████████| 157 kB 85.7 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.9.1-py2.py3-none-any.whl (157 kB)\n", + "\u001B[K |████████████████████████████████| 157 kB 83.5 MB/s \n", + "\u001B[?25h Downloading sentry_sdk-1.9.0-py2.py3-none-any.whl (156 kB)\n", + "\u001B[K |████████████████████████████████| 156 kB 84.0 MB/s \n", + "\u001B[?25hCollecting libtorrent\n", + " Using cached libtorrent-2.0.7-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (8.6 MB)\n", + "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (2.8.2)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (0.11.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (1.4.4)\n", + "Collecting chex>=0.1.5\n", + " Downloading chex-0.1.5-py3-none-any.whl (85 kB)\n", + "\u001B[K |████████████████████████████████| 85 kB 4.9 MB/s \n", + "\u001B[?25hRequirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.8/dist-packages (from chex>=0.1.5->optax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.1.7)\n", + "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.8/dist-packages (from chex>=0.1.5->optax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.12.0)\n", + "Requirement already satisfied: pytest in /usr/local/lib/python3.8/dist-packages (from orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (3.6.4)\n", + "Collecting cached_property\n", + " Downloading cached_property-1.5.2-py2.py3-none-any.whl (7.6 kB)\n", + "Requirement already satisfied: etils in /usr/local/lib/python3.8/dist-packages (from orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.9.0)\n", + "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (2022.6)\n", + "Requirement already satisfied: more-itertools>=4.0.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (9.0.0)\n", + "Requirement already satisfied: pluggy<0.8,>=0.5 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.7.1)\n", + "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (22.1.0)\n", + "Requirement already satisfied: py>=1.5.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (1.11.0)\n", + "Requirement already satisfied: atomicwrites>=1.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (1.4.1)\n", + "Building wheels for collected packages: gym, moviepy, AutoROM.accept-rom-license, pathtools\n", + " Building wheel for gym (PEP 517) ... \u001B[?25l\u001B[?25hdone\n", + " Created wheel for gym: filename=gym-0.23.1-py3-none-any.whl size=701376 sha256=7b59f30aef873fc1494bd2f2eeac27b103b64ae6ee87d554c8b61b9ddbe35765\n", + " Stored in directory: /root/.cache/pip/wheels/78/28/77/b0c74e80a2a4faae0161d5c53bc4f8e436e77aedc79136ee13\n", + " Building wheel for moviepy (setup.py) ... \u001B[?25l\u001B[?25hdone\n", + " Created wheel for moviepy: filename=moviepy-1.0.3-py3-none-any.whl size=110742 sha256=640c1c0df827ed5835373acab4d2d7b93e98e33b5e6cb90e3d5e703933f9bcf8\n", + " Stored in directory: /root/.cache/pip/wheels/e4/a4/db/0368d3a04033da662e13926594b3a8cf1aa4ffeefe570cfac1\n", + " Building wheel for AutoROM.accept-rom-license (PEP 517) ... \u001B[?25l\u001B[?25hdone\n", + " Created wheel for AutoROM.accept-rom-license: filename=AutoROM.accept_rom_license-0.5.0-py3-none-any.whl size=440868 sha256=a3833e2c22c21355029cb083d9ea62b7abe329af3757ccdce9b0d2a5cc06949f\n", + " Stored in directory: /root/.cache/pip/wheels/bf/c9/25/578470ae932b494c313dc22e6c57afff192140fb3cd5acf185\n", + " Building wheel for pathtools (setup.py) ... \u001B[?25l\u001B[?25hdone\n", + " Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8806 sha256=57226a75b752bf852ac2f0f5ad878217a63376d6c44a4b29ccdf40b4921bf4bc\n", + " Stored in directory: /root/.cache/pip/wheels/4c/8e/7e/72fbc243e1aeecae64a96875432e70d4e92f3d2d18123be004\n", + "Successfully built gym moviepy AutoROM.accept-rom-license pathtools\n", + "Installing collected packages: smmap, gitdb, tensorstore, shortuuid, setproctitle, sentry-sdk, proglog, pathtools, libtorrent, imageio-ffmpeg, gymnasium-notices, gym, GitPython, docker-pycreds, commonmark, chex, cached-property, wandb, tensorboard, stable-baselines3, rich, pygame, orbax, optax, moviepy, huggingface-hub, gymnasium, AutoROM.accept-rom-license, AutoROM, flax, cleanrl-test, ale-py\n", + " Attempting uninstall: gym\n", + " Found existing installation: gym 0.25.2\n", + " Uninstalling gym-0.25.2:\n", + " Successfully uninstalled gym-0.25.2\n", + " Attempting uninstall: tensorboard\n", + " Found existing installation: tensorboard 2.9.1\n", + " Uninstalling tensorboard-2.9.1:\n", + " Successfully uninstalled tensorboard-2.9.1\n", + " Attempting uninstall: moviepy\n", + " Found existing installation: moviepy 0.2.3.5\n", + " Uninstalling moviepy-0.2.3.5:\n", + " Successfully uninstalled moviepy-0.2.3.5\n", + "\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow 2.9.2 requires tensorboard<2.10,>=2.9, but you have tensorboard 2.11.0 which is incompatible.\u001B[0m\n", + "Successfully installed AutoROM-0.4.2 AutoROM.accept-rom-license-0.5.0 GitPython-3.1.30 ale-py-0.7.4 cached-property-1.5.2 chex-0.1.5 cleanrl-test-1.1.2 commonmark-0.9.1 docker-pycreds-0.4.0 flax-0.6.3 gitdb-4.0.10 gym-0.23.1 gymnasium-0.26.3 gymnasium-notices-0.0.1 huggingface-hub-0.11.1 imageio-ffmpeg-0.4.7 libtorrent-2.0.7 moviepy-1.0.3 optax-0.1.4 orbax-0.0.23 pathtools-0.1.2 proglog-0.1.10 pygame-2.1.0 rich-13.0.0 sentry-sdk-1.9.0 setproctitle-1.3.2 shortuuid-1.0.11 smmap-5.0.0 stable-baselines3-1.2.0 tensorboard-2.11.0 tensorstore-0.1.28 wandb-0.13.7\n" + ] + } + ], + "source": [ + "!pip install --upgrade \"cleanrl[dqn-atari-jax]\" # CAVEAT: the extra key is `dqn-atari-jax` with dashes instead of `dqn_atari_jax` with underscores" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xXQXZTh_AHZ0" + }, + "source": [ + "## Enjoy Utility\n", + "\n", + "We have a simple way to load the model by running our \"enjoy\" utility, which automatically pull the model from 🤗 HuggingFace and run for a few episodes. It also produces a rendered video through the `--capture_video` flag. See more at our [📜 Documentation](https://docs.cleanrl.dev/get-started/zoo/)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "4H9VZKBC_3_1", + "outputId": "fc03fd9b-84f8-43dc-b4e3-041e7a201c12" + }, + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "J0zhqyfea0If" - }, - "source": [ - "## Get Started\n", - "\n", - "CleanRL can be installed via `pip`. Let's say we are interested in pulling the model for [`dqn_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari_jax.py), we can install the algorithm-variant-specific dependencies as follows:" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/jupyter_client/connect.py:28: DeprecationWarning: Jupyter is migrating its paths to use standard platformdirs\n", + "given by the platformdirs library. To remove this warning and\n", + "see the appropriate new directories, set the environment variable\n", + "`JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.\n", + "The use of platformdirs will be the default in `jupyter_core` v6\n", + " from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write\n", + "loading saved models from cleanrl/BreakoutNoFrameskip-v4-dqn_atari_jax-seed1...\n", + "Downloading: 100% 6.75M/6.75M [00:00<00:00, 62.6MB/s]\n", + "A.L.E: Arcade Learning Environment (version 0.7.4+069f8bd)\n", + "[Powered by Stella]\n", + "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:138: DeprecationWarning: \u001B[33mWARN: Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. \u001B[0m\n", + " deprecation(\n", + "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:175: DeprecationWarning: \u001B[33mWARN: Function `_bigint_from_bytes(bytes)` is marked as deprecated and will be removed in the future. \u001B[0m\n", + " deprecation(\n", + "/usr/local/lib/python3.8/dist-packages/gym/wrappers/monitoring/video_recorder.py:43: DeprecationWarning: \u001B[33mWARN: `env.metadata[\"render.modes\"] is marked as deprecated and will be replaced with `env.metadata[\"render_modes\"]` see https://github.com/openai/gym/pull/2654 for more details\u001B[0m\n", + " logger.deprecation(\n", + "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:47: DeprecationWarning: \u001B[33mWARN: Function `rng.randint(low, [high, size, dtype])` is marked as deprecated and will be removed in the future. Please use `rng.integers(low, [high, size, dtype])` instead.\u001B[0m\n", + " deprecation(\n", + "/usr/local/lib/python3.8/dist-packages/gym/wrappers/monitoring/video_recorder.py:43: DeprecationWarning: \u001B[33mWARN: `env.metadata[\"render.modes\"] is marked as deprecated and will be replaced with `env.metadata[\"render_modes\"]` see https://github.com/openai/gym/pull/2654 for more details\u001B[0m\n", + " logger.deprecation(\n", + "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:47: DeprecationWarning: \u001B[33mWARN: Function `rng.randint(low, [high, size, dtype])` is marked as deprecated and will be removed in the future. Please use `rng.integers(low, [high, size, dtype])` instead.\u001B[0m\n", + " deprecation(\n", + "eval_episode=0, episodic_return=400.0\n", + "eval_episode=1, episodic_return=128.0\n" + ] + } + ], + "source": [ + "!python -m cleanrl_utils.enjoy --exp-name dqn_atari_jax --env-id BreakoutNoFrameskip-v4 --eval-episodes 2 --capture_video" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 231 }, + "id": "KpzdA4dkFbdT", + "outputId": "1b53628e-ac19-4f36-89e4-1a831b51f06b" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lhnJkrYLOvcs", - "outputId": "381d9d0d-7e83-4f21-ef89-91d4e3b93c18" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting cleanrl[dqn-atari-jax]\n", - " Downloading cleanrl-1.1.2-py3-none-any.whl (16.9 MB)\n", - "\u001b[K |████████████████████████████████| 16.9 MB 241 kB/s \n", - "\u001b[?25hCollecting pygame==2.1.0\n", - " Downloading pygame-2.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)\n", - "\u001b[K |████████████████████████████████| 18.3 MB 59.3 MB/s \n", - "\u001b[?25hCollecting huggingface-hub<0.12.0,>=0.11.1\n", - " Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)\n", - "\u001b[K |████████████████████████████████| 182 kB 74.9 MB/s \n", - "\u001b[?25hCollecting wandb<0.14.0,>=0.13.6\n", - " Downloading wandb-0.13.7-py2.py3-none-any.whl (1.9 MB)\n", - "\u001b[K |████████████████████████████████| 1.9 MB 63.6 MB/s \n", - "\u001b[?25hRequirement already satisfied: torch>=1.12.1 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (1.13.0+cu116)\n", - "Collecting stable-baselines3==1.2.0\n", - " Downloading stable_baselines3-1.2.0-py3-none-any.whl (161 kB)\n", - "\u001b[K |████████████████████████████████| 161 kB 64.7 MB/s \n", - "\u001b[?25hCollecting tensorboard<3.0.0,>=2.10.0\n", - " Downloading tensorboard-2.11.0-py3-none-any.whl (6.0 MB)\n", - "\u001b[K |████████████████████████████████| 6.0 MB 65.0 MB/s \n", - "\u001b[?25hCollecting moviepy<2.0.0,>=1.0.3\n", - " Downloading moviepy-1.0.3.tar.gz (388 kB)\n", - "\u001b[K |████████████████████████████████| 388 kB 59.5 MB/s \n", - "\u001b[?25hCollecting gym==0.23.1\n", - " Downloading gym-0.23.1.tar.gz (626 kB)\n", - "\u001b[K |████████████████████████████████| 626 kB 59.9 MB/s \n", - "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n", - "Collecting gymnasium<0.27.0,>=0.26.3\n", - " Downloading Gymnasium-0.26.3-py3-none-any.whl (836 kB)\n", - "\u001b[K |████████████████████████████████| 836 kB 64.6 MB/s \n", - "\u001b[?25hCollecting AutoROM[accept-rom-license]<0.5.0,>=0.4.2\n", - " Downloading AutoROM-0.4.2-py3-none-any.whl (16 kB)\n", - "Collecting ale-py==0.7.4\n", - " Downloading ale_py-0.7.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)\n", - "\u001b[K |████████████████████████████████| 1.6 MB 55.1 MB/s \n", - "\u001b[?25hRequirement already satisfied: opencv-python<5.0.0.0,>=4.6.0.66 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (4.6.0.66)\n", - "Requirement already satisfied: jax<0.4.0,>=0.3.17 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (0.3.25)\n", - "Collecting flax<0.7.0,>=0.6.0\n", - " Downloading flax-0.6.3-py3-none-any.whl (197 kB)\n", - "\u001b[K |████████████████████████████████| 197 kB 73.9 MB/s \n", - "\u001b[?25hRequirement already satisfied: jaxlib<0.4.0,>=0.3.15 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (0.3.25+cuda11.cudnn805)\n", - "Requirement already satisfied: importlib-metadata>=4.10.0 in /usr/local/lib/python3.8/dist-packages (from ale-py==0.7.4->cleanrl[dqn-atari-jax]) (5.1.0)\n", - "Requirement already satisfied: importlib-resources in /usr/local/lib/python3.8/dist-packages (from ale-py==0.7.4->cleanrl[dqn-atari-jax]) (5.10.1)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from ale-py==0.7.4->cleanrl[dqn-atari-jax]) (1.21.6)\n", - "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.8/dist-packages (from gym==0.23.1->cleanrl[dqn-atari-jax]) (0.0.8)\n", - "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.8/dist-packages (from gym==0.23.1->cleanrl[dqn-atari-jax]) (1.5.0)\n", - "Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (1.3.5)\n", - "Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (3.2.2)\n", - "Requirement already satisfied: click in /usr/local/lib/python3.8/dist-packages (from AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (7.1.2)\n", - "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (2.23.0)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (from AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (4.64.1)\n", - "Collecting AutoROM.accept-rom-license\n", - " Downloading AutoROM.accept-rom-license-0.5.0.tar.gz (10 kB)\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: PyYAML>=5.4.1 in /usr/local/lib/python3.8/dist-packages (from flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (6.0)\n", - "Collecting rich>=11.1\n", - " Downloading rich-13.0.0-py3-none-any.whl (238 kB)\n", - "\u001b[K |████████████████████████████████| 238 kB 76.7 MB/s \n", - "\u001b[?25hRequirement already satisfied: typing-extensions>=4.1.1 in /usr/local/lib/python3.8/dist-packages (from flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (4.4.0)\n", - "Collecting orbax\n", - " Downloading orbax-0.0.23-py3-none-any.whl (66 kB)\n", - "\u001b[K |████████████████████████████████| 66 kB 6.3 MB/s \n", - "\u001b[?25hCollecting optax\n", - " Downloading optax-0.1.4-py3-none-any.whl (154 kB)\n", - "\u001b[K |████████████████████████████████| 154 kB 82.0 MB/s \n", - "\u001b[?25hCollecting tensorstore\n", - " Downloading tensorstore-0.1.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)\n", - "\u001b[K |████████████████████████████████| 8.3 MB 64.7 MB/s \n", - "\u001b[?25hRequirement already satisfied: msgpack in /usr/local/lib/python3.8/dist-packages (from flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (1.0.4)\n", - "Collecting gymnasium-notices>=0.0.1\n", - " Downloading gymnasium_notices-0.0.1-py3-none-any.whl (2.8 kB)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<0.12.0,>=0.11.1->cleanrl[dqn-atari-jax]) (3.8.2)\n", - "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<0.12.0,>=0.11.1->cleanrl[dqn-atari-jax]) (21.3)\n", - "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata>=4.10.0->ale-py==0.7.4->cleanrl[dqn-atari-jax]) (3.11.0)\n", - "Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.8/dist-packages (from jax<0.4.0,>=0.3.17->cleanrl[dqn-atari-jax]) (1.7.3)\n", - "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.8/dist-packages (from jax<0.4.0,>=0.3.17->cleanrl[dqn-atari-jax]) (3.3.0)\n", - "Requirement already satisfied: decorator<5.0,>=4.0.2 in /usr/local/lib/python3.8/dist-packages (from moviepy<2.0.0,>=1.0.3->cleanrl[dqn-atari-jax]) (4.4.2)\n", - "Collecting proglog<=1.0.0\n", - " Downloading proglog-0.1.10-py3-none-any.whl (6.1 kB)\n", - "Requirement already satisfied: imageio<3.0,>=2.5 in /usr/local/lib/python3.8/dist-packages (from moviepy<2.0.0,>=1.0.3->cleanrl[dqn-atari-jax]) (2.9.0)\n", - "Collecting imageio_ffmpeg>=0.2.0\n", - " Downloading imageio_ffmpeg-0.4.7-py3-none-manylinux2010_x86_64.whl (26.9 MB)\n", - "\u001b[K |████████████████████████████████| 26.9 MB 47.9 MB/s \n", - "\u001b[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.8/dist-packages (from imageio<3.0,>=2.5->moviepy<2.0.0,>=1.0.3->cleanrl[dqn-atari-jax]) (7.1.2)\n", - "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.9->huggingface-hub<0.12.0,>=0.11.1->cleanrl[dqn-atari-jax]) (3.0.9)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (2022.12.7)\n", - "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (1.24.3)\n", - "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (2.10)\n", - "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (3.0.4)\n", - "Collecting commonmark<0.10.0,>=0.9.0\n", - " Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)\n", - "\u001b[K |████████████████████████████████| 51 kB 5.0 MB/s \n", - "\u001b[?25hRequirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.8/dist-packages (from rich>=11.1->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (2.6.1)\n", - "Requirement already satisfied: protobuf<4,>=3.9.2 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (3.19.6)\n", - "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (2.15.0)\n", - "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.6.1)\n", - "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (3.4.1)\n", - "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.38.4)\n", - "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.4.6)\n", - "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.3.0)\n", - "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (57.4.0)\n", - "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.8.1)\n", - "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.0.1)\n", - "Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.51.1)\n", - "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.2.8)\n", - "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (5.2.0)\n", - "Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.15.0)\n", - "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (4.9)\n", - "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.8/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.3.1)\n", - "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.8/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.4.8)\n", - "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.8/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (3.2.2)\n", - "Collecting pathtools\n", - " Downloading pathtools-0.1.2.tar.gz (11 kB)\n", - "Requirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.8/dist-packages (from wandb<0.14.0,>=0.13.6->cleanrl[dqn-atari-jax]) (2.3)\n", - "Collecting GitPython>=1.0.0\n", - " Downloading GitPython-3.1.30-py3-none-any.whl (184 kB)\n", - "\u001b[K |████████████████████████████████| 184 kB 71.4 MB/s \n", - "\u001b[?25hRequirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.8/dist-packages (from wandb<0.14.0,>=0.13.6->cleanrl[dqn-atari-jax]) (5.4.8)\n", - "Collecting shortuuid>=0.5.0\n", - " Downloading shortuuid-1.0.11-py3-none-any.whl (10 kB)\n", - "Collecting sentry-sdk>=1.0.0\n", - " Downloading sentry_sdk-1.12.1-py2.py3-none-any.whl (174 kB)\n", - "\u001b[K |████████████████████████████████| 174 kB 80.8 MB/s \n", - "\u001b[?25hCollecting docker-pycreds>=0.4.0\n", - " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", - "Collecting setproctitle\n", - " Downloading setproctitle-1.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB)\n", - "Collecting gitdb<5,>=4.0.1\n", - " Downloading gitdb-4.0.10-py3-none-any.whl (62 kB)\n", - "\u001b[K |████████████████████████████████| 62 kB 1.7 MB/s \n", - "\u001b[?25hCollecting smmap<6,>=3.0.1\n", - " Downloading smmap-5.0.0-py3-none-any.whl (24 kB)\n", - "Collecting sentry-sdk>=1.0.0\n", - " Downloading sentry_sdk-1.12.0-py2.py3-none-any.whl (173 kB)\n", - "\u001b[K |████████████████████████████████| 173 kB 69.0 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.11.1-py2.py3-none-any.whl (168 kB)\n", - "\u001b[K |████████████████████████████████| 168 kB 66.6 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.11.0-py2.py3-none-any.whl (168 kB)\n", - "\u001b[K |████████████████████████████████| 168 kB 8.1 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.10.1-py2.py3-none-any.whl (166 kB)\n", - "\u001b[K |████████████████████████████████| 166 kB 10.6 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.10.0-py2.py3-none-any.whl (166 kB)\n", - "\u001b[K |████████████████████████████████| 166 kB 71.4 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.9.10-py2.py3-none-any.whl (162 kB)\n", - "\u001b[K |████████████████████████████████| 162 kB 70.1 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.9.9-py2.py3-none-any.whl (162 kB)\n", - "\u001b[K |████████████████████████████████| 162 kB 70.2 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.9.8-py2.py3-none-any.whl (158 kB)\n", - "\u001b[K |████████████████████████████████| 158 kB 75.4 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.9.7-py2.py3-none-any.whl (157 kB)\n", - "\u001b[K |████████████████████████████████| 157 kB 77.6 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.9.6-py2.py3-none-any.whl (157 kB)\n", - "\u001b[K |████████████████████████████████| 157 kB 83.8 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.9.5-py2.py3-none-any.whl (157 kB)\n", - "\u001b[K |████████████████████████████████| 157 kB 88.0 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.9.4-py2.py3-none-any.whl (157 kB)\n", - "\u001b[K |████████████████████████████████| 157 kB 80.1 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.9.3-py2.py3-none-any.whl (157 kB)\n", - "\u001b[K |████████████████████████████████| 157 kB 84.8 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.9.2-py2.py3-none-any.whl (157 kB)\n", - "\u001b[K |████████████████████████████████| 157 kB 85.7 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.9.1-py2.py3-none-any.whl (157 kB)\n", - "\u001b[K |████████████████████████████████| 157 kB 83.5 MB/s \n", - "\u001b[?25h Downloading sentry_sdk-1.9.0-py2.py3-none-any.whl (156 kB)\n", - "\u001b[K |████████████████████████████████| 156 kB 84.0 MB/s \n", - "\u001b[?25hCollecting libtorrent\n", - " Using cached libtorrent-2.0.7-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (8.6 MB)\n", - "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (2.8.2)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (0.11.0)\n", - "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (1.4.4)\n", - "Collecting chex>=0.1.5\n", - " Downloading chex-0.1.5-py3-none-any.whl (85 kB)\n", - "\u001b[K |████████████████████████████████| 85 kB 4.9 MB/s \n", - "\u001b[?25hRequirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.8/dist-packages (from chex>=0.1.5->optax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.1.7)\n", - "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.8/dist-packages (from chex>=0.1.5->optax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.12.0)\n", - "Requirement already satisfied: pytest in /usr/local/lib/python3.8/dist-packages (from orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (3.6.4)\n", - "Collecting cached_property\n", - " Downloading cached_property-1.5.2-py2.py3-none-any.whl (7.6 kB)\n", - "Requirement already satisfied: etils in /usr/local/lib/python3.8/dist-packages (from orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.9.0)\n", - "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (2022.6)\n", - "Requirement already satisfied: more-itertools>=4.0.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (9.0.0)\n", - "Requirement already satisfied: pluggy<0.8,>=0.5 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.7.1)\n", - "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (22.1.0)\n", - "Requirement already satisfied: py>=1.5.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (1.11.0)\n", - "Requirement already satisfied: atomicwrites>=1.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (1.4.1)\n", - "Building wheels for collected packages: gym, moviepy, AutoROM.accept-rom-license, pathtools\n", - " Building wheel for gym (PEP 517) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for gym: filename=gym-0.23.1-py3-none-any.whl size=701376 sha256=7b59f30aef873fc1494bd2f2eeac27b103b64ae6ee87d554c8b61b9ddbe35765\n", - " Stored in directory: /root/.cache/pip/wheels/78/28/77/b0c74e80a2a4faae0161d5c53bc4f8e436e77aedc79136ee13\n", - " Building wheel for moviepy (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for moviepy: filename=moviepy-1.0.3-py3-none-any.whl size=110742 sha256=640c1c0df827ed5835373acab4d2d7b93e98e33b5e6cb90e3d5e703933f9bcf8\n", - " Stored in directory: /root/.cache/pip/wheels/e4/a4/db/0368d3a04033da662e13926594b3a8cf1aa4ffeefe570cfac1\n", - " Building wheel for AutoROM.accept-rom-license (PEP 517) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for AutoROM.accept-rom-license: filename=AutoROM.accept_rom_license-0.5.0-py3-none-any.whl size=440868 sha256=a3833e2c22c21355029cb083d9ea62b7abe329af3757ccdce9b0d2a5cc06949f\n", - " Stored in directory: /root/.cache/pip/wheels/bf/c9/25/578470ae932b494c313dc22e6c57afff192140fb3cd5acf185\n", - " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8806 sha256=57226a75b752bf852ac2f0f5ad878217a63376d6c44a4b29ccdf40b4921bf4bc\n", - " Stored in directory: /root/.cache/pip/wheels/4c/8e/7e/72fbc243e1aeecae64a96875432e70d4e92f3d2d18123be004\n", - "Successfully built gym moviepy AutoROM.accept-rom-license pathtools\n", - "Installing collected packages: smmap, gitdb, tensorstore, shortuuid, setproctitle, sentry-sdk, proglog, pathtools, libtorrent, imageio-ffmpeg, gymnasium-notices, gym, GitPython, docker-pycreds, commonmark, chex, cached-property, wandb, tensorboard, stable-baselines3, rich, pygame, orbax, optax, moviepy, huggingface-hub, gymnasium, AutoROM.accept-rom-license, AutoROM, flax, cleanrl-test, ale-py\n", - " Attempting uninstall: gym\n", - " Found existing installation: gym 0.25.2\n", - " Uninstalling gym-0.25.2:\n", - " Successfully uninstalled gym-0.25.2\n", - " Attempting uninstall: tensorboard\n", - " Found existing installation: tensorboard 2.9.1\n", - " Uninstalling tensorboard-2.9.1:\n", - " Successfully uninstalled tensorboard-2.9.1\n", - " Attempting uninstall: moviepy\n", - " Found existing installation: moviepy 0.2.3.5\n", - " Uninstalling moviepy-0.2.3.5:\n", - " Successfully uninstalled moviepy-0.2.3.5\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "tensorflow 2.9.2 requires tensorboard<2.10,>=2.9, but you have tensorboard 2.11.0 which is incompatible.\u001b[0m\n", - "Successfully installed AutoROM-0.4.2 AutoROM.accept-rom-license-0.5.0 GitPython-3.1.30 ale-py-0.7.4 cached-property-1.5.2 chex-0.1.5 cleanrl-test-1.1.2 commonmark-0.9.1 docker-pycreds-0.4.0 flax-0.6.3 gitdb-4.0.10 gym-0.23.1 gymnasium-0.26.3 gymnasium-notices-0.0.1 huggingface-hub-0.11.1 imageio-ffmpeg-0.4.7 libtorrent-2.0.7 moviepy-1.0.3 optax-0.1.4 orbax-0.0.23 pathtools-0.1.2 proglog-0.1.10 pygame-2.1.0 rich-13.0.0 sentry-sdk-1.9.0 setproctitle-1.3.2 shortuuid-1.0.11 smmap-5.0.0 stable-baselines3-1.2.0 tensorboard-2.11.0 tensorstore-0.1.28 wandb-0.13.7\n" - ] - } + "data": { + "text/html": [ + "" ], - "source": [ - "!pip install --upgrade \"cleanrl[dqn-atari-jax]\" # CAVEAT: the extra key is `dqn-atari-jax` with dashes instead of `dqn_atari_jax` with underscores" + "text/plain": [ + "" ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from IPython.display import Video\n", + "Video('videos/eval/rl-video-episode-0.mp4', embed=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WU29XP1ICwxv" + }, + "source": [ + "## Diving Deeper\n", + "\n", + "What happened above was achieved by a simple wrapper for [cleanrl_utils/evals/dqn_eval.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl_utils/evals/dqn_eval.py), which is pretty succinct and may give you a more fine-grained control and access to the model. Its content is roughly as follows, where it attempts to download a model from https://huggingface.co/cleanrl/BreakoutNoFrameskip-v4-dqn_atari_jax-seed1 and run an evaluation pass. " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "eZY6cAxkDJF5", + "outputId": "0144efd9-5d8e-4631-8a07-6385d8365558" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:138: DeprecationWarning: \u001B[33mWARN: Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. \u001B[0m\n", + " deprecation(\n", + "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:175: DeprecationWarning: \u001B[33mWARN: Function `_bigint_from_bytes(bytes)` is marked as deprecated and will be removed in the future. \u001B[0m\n", + " deprecation(\n", + "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:47: DeprecationWarning: \u001B[33mWARN: Function `rng.randint(low, [high, size, dtype])` is marked as deprecated and will be removed in the future. Please use `rng.integers(low, [high, size, dtype])` instead.\u001B[0m\n", + " deprecation(\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "xXQXZTh_AHZ0" + "name": "stdout", + "output_type": "stream", + "text": [ + "eval_episode=0, episodic_return=340.0\n", + "eval_episode=1, episodic_return=399.0\n" + ] + }, + { + "data": { + "text/plain": [ + "[340.0, 399.0]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import random\n", + "from typing import Callable\n", + "\n", + "import flax\n", + "import flax.linen as nn\n", + "import gym\n", + "import jax\n", + "import numpy as np\n", + "\n", + "\n", + "def evaluate(\n", + " model_path: str,\n", + " make_env: Callable,\n", + " env_id: str,\n", + " eval_episodes: int,\n", + " run_name: str,\n", + " Model: nn.Module,\n", + " epsilon: float = 0.05,\n", + " capture_video: bool = True,\n", + " seed=1,\n", + "):\n", + " envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)])\n", + " obs = envs.reset()\n", + " model = Model(action_dim=envs.single_action_space.n)\n", + " q_key = jax.random.PRNGKey(seed)\n", + " params = model.init(q_key, obs)\n", + " with open(model_path, \"rb\") as f:\n", + " params = flax.serialization.from_bytes(params, f.read())\n", + " model.apply = jax.jit(model.apply)\n", + "\n", + " episodic_returns = []\n", + " while len(episodic_returns) < eval_episodes:\n", + " if random.random() < epsilon:\n", + " actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])\n", + " else:\n", + " q_values = model.apply(params, obs)\n", + " actions = q_values.argmax(axis=-1)\n", + " actions = jax.device_get(actions)\n", + " next_obs, _, _, infos = envs.step(actions)\n", + " for info in infos:\n", + " if \"episode\" in info.keys():\n", + " print(f\"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}\")\n", + " episodic_returns += [info[\"episode\"][\"r\"]]\n", + " obs = next_obs\n", + "\n", + " return episodic_returns\n", + "\n", + "\n", + "from huggingface_hub import hf_hub_download\n", + "\n", + "from cleanrl.dqn_atari_jax import QNetwork, make_env\n", + "\n", + "model_path = hf_hub_download(repo_id=\"cleanrl/BreakoutNoFrameskip-v4-dqn_atari_jax-seed1\", filename=\"dqn_atari_jax.cleanrl_model\")\n", + "evaluate(\n", + " model_path,\n", + " make_env,\n", + " \"BreakoutNoFrameskip-v4\",\n", + " eval_episodes=2,\n", + " run_name=f\"eval\",\n", + " Model=QNetwork,\n", + " capture_video=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZxM0A6LmQtnn" + }, + "source": [ + "## More Examples\n", + "\n", + "Now let's get going with more examples!" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "TrQae62Y70H0" + }, + "outputs": [], + "source": [ + "import argparse\n", + "from dataclasses import dataclass\n", + "\n", + "from huggingface_hub import hf_hub_download\n", + "\n", + "try:\n", + " from pip import main as pipmain\n", + "except ImportError:\n", + " from pip._internal import main as pipmain\n", + "\n", + "@dataclass\n", + "class Args:\n", + " exp_name: str = \"dqn_atari_jax\"\n", + " seed: int = 1\n", + " hf_entity: str = \"cleanrl\"\n", + " hf_repository: str = \"\"\n", + " env_id: str = \"BreakoutNoFrameskip-v4\"\n", + "\n", + "\n", + "def dqn():\n", + " import cleanrl.dqn\n", + " import cleanrl_utils.evals.dqn_eval\n", + " return cleanrl.dqn.QNetwork, cleanrl.dqn.make_env, cleanrl_utils.evals.dqn_eval.evaluate\n", + "\n", + "def dqn_atari():\n", + " import cleanrl.dqn_atari\n", + " import cleanrl_utils.evals.dqn_eval\n", + " return cleanrl.dqn_atari.QNetwork, cleanrl.dqn_atari.make_env, cleanrl_utils.evals.dqn_eval.evaluate\n", + "\n", + "def dqn_jax():\n", + " import cleanrl.dqn_jax\n", + " import cleanrl_utils.evals.dqn_jax_eval\n", + " return cleanrl.dqn_jax.QNetwork, cleanrl.dqn_jax.make_env, cleanrl_utils.evals.dqn_jax_eval.evaluate\n", + "\n", + "def dqn_atari_jax():\n", + " import cleanrl.dqn_atari_jax\n", + " import cleanrl_utils.evals.dqn_jax_eval\n", + " return cleanrl.dqn_atari_jax.QNetwork, cleanrl.dqn_atari_jax.make_env, cleanrl_utils.evals.dqn_jax_eval.evaluate\n", + "\n", + "MODELS = {\n", + " \"dqn\": dqn,\n", + " \"dqn_atari\": dqn_atari,\n", + " \"dqn_jax\": dqn_jax,\n", + " \"dqn_atari_jax\": dqn_atari_jax,\n", + "}\n", + "\n", + "\n", + "\n", + "exp_names = [\"dqn\", \"dqn_jax\", \"dqn_atari_jax\", \"dqn_atari\"]\n", + "env_idss = [\n", + " [\n", + " \"CartPole-v1\",\n", + " \"Acrobot-v1\",\n", + " \"MountainCar-v0\",\n", + " ],\n", + " [\n", + " \"CartPole-v1\",\n", + " \"Acrobot-v1\",\n", + " \"MountainCar-v0\",\n", + " ],\n", + " [\n", + " \"BreakoutNoFrameskip-v4\",\n", + " \"PongNoFrameskip-v4\",\n", + " \"BeamRiderNoFrameskip-v4\"\n", + " ],\n", + " [\n", + " \"BreakoutNoFrameskip-v4\",\n", + " \"PongNoFrameskip-v4\",\n", + " \"BeamRiderNoFrameskip-v4\"\n", + " ]\n", + " ]\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "IeksFU1me8q8" + }, + "source": [ + "### Install dependencies for each variant" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dnvpgpWWfABl", + "outputId": "1e41abbf-d9c4-4adf-fe05-40e8e31962f4" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/pip/_vendor/packaging/version.py:127: DeprecationWarning: Creating a LegacyVersion has been deprecated and will be removed in the next major release\n", + " warnings.warn(\n", + "/usr/local/lib/python3.8/dist-packages/pip/_vendor/packaging/version.py:127: DeprecationWarning: Creating a LegacyVersion has been deprecated and will be removed in the next major release\n", + " warnings.warn(\n", + "WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.\n", + "Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.\n", + "To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.\n", + "/usr/local/lib/python3.8/dist-packages/pip/_vendor/packaging/version.py:127: DeprecationWarning: Creating a LegacyVersion has been deprecated and will be removed in the next major release\n", + " warnings.warn(\n", + "WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.\n", + "Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.\n", + "To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==== ['install', '--upgrade', 'cleanrl[dqn]', '--quiet']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/pip/_vendor/packaging/version.py:127: DeprecationWarning: Creating a LegacyVersion has been deprecated and will be removed in the next major release\n", + " warnings.warn(\n", + "WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.\n", + "Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.\n", + "To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==== ['install', '--upgrade', 'cleanrl[dqn-jax]', '--quiet']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/pip/_vendor/packaging/version.py:127: DeprecationWarning: Creating a LegacyVersion has been deprecated and will be removed in the next major release\n", + " warnings.warn(\n", + "WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.\n", + "Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.\n", + "To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==== ['install', '--upgrade', 'cleanrl[dqn-atari-jax]', '--quiet']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/pip/_vendor/packaging/version.py:127: DeprecationWarning: Creating a LegacyVersion has been deprecated and will be removed in the next major release\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==== ['install', '--upgrade', 'cleanrl[dqn-atari]', '--quiet']\n" + ] + } + ], + "source": [ + "for exp_name, env_ids in zip(exp_names, env_idss):\n", + " # install dependencies for the algorithm variant\n", + " pipmain(['install', '--upgrade', f'cleanrl[{exp_name.replace(\"_\", \"-\")}]', \"--quiet\"])\n", + " print(\"====\", ['install', '--upgrade', f'cleanrl[{exp_name.replace(\"_\", \"-\")}]', \"--quiet\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "van2E4jFfC2f" + }, + "source": [ + "# Enjoy!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "1af264779a2442e596aed8e620561248", + "21bff070d7b342a3b1ca5c9976746a6f", + "82dc6a3d16e24066a8531f1390eb0450", + "df36fcd7bf4c47f48914e43866da1edf", + "f31718702d1a433eac7388ef1612a149", + "e6179240e76b4a2abd79e31b44c8bcf5", + "6e8a545a6d8e43a4ac9cc3a5cd8c06b6", + "1bbd3551eda24347b078a78c54ef9808", + "a3e4a40738274a48ad76090b7e3593ec", + "11d8ea24816b4d9487b5cf701c77368f", + "7bcc89fefa314206aef5e24d1b105b50", + "0c074497102c45aab5db63d863a493a5", + "8ffa468c4cee423ea37eb5c6a3008a8d", + "71728d572be04c4a816e09c4682dd254", + "54f10fce2aa44083bdddd67209f75097", + "b49b12ca89c5448a8994373aabb8c2de", + "3ac10a61d2394023b867276369d75d94", + "8676b5de0bb24b49a2624ce56e57b041", + "d537dbcb8a604e96ab1445a44f4f5795", + "3c2f669e4f5a472288f275d3c383df38", + "a538ea3a39db4b62b90e8fa84cc96ae0", + "96e4392eab1e468e990ca5e3ce239ee2", + "d47a0fe9206746618d6513a725629706", + "a69f93596efd47ebb3daa28aa4193333", + "f47b25f7ba814ad79fc49d5469f073f1", + "42d1d9807ecd4ac79a8a339ebec90c7d", + "0f846fe59b4f4f9595640cd537d5d359", + "242d6e4726bc4b44b51fbefbe9e71d89", + "f55821cfe861470d91254ca6fb1686d8", + "60c2723becc441f490db7ac57f0db14f", + "335346c51b404ec38d28451cd64bfb1d", + "b082ad5929d44cfb873d686178af22f6", + "707972069edb433d808fde341ae797bf", + "85d693c15c7f402d8cbeb3b2dfa1204b", + "ff5d66561809464cba45624ac5c6db4d", + "a833d25fce5c49d89babd5efa6a9e7a1", + "bbf1f621e72e46b5bb34f83de5aad104", + "c0701fd6a60241449580c09929d8a23c", + "5d2f9340fa314329846b680590f7b983", + "c5a49382f0fd40f2989b36d6ad63e5d5", + "4df7f8fbc6c646dc81a1c63fe5167618", + "92a92c1cec1445bca8bda94543c31091", + "7d57323e580e48c789a834ba8f324609", + "5ef4b0f2e682408ea74d44f07863e726", + "03848b157e164490a7a509028df7cad8", + "744e9685a89749738871f14718a58c56", + "442311424f9449cda8f697eec946ba04", + "0677049b6f5f4cfaa662300b8063bf8d", + "6887b96a72134b579780ce8fb4fec51c", + "1ae630c1bb96429595e4c8f64a8cd978", + "54333bb25bf74274bf37e62295bc5a90", + "10a31b06ec11402ea041ee5e048bd9be", + "a72a928130e3435892a666a59ce3f9ae", + "198c4e87657e4a02bcbc0db1306f842c", + "c12145376a554a9dab28277865509cc5", + "352eb266cd374d0f91a3d6628fe4df97", + "47a86cb59e43492b9d38ade53cc5bfa7", + "2a63eecc839f49f7ae8f2d974a922664", + "82a3c7c1a51c473cbe5c8c3fc3f6c7ce", + "b2af0dd940ae4b6893a86778df1b7877", + "1d0c0d2e16b046209c80fcd0a9392a41", + "31d97d0b38414773953861efb7d10afb", + "da450c588af44c78af9d665e5ef9bed9", + "e4124e064efb43bb8acfda5052874b5d", + "73d5d725a38e457eb078999c64171a82", + "cac8aede09734dac8e1a638b67da49cc", + "f31d98b28d5d4174b6fff125d0cfc169", + "d2708e38223d4f5397e1856896560c4d", + "90fddc80c9ba4b669065f2af725a814b", + "a71ba0aed6394d8ab1ca46dc21ee03ca", + "2a79f07e770c40cf821d9f7ab860e99b", + "8ec004e73c7f4b28b85365dcb959663c", + "651abec81a2b47b99fd28db7774f6b3f", + "65efc653defb40778d74abeed961ddf8", + "bb0018f331e647ab9e0b40b6b5e683b3", + "ad7717b0980140c58befed77c4f70250", + "aa85a5bb82e3445b929a198e246a19a0", + "29cf4ba923dc4bb4979b74035a1f50bf", + "6409943fc9bc4fc687bb16b4cda9d07c", + "45d8390f2e26447e81c2522f837c5a9f", + "11be98dd8811486dac47267f48244247", + "7db3fca11e5740e6801d26d40bf91076", + "a98d0644e9a047199464e16d46962088", + "af20581002334cd9af3833c59cd4c070", + "4a25a274b4034181b7d307490c3d5db1", + "78f2427ad9c64f56bf1481c3190e8c6c", + "7438c142a37748548fb84e732855d543", + "7f30a20de166487781da41a13bc3c0e4", + "1100f289ac0e45fb965d6221438c37f3", + "0e990a8ce0b744dfb2a0b53e98eca805", + "8b0029aeefbd4192a896f49607f1ed87", + "4d2c7cdb2eee47f9ac4e59d0fed9b4cd", + "90670ba8064c47fc9d84bcf889ba8b23", + "0e4b221ec94b4404a8a6f7b429154855", + "0b749bec4b2646a28d6f5f532149a8be", + "bac297e2dfc946509091a298e5d602b9", + "52d1eee0f09c4aee9454294a4379bece", + "f0f16ec872c44321a64387ffb0621ad4", + "1542a1c6fc944d219f70149236484bb0", + "0c1c34ffcc3c4071a45d9e48dc01459a", + "89357399364d43a493d7f1ad58f2ea5c", + "1f543ff2d15b4bcd9ca5ea227bf6025e", + "4d1d17e87b0b48a6bbe3f901640d5ae6", + "67e5ec2a25bd41f0b9e62fc6f0395a80", + "18f6e921b61d45fdb26bbff5531d3eb1", + "11df9e9525fb4616ae9fc563c104d2e7", + "5f54fef47da64724b7e7bf59cbcf3f35", + "1cbf2c8edb9246b49ff9334c655a3849", + "8e67dc3a52574897abf9dcb6bbc7c8ec", + "321a70d9c81f4d17b497f13328589153", + "f5fb6c114393429f8ad158d916d2b837", + "54a5be97eea9498ebb1722a65529006f", + "3273622ad1ce442188d54df94c6f737c", + "a34b45c0b3c449e9a914588cfaa375b7", + "64d68f2d555f47e8abe502dd212c8ff4", + "9db1d2d8d8b0433ca67629e00bb977bd", + "8a6e2d8cb9c5404d8e3f5ff2543b4409", + "a11b2e1f12f5460595aded3f4181d64f", + "3d947d7e8dd1446ca6107edccba6a183", + "3585b4c2b37547b38a386003f9370aa3", + "7c82ac27f8fb4fc1a87d89f415f2e2c5" + ] + }, + "id": "NkHbB0R4ewnE", + "outputId": "61d1d3d5-5114-456d-b3ff-8333b52d9f9a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading models from cleanrl/CartPole-v1-dqn-seed1\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1af264779a2442e596aed8e620561248", + "version_major": 2, + "version_minor": 0 }, - "source": [ - "## Enjoy Utility\n", - "\n", - "We have a simple way to load the model by running our \"enjoy\" utility, which automatically pull the model from 🤗 HuggingFace and run for a few episodes. It also produces a rendered video through the `--capture_video` flag. See more at our [📜 Documentation](https://docs.cleanrl.dev/get-started/zoo/)." + "text/plain": [ + "Downloading: 0%| | 0.00/45.8k [00:00\n", - " \n", - " Your browser does not support the video tag.\n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from IPython.display import Video\n", - "Video('videos/eval/rl-video-episode-0.mp4', embed=True)" + "text/plain": [ + "Downloading: 0%| | 0.00/45.1k [00:00 Date: Thu, 18 Jan 2024 11:14:51 -0500 Subject: [PATCH 11/20] Fix SuperSuit to most recent version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index eeca1634c..be0bd68e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ optuna = {version = "^3.0.1", optional = true} optuna-dashboard = {version = "^0.7.2", optional = true} envpool = {version = "^0.6.4", optional = true} PettingZoo = {version = "^1.24.2", optional = true} -SuperSuit = {version = "^3.8.1", optional = true} +SuperSuit = {version = "^3.9.1", optional = true} multi-agent-ale-py = {version = "0.1.11", optional = true} boto3 = {version = "^1.24.70", optional = true} awscli = {version = "^1.31.0", optional = true} From d7a2aa2132444c226009f85371acff97bc9a885e Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 18 Jan 2024 11:25:04 -0500 Subject: [PATCH 12/20] Fix SuperSuit version in poetry lockfile and tinyscaler in pettingzoo reqs (subdependency of supersuit) --- poetry.lock | 92 ++++++++++++++---------- requirements/requirements-pettingzoo.txt | 2 +- 2 files changed, 54 insertions(+), 40 deletions(-) diff --git a/poetry.lock b/poetry.lock index f30baf28c..785b1aaf9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -1080,7 +1080,6 @@ files = [ {file = "greenlet-2.0.2-cp27-cp27m-win32.whl", hash = "sha256:6c3acb79b0bfd4fe733dff8bc62695283b57949ebcca05ae5c129eb606ff2d74"}, {file = "greenlet-2.0.2-cp27-cp27m-win_amd64.whl", hash = "sha256:283737e0da3f08bd637b5ad058507e578dd462db259f7f6e4c5c365ba4ee9343"}, {file = "greenlet-2.0.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:d27ec7509b9c18b6d73f2f5ede2622441de812e7b1a80bbd446cb0633bd3d5ae"}, - {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d967650d3f56af314b72df7089d96cda1083a7fc2da05b375d2bc48c82ab3f3c"}, {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:30bcf80dda7f15ac77ba5af2b961bdd9dbc77fd4ac6105cee85b0d0a5fcf74df"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26fbfce90728d82bc9e6c38ea4d038cba20b7faf8a0ca53a9c07b67318d46088"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9190f09060ea4debddd24665d6804b995a9c122ef5917ab26e1566dcc712ceeb"}, @@ -1089,7 +1088,6 @@ files = [ {file = "greenlet-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:76ae285c8104046b3a7f06b42f29c7b73f77683df18c49ab5af7983994c2dd91"}, {file = "greenlet-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:2d4686f195e32d36b4d7cf2d166857dbd0ee9f3d20ae349b6bf8afc8485b3645"}, {file = "greenlet-2.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c4302695ad8027363e96311df24ee28978162cdcdd2006476c43970b384a244c"}, - {file = "greenlet-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d4606a527e30548153be1a9f155f4e283d109ffba663a15856089fb55f933e47"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c48f54ef8e05f04d6eff74b8233f6063cb1ed960243eacc474ee73a2ea8573ca"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a1846f1b999e78e13837c93c778dcfc3365902cfb8d1bdb7dd73ead37059f0d0"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a06ad5312349fec0ab944664b01d26f8d1f05009566339ac6f63f56589bc1a2"}, @@ -1119,7 +1117,6 @@ files = [ {file = "greenlet-2.0.2-cp37-cp37m-win32.whl", hash = "sha256:3f6ea9bd35eb450837a3d80e77b517ea5bc56b4647f5502cd28de13675ee12f7"}, {file = "greenlet-2.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:7492e2b7bd7c9b9916388d9df23fa49d9b88ac0640db0a5b4ecc2b653bf451e3"}, {file = "greenlet-2.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b864ba53912b6c3ab6bcb2beb19f19edd01a6bfcbdfe1f37ddd1778abfe75a30"}, - {file = "greenlet-2.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1087300cf9700bbf455b1b97e24db18f2f77b55302a68272c56209d5587c12d1"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:ba2956617f1c42598a308a84c6cf021a90ff3862eddafd20c3333d50f0edb45b"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc3a569657468b6f3fb60587e48356fe512c1754ca05a564f11366ac9e306526"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8eab883b3b2a38cc1e050819ef06a7e6344d4a990d24d45bc6f2cf959045a45b"}, @@ -1128,7 +1125,6 @@ files = [ {file = "greenlet-2.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0ef99cdbe2b682b9ccbb964743a6aca37905fda5e0452e5ee239b1654d37f2a"}, {file = "greenlet-2.0.2-cp38-cp38-win32.whl", hash = "sha256:b80f600eddddce72320dbbc8e3784d16bd3fb7b517e82476d8da921f27d4b249"}, {file = "greenlet-2.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:4d2e11331fc0c02b6e84b0d28ece3a36e0548ee1a1ce9ddde03752d9b79bba40"}, - {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8512a0c38cfd4e66a858ddd1b17705587900dd760c6003998e9472b77b56d417"}, {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:88d9ab96491d38a5ab7c56dd7a3cc37d83336ecc564e4e8816dbed12e5aaefc8"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:561091a7be172ab497a3527602d467e2b3fbe75f9e783d8b8ce403fa414f71a6"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:971ce5e14dc5e73715755d0ca2975ac88cfdaefcaab078a284fea6cfabf866df"}, @@ -1707,6 +1703,11 @@ files = [ {file = "labmaze-1.0.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70635d1cdb0147a02efb6b3f607a52cdc51723bc3dcc42717a0d4ef55fa0a987"}, {file = "labmaze-1.0.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff472793238bd9b6dabea8094594d6074ad3c111455de3afcae72f6c40c6817e"}, {file = "labmaze-1.0.6-cp311-cp311-win_amd64.whl", hash = "sha256:2317e65e12fa3d1abecda7e0488dab15456cee8a2e717a586bfc8f02a91579e7"}, + {file = "labmaze-1.0.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e36b6fadcd78f22057b597c1c77823e806a0987b3bdfbf850e14b6b5b502075e"}, + {file = "labmaze-1.0.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d1a4f8de29c2c3d7f14163759b69cd3f237093b85334c983619c1db5403a223b"}, + {file = "labmaze-1.0.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a394f8bb857fcaa2884b809d63e750841c2662a106cfe8c045f2112d201ac7d5"}, + {file = "labmaze-1.0.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d17abb69d4dfc56183afb5c317e8b2eaca0587abb3aabd2326efd3143c81f4e"}, + {file = "labmaze-1.0.6-cp312-cp312-win_amd64.whl", hash = "sha256:5af997598cc46b1929d1c5a1febc32fd56c75874fe481a2a5982c65cee8450c9"}, {file = "labmaze-1.0.6-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:a4c5bc6e56baa55ce63b97569afec2f80cab0f6b952752a131e1f83eed190a53"}, {file = "labmaze-1.0.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3955f24fe5f708e1e97495b4cfe284b70ae4fd51be5e17b75a6fc04ffbd67bca"}, {file = "labmaze-1.0.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed96ddc0bb8d66df36428c94db83949fd84a15867e8250763a4c5e3d82104c54"}, @@ -2645,29 +2646,28 @@ files = [ [[package]] name = "pettingzoo" -version = "1.18.1" -description = "Gym for multi-agent reinforcement learning" +version = "1.24.2" +description = "Gymnasium for multi-agent reinforcement learning." optional = true -python-versions = ">=3.7, <3.11" +python-versions = ">=3.8" files = [ - {file = "PettingZoo-1.18.1-py3-none-any.whl", hash = "sha256:25ae45fcfa2c623800e1f81b98ae50f5f5a1af6caabc5946764248de71a0371d"}, - {file = "PettingZoo-1.18.1.tar.gz", hash = "sha256:7e6a3231dc3fc3801af83fe880f199f570d46a9acdcb990f2a223f121b6e5038"}, + {file = "pettingzoo-1.24.2-py3-none-any.whl", hash = "sha256:00268cf990d243654c2bbbbf8c88322c12b041dc0a879b74747f14ee8aa93dd6"}, + {file = "pettingzoo-1.24.2.tar.gz", hash = "sha256:0a5856d47de78ab20feddfdac4940959dc892f6becc92107247b1c3a210c0984"}, ] [package.dependencies] -gym = ">=0.21.0" -numpy = ">=1.18.0" +gymnasium = ">=0.28.0" +numpy = ">=1.21.0" [package.extras] -all = ["box2d-py (==2.3.5)", "chess (==1.7.0)", "hanabi-learning-environment (==0.0.1)", "magent (==0.2.2)", "multi-agent-ale-py (==0.1.11)", "pillow (>=8.0.1)", "pygame (==2.1.0)", "pyglet (>=1.4.0)", "pymunk (==6.2.0)", "rlcard (==1.0.4)", "scipy (>=1.4.1)"] -atari = ["multi-agent-ale-py (==0.1.11)", "pygame (==2.1.0)"] -butterfly = ["pygame (==2.1.0)", "pymunk (==6.2.0)"] -classic = ["chess (==1.7.0)", "hanabi-learning-environment (==0.0.1)", "pygame (==2.1.0)", "rlcard (==1.0.4)"] -magent = ["magent (==0.2.2)"] -mpe = ["pyglet (>=1.4.0)"] +all = ["box2d-py (==2.3.5)", "chess (==1.9.4)", "multi-agent-ale-py (==0.1.11)", "pillow (>=8.0.1)", "pygame (==2.3.0)", "pymunk (==6.2.0)", "rlcard (==1.0.5)", "scipy (>=1.4.1)", "shimmy[openspiel] (>=1.2.0)"] +atari = ["multi-agent-ale-py (==0.1.11)", "pygame (==2.3.0)"] +butterfly = ["pygame (==2.3.0)", "pymunk (==6.2.0)"] +classic = ["chess (==1.9.4)", "pygame (==2.3.0)", "rlcard (==1.0.5)", "shimmy[openspiel] (>=1.2.0)"] +mpe = ["pygame (==2.3.0)"] other = ["pillow (>=8.0.1)"] -sisl = ["box2d-py (==2.3.5)", "pygame (==2.1.0)", "scipy (>=1.4.1)"] -tests = ["codespell", "flake8", "isort", "pynput", "pytest"] +sisl = ["box2d-py (==2.3.5)", "pygame (==2.3.0)", "pymunk (==6.2.0)", "scipy (>=1.4.1)"] +testing = ["AutoROM", "pre-commit", "pynput", "pytest", "pytest-cov", "pytest-markdown-docs", "pytest-xdist"] [[package]] name = "pillow" @@ -3707,19 +3707,22 @@ tests = ["black", "isort (>=5.0)", "mypy", "pytest", "pytest-cov", "pytest-env", [[package]] name = "supersuit" -version = "3.4.0" -description = "Wrappers for Gym and PettingZoo" +version = "3.9.1" +description = "Wrappers for Gymnasium and PettingZoo" optional = true -python-versions = ">=3.7" +python-versions = "<3.12,>=3.8" files = [ - {file = "SuperSuit-3.4.0-py3-none-any.whl", hash = "sha256:45b541b2b29faffd6494b53d649c8d94889966f407fd380b3e3211f9e68a49e9"}, - {file = "SuperSuit-3.4.0.tar.gz", hash = "sha256:5999beec8d7923c11c9511eaa9dec8a38269cb0d7af029e17903c79234233409"}, + {file = "SuperSuit-3.9.1-py3-none-any.whl", hash = "sha256:24907f8edb9578c8b35eb374e53fdde96daf37c006d8e929c7bf485e5c52f356"}, + {file = "SuperSuit-3.9.1.tar.gz", hash = "sha256:536732019e5f00420a17a7e3078a73824191515b6b0af37b06322d4846cda655"}, ] [package.dependencies] -gym = ">=0.22.0" -pettingzoo = ">=1.15.0" -tinyscaler = ">=1.0.4" +gymnasium = ">=0.28.1" +numpy = ">=1.19.0" +tinyscaler = ">=1.2.6" + +[package.extras] +testing = ["pettingzoo[butterfly,classic] (>=1.23.1)", "pytest"] [[package]] name = "tabulate" @@ -3837,20 +3840,31 @@ files = [ [[package]] name = "tinyscaler" -version = "1.2.5" -description = "A tiny, simple image scaler" +version = "1.2.7" +description = "A tiny, simple image scaler." optional = true -python-versions = ">=3.7, <3.11" -files = [ - {file = "tinyscaler-1.2.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f32d794fc2b9c5e4aa3b435d040f9e75b11f55ab41b32580f2c8e8dfb350f25"}, - {file = "tinyscaler-1.2.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4af0a9502e9ef118c84de80b09544407c8dbbe815af215b1abb8eb170271ab71"}, - {file = "tinyscaler-1.2.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0bde14fb15027d73f4cc5ac837e849feb1cbedbfc0a0c0928f11756f08f6626"}, - {file = "tinyscaler-1.2.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46c75799068330ff7c28fd01f10409d4f12c22f1adbe732f1699228449a4d712"}, - {file = "tinyscaler-1.2.5.tar.gz", hash = "sha256:deb47df1a53a55b53f0ae15b89b4814af184d149a8149385e54e11afc57364a5"}, +python-versions = "<3.12,>=3.7" +files = [ + {file = "tinyscaler-1.2.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:bbb98ced396d4829a41aa9c7c895df4bcb3801a3bbe963978c90d12b07110731"}, + {file = "tinyscaler-1.2.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d062e0e33f6104d625fff9b57aa53511c39d2dc3bb711686f6992a7fbfe41336"}, + {file = "tinyscaler-1.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:a96f008975d4d167102a2671fb54fb6ace6ff2580fede3b79daeca99a01e5d6e"}, + {file = "tinyscaler-1.2.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bf63243a08e214e3db149435741b779db357c376636e17ddf153bf9f6ada041c"}, + {file = "tinyscaler-1.2.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d2d129f71c518d9c0c25f5e9f3a7f7a31af62e7e7e6f8750ddf0154ed76a58a"}, + {file = "tinyscaler-1.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:c14d302cd609d8c8e53ddf15b3ab43fa3c975d648ffcf16276c8b131ab849f85"}, + {file = "tinyscaler-1.2.7-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:3ef723fbe119614dfdd8a7bd40d73c17defaac6765f60c44693858bd5cd70fbc"}, + {file = "tinyscaler-1.2.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6700a37bd42615944099994f2aa473be215e25d79803fcac9de849205c7b149"}, + {file = "tinyscaler-1.2.7-cp37-cp37m-win_amd64.whl", hash = "sha256:26d488778686392a0441e598df7ebc45ad014663e60384ef6170dd793f80d275"}, + {file = "tinyscaler-1.2.7-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:f80203589d883896c86fe94165967be453fbb0fe47c9bc64521aee15e125f202"}, + {file = "tinyscaler-1.2.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef4843aaa2647d7ae7a26ba66dbd1d1b31d161ca558f2c385bc6b02277d27fdb"}, + {file = "tinyscaler-1.2.7-cp38-cp38-win_amd64.whl", hash = "sha256:8e6b605ef00fc65a27f294742514f67d9b4c37d41bfe586e2609ab03a41f2e74"}, + {file = "tinyscaler-1.2.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:2847420c81064c8bd3397bdcd83e2706cd914cdb9cbde5300ed968c14954b9d3"}, + {file = "tinyscaler-1.2.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d86fe85fa37cfaedb521c9eb3a804b6ab202924be221049b784220c5ca49546"}, + {file = "tinyscaler-1.2.7-cp39-cp39-win_amd64.whl", hash = "sha256:ce1e10fc54d02bb49ea1f72f76d320c50739eb4ff3e6cbb82148b4f84272220b"}, + {file = "tinyscaler-1.2.7.tar.gz", hash = "sha256:1c0b34b41cca3ae9b09c20fee27499833345b9264617bdd23c896733676d82d8"}, ] [package.dependencies] -numpy = "*" +numpy = ">=1.21.0" [[package]] name = "tomli" @@ -4245,4 +4259,4 @@ qdagger-dqn-atari-jax-impalacnn = ["AutoROM", "ale-py", "flax", "jax", "jaxlib", [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "ce1dd6a428e94e30643d2fb0a3fd13f0132d176185a91f7685392d4ec0e7892b" +content-hash = "9300d30d6b5fa37bd0f2f5604f78c758c988830bc7df8f75df7b903e63921952" diff --git a/requirements/requirements-pettingzoo.txt b/requirements/requirements-pettingzoo.txt index 461c6023f..48fdf76e3 100644 --- a/requirements/requirements-pettingzoo.txt +++ b/requirements/requirements-pettingzoo.txt @@ -68,7 +68,7 @@ tenacity==8.2.3 ; python_version >= "3.8" and python_version < "3.11" tensorboard-data-server==0.6.1 ; python_version >= "3.8" and python_version < "3.11" tensorboard-plugin-wit==1.8.1 ; python_version >= "3.8" and python_version < "3.11" tensorboard==2.11.2 ; python_version >= "3.8" and python_version < "3.11" -tinyscaler==1.2.5 ; python_version >= "3.8" and python_version < "3.11" +tinyscaler==1.2.7 ; python_version >= "3.8" and python_version < "3.11" torch==1.12.1 ; python_version >= "3.8" and python_version < "3.11" tqdm==4.65.0 ; python_version >= "3.8" and python_version < "3.11" typing-extensions==4.5.0 ; python_version >= "3.8" and python_version < "3.11" From d77cca060c43a987b3c78be73e5e65e10d953c32 Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 18 Jan 2024 11:50:07 -0500 Subject: [PATCH 13/20] Fix pettingzoo-requirements export (pre-commit hooks) --- requirements/requirements-pettingzoo.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/requirements-pettingzoo.txt b/requirements/requirements-pettingzoo.txt index 48fdf76e3..2dacec44d 100644 --- a/requirements/requirements-pettingzoo.txt +++ b/requirements/requirements-pettingzoo.txt @@ -39,7 +39,7 @@ oauthlib==3.2.2 ; python_version >= "3.8" and python_version < "3.11" packaging==23.1 ; python_version >= "3.8" and python_version < "3.11" pandas==1.3.5 ; python_version >= "3.8" and python_version < "3.11" pathtools==0.1.2 ; python_version >= "3.8" and python_version < "3.11" -pettingzoo==1.18.1 ; python_version >= "3.8" and python_version < "3.11" +pettingzoo==1.24.2 ; python_version >= "3.8" and python_version < "3.11" pillow==9.5.0 ; python_version >= "3.8" and python_version < "3.11" proglog==0.1.10 ; python_version >= "3.8" and python_version < "3.11" protobuf==3.20.3 ; python_version < "3.11" and python_version >= "3.8" @@ -63,7 +63,7 @@ shtab==1.6.4 ; python_version >= "3.8" and python_version < "3.11" six==1.16.0 ; python_version >= "3.8" and python_version < "3.11" smmap==5.0.0 ; python_version >= "3.8" and python_version < "3.11" stable-baselines3==2.0.0 ; python_version >= "3.8" and python_version < "3.11" -supersuit==3.4.0 ; python_version >= "3.8" and python_version < "3.11" +supersuit==3.9.1 ; python_version >= "3.8" and python_version < "3.11" tenacity==8.2.3 ; python_version >= "3.8" and python_version < "3.11" tensorboard-data-server==0.6.1 ; python_version >= "3.8" and python_version < "3.11" tensorboard-plugin-wit==1.8.1 ; python_version >= "3.8" and python_version < "3.11" From afba4e85d2dd20959176fd8b1ea0069ebb5ab3a1 Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 18 Jan 2024 12:40:35 -0500 Subject: [PATCH 14/20] Test updating pettingzoo to new version 1.24.3 --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- requirements/requirements-pettingzoo.txt | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/poetry.lock b/poetry.lock index 785b1aaf9..657c56357 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2646,13 +2646,13 @@ files = [ [[package]] name = "pettingzoo" -version = "1.24.2" +version = "1.24.3" description = "Gymnasium for multi-agent reinforcement learning." optional = true python-versions = ">=3.8" files = [ - {file = "pettingzoo-1.24.2-py3-none-any.whl", hash = "sha256:00268cf990d243654c2bbbbf8c88322c12b041dc0a879b74747f14ee8aa93dd6"}, - {file = "pettingzoo-1.24.2.tar.gz", hash = "sha256:0a5856d47de78ab20feddfdac4940959dc892f6becc92107247b1c3a210c0984"}, + {file = "pettingzoo-1.24.3-py3-none-any.whl", hash = "sha256:23ed90517d2e8a7098bdaf5e31234b3a7f7b73ca578d70d1ca7b9d0cb0e37982"}, + {file = "pettingzoo-1.24.3.tar.gz", hash = "sha256:91f9094f18e06fb74b98f4099cd22e8ae4396125e51719d50b30c9f1c7ab07e6"}, ] [package.dependencies] @@ -4259,4 +4259,4 @@ qdagger-dqn-atari-jax-impalacnn = ["AutoROM", "ale-py", "flax", "jax", "jaxlib", [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "9300d30d6b5fa37bd0f2f5604f78c758c988830bc7df8f75df7b903e63921952" +content-hash = "453c8d2c113d81cb529f771d450301c0a4fa6d5ab0bfc3964a110d650ee7db39" diff --git a/pyproject.toml b/pyproject.toml index be0bd68e6..1b45385a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ flax = {version = "0.6.8", optional = true} optuna = {version = "^3.0.1", optional = true} optuna-dashboard = {version = "^0.7.2", optional = true} envpool = {version = "^0.6.4", optional = true} -PettingZoo = {version = "^1.24.2", optional = true} +PettingZoo = {version = "^1.24.3", optional = true} SuperSuit = {version = "^3.9.1", optional = true} multi-agent-ale-py = {version = "0.1.11", optional = true} boto3 = {version = "^1.24.70", optional = true} diff --git a/requirements/requirements-pettingzoo.txt b/requirements/requirements-pettingzoo.txt index 2dacec44d..abcff76d2 100644 --- a/requirements/requirements-pettingzoo.txt +++ b/requirements/requirements-pettingzoo.txt @@ -39,7 +39,7 @@ oauthlib==3.2.2 ; python_version >= "3.8" and python_version < "3.11" packaging==23.1 ; python_version >= "3.8" and python_version < "3.11" pandas==1.3.5 ; python_version >= "3.8" and python_version < "3.11" pathtools==0.1.2 ; python_version >= "3.8" and python_version < "3.11" -pettingzoo==1.24.2 ; python_version >= "3.8" and python_version < "3.11" +pettingzoo==1.24.3 ; python_version >= "3.8" and python_version < "3.11" pillow==9.5.0 ; python_version >= "3.8" and python_version < "3.11" proglog==0.1.10 ; python_version >= "3.8" and python_version < "3.11" protobuf==3.20.3 ; python_version < "3.11" and python_version >= "3.8" From 86711541f699fa631f4ae742e27eecb8dedc6220 Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 18 Jan 2024 13:33:12 -0500 Subject: [PATCH 15/20] Update ma_atari to match regular atari (tyro, minor code style changes) --- benchmark/zoo.sh | 10 +-- cleanrl/ppo_pettingzoo_ma_atari.py | 139 +++++++++++++++-------------- 2 files changed, 76 insertions(+), 73 deletions(-) diff --git a/benchmark/zoo.sh b/benchmark/zoo.sh index a5ab38e14..95779fe37 100644 --- a/benchmark/zoo.sh +++ b/benchmark/zoo.sh @@ -1,25 +1,25 @@ poetry run python cleanrl/dqn_jax.py --env-id CartPole-v1 --save-model --upload-model --hf-entity cleanrl poetry run python cleanrl/dqn_atari_jax.py --env-id SeaquestNoFrameskip-v4 --save-model --upload-model --hf-entity cleanrl -xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ +poetry run python -m cleanrl_utils.benchmark \ --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ --command "poetry run python cleanrl/dqn.py --no_cuda --track --capture_video --save-model --upload-model --hf-entity cleanrl" \ --num-seeds 1 \ --workers 1 -CUDA_VISIBLE_DEVICES="-1" xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ +CUDA_VISIBLE_DEVICES="-1" poetry run python -m cleanrl_utils.benchmark \ --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ --command "poetry run python cleanrl/dqn_jax.py --track --capture_video --save-model --upload-model --hf-entity cleanrl" \ --num-seeds 1 \ --workers 1 -xvfb-run -a python -m cleanrl_utils.benchmark \ +python -m cleanrl_utils.benchmark \ --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ --command "poetry run python cleanrl/dqn_atari_jax.py --track --capture_video --save-model --upload-model --hf-entity cleanrl" \ --num-seeds 1 \ --workers 1 -xvfb-run -a python -m cleanrl_utils.benchmark \ +python -m cleanrl_utils.benchmark \ --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ --command "poetry run python cleanrl/dqn_atari.py --track --capture_video --save-model --upload-model --hf-entity cleanrl" \ --num-seeds 1 \ @@ -31,7 +31,7 @@ python -m cleanrl_utils.benchmark \ --num-seeds 1 \ --workers 1 -CUDA_VISIBLE_DEVICES="1" taskset --cpu-list 16,17,18,19,20,21,22,23 python -m cleanrl_utils.benchmark \ +CUDA_VISIBLE_DEVICES="1" python -m cleanrl_utils.benchmark \ --env-ids Breakout-v5 \ --command "poetry run python cleanrl/ppo_atari_envpool_xla_jax_scan.py --track --save-model --upload-model --hf-entity cleanrl" \ --num-seeds 1 \ diff --git a/cleanrl/ppo_pettingzoo_ma_atari.py b/cleanrl/ppo_pettingzoo_ma_atari.py index 0b042b00e..c6bf2490b 100644 --- a/cleanrl/ppo_pettingzoo_ma_atari.py +++ b/cleanrl/ppo_pettingzoo_ma_atari.py @@ -1,10 +1,9 @@ # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_pettingzoo_ma_ataripy -import argparse import importlib import os import random import time -from distutils.util import strtobool +from dataclasses import dataclass import gymnasium as gym import numpy as np @@ -12,70 +11,72 @@ import torch import torch.nn as nn import torch.optim as optim +import tyro from torch.distributions.categorical import Categorical from torch.utils.tensorboard import SummaryWriter - -def parse_args(): - # fmt: off - parser = argparse.ArgumentParser() - parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), - help="the name of this experiment") - parser.add_argument("--seed", type=int, default=1, - help="seed of the experiment") - parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="if toggled, `torch.backends.cudnn.deterministic=False`") - parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="if toggled, cuda will be enabled by default") - parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, - help="if toggled, this experiment will be tracked with Weights and Biases") - parser.add_argument("--wandb-project-name", type=str, default="cleanRL", - help="the wandb's project name") - parser.add_argument("--wandb-entity", type=str, default=None, - help="the entity (team) of wandb's project") - parser.add_argument("--capture_video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, - help="whether to capture videos of the agent performances (check out `videos` folder)") +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanRL" + """the wandb's project name""" + wandb_entity: str = None + """the entity (team) of wandb's project""" + capture_video: bool = False + """whether to capture videos of the agent performances (check out `videos` folder)""" # Algorithm specific arguments - parser.add_argument("--env-id", type=str, default="pong_v3", - help="the id of the environment") - parser.add_argument("--total-timesteps", type=int, default=20000000, - help="total timesteps of the experiments") - parser.add_argument("--learning-rate", type=float, default=2.5e-4, - help="the learning rate of the optimizer") - parser.add_argument("--num-envs", type=int, default=16, - help="the number of parallel game environments") - parser.add_argument("--num-steps", type=int, default=128, - help="the number of steps to run in each environment per policy rollout") - parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Toggle learning rate annealing for policy and value networks") - parser.add_argument("--gamma", type=float, default=0.99, - help="the discount factor gamma") - parser.add_argument("--gae-lambda", type=float, default=0.95, - help="the lambda for the general advantage estimation") - parser.add_argument("--num-minibatches", type=int, default=4, - help="the number of mini-batches") - parser.add_argument("--update-epochs", type=int, default=4, - help="the K epochs to update the policy") - parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Toggles advantages normalization") - parser.add_argument("--clip-coef", type=float, default=0.1, - help="the surrogate clipping coefficient") - parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, - help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") - parser.add_argument("--ent-coef", type=float, default=0.01, - help="coefficient of the entropy") - parser.add_argument("--vf-coef", type=float, default=0.5, - help="coefficient of the value function") - parser.add_argument("--max-grad-norm", type=float, default=0.5, - help="the maximum norm for the gradient clipping") - parser.add_argument("--target-kl", type=float, default=None, - help="the target KL divergence threshold") - args = parser.parse_args() - args.batch_size = int(args.num_envs * args.num_steps) - args.minibatch_size = int(args.batch_size // args.num_minibatches) - # fmt: on - return args + env_id: str = "pong_v3" + """the id of the environment""" + total_timesteps: int = 20000000 + """total timesteps of the experiments""" + learning_rate: float = 2.5e-4 + """the learning rate of the optimizer""" + num_envs: int = 16 + """the number of parallel game environments""" + num_steps: int = 128 + """the number of steps to run in each environment per policy rollout""" + anneal_lr: bool = True + """Toggle learning rate annealing for policy and value networks""" + gamma: float = 0.99 + """the discount factor gamma""" + gae_lambda: float = 0.95 + """the lambda for the general advantage estimation""" + num_minibatches: int = 4 + """the number of mini-batches""" + update_epochs: int = 4 + """the K epochs to update the policy""" + norm_adv: bool = True + """Toggles advantages normalization""" + clip_coef: float = 0.1 + """the surrogate clipping coefficient""" + clip_vloss: bool = True + """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" + ent_coef: float = 0.01 + """coefficient of the entropy""" + vf_coef: float = 0.5 + """coefficient of the value function""" + max_grad_norm: float = 0.5 + """the maximum norm for the gradient clipping""" + target_kl: float = None + """the target KL divergence threshold""" + + # to be filled in runtime + batch_size: int = 0 + """the batch size (computed in runtime)""" + minibatch_size: int = 0 + """the mini-batch size (computed in runtime)""" + num_iterations: int = 0 + """the number of iterations (computed in runtime)""" def layer_init(layer, std=np.sqrt(2), bias_const=0.0): @@ -118,7 +119,10 @@ def get_action_and_value(self, x, action=None): if __name__ == "__main__": - args = parse_args() + args = tyro.cli(Args) + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_iterations = args.total_timesteps // args.batch_size run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: import wandb @@ -185,15 +189,15 @@ def get_action_and_value(self, x, action=None): next_truncation = torch.zeros(args.num_envs).to(device) num_updates = args.total_timesteps // args.batch_size - for update in range(1, num_updates + 1): + for iteration in range(1, args.num_iterations + 1): # Annealing the rate if instructed to do so. if args.anneal_lr: - frac = 1.0 - (update - 1.0) / num_updates + frac = 1.0 - (iteration - 1.0) / args.num_iterations lrnow = frac * args.learning_rate optimizer.param_groups[0]["lr"] = lrnow for step in range(0, args.num_steps): - global_step += 1 * args.num_envs + global_step += args.num_envs obs[step] = next_obs terminations[step] = next_termination truncations[step] = next_truncation @@ -299,9 +303,8 @@ def get_action_and_value(self, x, action=None): nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) optimizer.step() - if args.target_kl is not None: - if approx_kl > args.target_kl: - break + if args.target_kl is not None and approx_kl > args.target_kl: + break y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() var_y = np.var(y_true) From d2cf1a5429c9d23d4ad9f07013c62341afdcf01a Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 18 Jan 2024 13:33:36 -0500 Subject: [PATCH 16/20] pre-commit --- cleanrl/ppo_pettingzoo_ma_atari.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cleanrl/ppo_pettingzoo_ma_atari.py b/cleanrl/ppo_pettingzoo_ma_atari.py index c6bf2490b..d92ce6c58 100644 --- a/cleanrl/ppo_pettingzoo_ma_atari.py +++ b/cleanrl/ppo_pettingzoo_ma_atari.py @@ -15,6 +15,7 @@ from torch.distributions.categorical import Categorical from torch.utils.tensorboard import SummaryWriter + @dataclass class Args: exp_name: str = os.path.basename(__file__)[: -len(".py")] From 981bc637d83d34a4dce3b74c76b1cf8d22a672ae Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 18 Jan 2024 13:34:53 -0500 Subject: [PATCH 17/20] Revert accidentally changed files (zoo and ipynb, which randomly seems to change --- benchmark/zoo.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmark/zoo.sh b/benchmark/zoo.sh index 95779fe37..a5ab38e14 100644 --- a/benchmark/zoo.sh +++ b/benchmark/zoo.sh @@ -1,25 +1,25 @@ poetry run python cleanrl/dqn_jax.py --env-id CartPole-v1 --save-model --upload-model --hf-entity cleanrl poetry run python cleanrl/dqn_atari_jax.py --env-id SeaquestNoFrameskip-v4 --save-model --upload-model --hf-entity cleanrl -poetry run python -m cleanrl_utils.benchmark \ +xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ --command "poetry run python cleanrl/dqn.py --no_cuda --track --capture_video --save-model --upload-model --hf-entity cleanrl" \ --num-seeds 1 \ --workers 1 -CUDA_VISIBLE_DEVICES="-1" poetry run python -m cleanrl_utils.benchmark \ +CUDA_VISIBLE_DEVICES="-1" xvfb-run -a poetry run python -m cleanrl_utils.benchmark \ --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \ --command "poetry run python cleanrl/dqn_jax.py --track --capture_video --save-model --upload-model --hf-entity cleanrl" \ --num-seeds 1 \ --workers 1 -python -m cleanrl_utils.benchmark \ +xvfb-run -a python -m cleanrl_utils.benchmark \ --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ --command "poetry run python cleanrl/dqn_atari_jax.py --track --capture_video --save-model --upload-model --hf-entity cleanrl" \ --num-seeds 1 \ --workers 1 -python -m cleanrl_utils.benchmark \ +xvfb-run -a python -m cleanrl_utils.benchmark \ --env-ids PongNoFrameskip-v4 BeamRiderNoFrameskip-v4 BreakoutNoFrameskip-v4 \ --command "poetry run python cleanrl/dqn_atari.py --track --capture_video --save-model --upload-model --hf-entity cleanrl" \ --num-seeds 1 \ @@ -31,7 +31,7 @@ python -m cleanrl_utils.benchmark \ --num-seeds 1 \ --workers 1 -CUDA_VISIBLE_DEVICES="1" python -m cleanrl_utils.benchmark \ +CUDA_VISIBLE_DEVICES="1" taskset --cpu-list 16,17,18,19,20,21,22,23 python -m cleanrl_utils.benchmark \ --env-ids Breakout-v5 \ --command "poetry run python cleanrl/ppo_atari_envpool_xla_jax_scan.py --track --save-model --upload-model --hf-entity cleanrl" \ --num-seeds 1 \ From 454364d970494000a86612cc97d3a73543e34940 Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 18 Jan 2024 13:40:08 -0500 Subject: [PATCH 18/20] Revert ipynb change --- ...CleanRL_Huggingface_Integration_Demo.ipynb | 9772 ++++++++--------- 1 file changed, 4886 insertions(+), 4886 deletions(-) diff --git a/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb b/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb index 51775005f..4cb022ec3 100644 --- a/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb +++ b/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb @@ -1,4941 +1,4941 @@ { - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "id": "oTuvowvgFQpm" - }, - "source": [ - "# CleanRL's Huggingface Integration Demo\n", - "\n", - "\n", - "\n", - "[](https://github.com/vwxyzjn/cleanrl)\n", - "[![tests](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml/badge.svg)](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml)\n", - "[![docs](https://img.shields.io/github/deployments/vwxyzjn/cleanrl/Production?label=docs&logo=vercel)](https://docs.cleanrl.dev/)\n", - "[](https://discord.gg/D6RCjA6sVT)\n", - "[](https://www.youtube.com/channel/UCDdC6BIFRI0jvcwuhi3aI6w/videos)\n", - "[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)\n", - "[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)\n", - "[](https://huggingface.co/cleanrl)\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vwxyzjn/cleanrl/blob/master/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb)\n", - "\n", - "\n", - "CleanRL is a Deep Reinforcement Learning library that provides high-quality single-file implementation with research-friendly features. It now has has 🧪 experimental support for saving and loading models from 🤗 HuggingFace's [Model Hub](https://huggingface.co/models). This notebook is a preliminary demo.\n", - "\n", - "\n", - "* 💾 [GitHub Repo](https://github.com/vwxyzjn/cleanrl)\n", - "* 📜 [Documentation](https://docs.cleanrl.dev/)\n", - "* 🤗 [HuggingFace Model Hub](https://huggingface.co/cleanrl)\n", - "* 🔗 [Open RL Benchmark reports](https://wandb.ai/openrlbenchmark/openrlbenchmark/reportlist)\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "J0zhqyfea0If" - }, - "source": [ - "## Get Started\n", - "\n", - "CleanRL can be installed via `pip`. Let's say we are interested in pulling the model for [`dqn_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari_jax.py), we can install the algorithm-variant-specific dependencies as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lhnJkrYLOvcs", - "outputId": "381d9d0d-7e83-4f21-ef89-91d4e3b93c18" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting cleanrl[dqn-atari-jax]\n", - " Downloading cleanrl-1.1.2-py3-none-any.whl (16.9 MB)\n", - "\u001B[K |████████████████████████████████| 16.9 MB 241 kB/s \n", - "\u001B[?25hCollecting pygame==2.1.0\n", - " Downloading pygame-2.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)\n", - "\u001B[K |████████████████████████████████| 18.3 MB 59.3 MB/s \n", - "\u001B[?25hCollecting huggingface-hub<0.12.0,>=0.11.1\n", - " Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)\n", - "\u001B[K |████████████████████████████████| 182 kB 74.9 MB/s \n", - "\u001B[?25hCollecting wandb<0.14.0,>=0.13.6\n", - " Downloading wandb-0.13.7-py2.py3-none-any.whl (1.9 MB)\n", - "\u001B[K |████████████████████████████████| 1.9 MB 63.6 MB/s \n", - "\u001B[?25hRequirement already satisfied: torch>=1.12.1 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (1.13.0+cu116)\n", - "Collecting stable-baselines3==1.2.0\n", - " Downloading stable_baselines3-1.2.0-py3-none-any.whl (161 kB)\n", - "\u001B[K |████████████████████████████████| 161 kB 64.7 MB/s \n", - "\u001B[?25hCollecting tensorboard<3.0.0,>=2.10.0\n", - " Downloading tensorboard-2.11.0-py3-none-any.whl (6.0 MB)\n", - "\u001B[K |████████████████████████████████| 6.0 MB 65.0 MB/s \n", - "\u001B[?25hCollecting moviepy<2.0.0,>=1.0.3\n", - " Downloading moviepy-1.0.3.tar.gz (388 kB)\n", - "\u001B[K |████████████████████████████████| 388 kB 59.5 MB/s \n", - "\u001B[?25hCollecting gym==0.23.1\n", - " Downloading gym-0.23.1.tar.gz (626 kB)\n", - "\u001B[K |████████████████████████████████| 626 kB 59.9 MB/s \n", - "\u001B[?25h Installing build dependencies ... \u001B[?25l\u001B[?25hdone\n", - " Getting requirements to build wheel ... \u001B[?25l\u001B[?25hdone\n", - " Preparing wheel metadata ... \u001B[?25l\u001B[?25hdone\n", - "Collecting gymnasium<0.27.0,>=0.26.3\n", - " Downloading Gymnasium-0.26.3-py3-none-any.whl (836 kB)\n", - "\u001B[K |████████████████████████████████| 836 kB 64.6 MB/s \n", - "\u001B[?25hCollecting AutoROM[accept-rom-license]<0.5.0,>=0.4.2\n", - " Downloading AutoROM-0.4.2-py3-none-any.whl (16 kB)\n", - "Collecting ale-py==0.7.4\n", - " Downloading ale_py-0.7.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)\n", - "\u001B[K |████████████████████████████████| 1.6 MB 55.1 MB/s \n", - "\u001B[?25hRequirement already satisfied: opencv-python<5.0.0.0,>=4.6.0.66 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (4.6.0.66)\n", - "Requirement already satisfied: jax<0.4.0,>=0.3.17 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (0.3.25)\n", - "Collecting flax<0.7.0,>=0.6.0\n", - " Downloading flax-0.6.3-py3-none-any.whl (197 kB)\n", - "\u001B[K |████████████████████████████████| 197 kB 73.9 MB/s \n", - "\u001B[?25hRequirement already satisfied: jaxlib<0.4.0,>=0.3.15 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (0.3.25+cuda11.cudnn805)\n", - "Requirement already satisfied: importlib-metadata>=4.10.0 in /usr/local/lib/python3.8/dist-packages (from ale-py==0.7.4->cleanrl[dqn-atari-jax]) (5.1.0)\n", - "Requirement already satisfied: importlib-resources in /usr/local/lib/python3.8/dist-packages (from ale-py==0.7.4->cleanrl[dqn-atari-jax]) (5.10.1)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from ale-py==0.7.4->cleanrl[dqn-atari-jax]) (1.21.6)\n", - "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.8/dist-packages (from gym==0.23.1->cleanrl[dqn-atari-jax]) (0.0.8)\n", - "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.8/dist-packages (from gym==0.23.1->cleanrl[dqn-atari-jax]) (1.5.0)\n", - "Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (1.3.5)\n", - "Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (3.2.2)\n", - "Requirement already satisfied: click in /usr/local/lib/python3.8/dist-packages (from AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (7.1.2)\n", - "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (2.23.0)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (from AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (4.64.1)\n", - "Collecting AutoROM.accept-rom-license\n", - " Downloading AutoROM.accept-rom-license-0.5.0.tar.gz (10 kB)\n", - " Installing build dependencies ... \u001B[?25l\u001B[?25hdone\n", - " Getting requirements to build wheel ... \u001B[?25l\u001B[?25hdone\n", - " Preparing wheel metadata ... \u001B[?25l\u001B[?25hdone\n", - "Requirement already satisfied: PyYAML>=5.4.1 in /usr/local/lib/python3.8/dist-packages (from flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (6.0)\n", - "Collecting rich>=11.1\n", - " Downloading rich-13.0.0-py3-none-any.whl (238 kB)\n", - "\u001B[K |████████████████████████████████| 238 kB 76.7 MB/s \n", - "\u001B[?25hRequirement already satisfied: typing-extensions>=4.1.1 in /usr/local/lib/python3.8/dist-packages (from flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (4.4.0)\n", - "Collecting orbax\n", - " Downloading orbax-0.0.23-py3-none-any.whl (66 kB)\n", - "\u001B[K |████████████████████████████████| 66 kB 6.3 MB/s \n", - "\u001B[?25hCollecting optax\n", - " Downloading optax-0.1.4-py3-none-any.whl (154 kB)\n", - "\u001B[K |████████████████████████████████| 154 kB 82.0 MB/s \n", - "\u001B[?25hCollecting tensorstore\n", - " Downloading tensorstore-0.1.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)\n", - "\u001B[K |████████████████████████████████| 8.3 MB 64.7 MB/s \n", - "\u001B[?25hRequirement already satisfied: msgpack in /usr/local/lib/python3.8/dist-packages (from flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (1.0.4)\n", - "Collecting gymnasium-notices>=0.0.1\n", - " Downloading gymnasium_notices-0.0.1-py3-none-any.whl (2.8 kB)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<0.12.0,>=0.11.1->cleanrl[dqn-atari-jax]) (3.8.2)\n", - "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<0.12.0,>=0.11.1->cleanrl[dqn-atari-jax]) (21.3)\n", - "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata>=4.10.0->ale-py==0.7.4->cleanrl[dqn-atari-jax]) (3.11.0)\n", - "Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.8/dist-packages (from jax<0.4.0,>=0.3.17->cleanrl[dqn-atari-jax]) (1.7.3)\n", - "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.8/dist-packages (from jax<0.4.0,>=0.3.17->cleanrl[dqn-atari-jax]) (3.3.0)\n", - "Requirement already satisfied: decorator<5.0,>=4.0.2 in /usr/local/lib/python3.8/dist-packages (from moviepy<2.0.0,>=1.0.3->cleanrl[dqn-atari-jax]) (4.4.2)\n", - "Collecting proglog<=1.0.0\n", - " Downloading proglog-0.1.10-py3-none-any.whl (6.1 kB)\n", - "Requirement already satisfied: imageio<3.0,>=2.5 in /usr/local/lib/python3.8/dist-packages (from moviepy<2.0.0,>=1.0.3->cleanrl[dqn-atari-jax]) (2.9.0)\n", - "Collecting imageio_ffmpeg>=0.2.0\n", - " Downloading imageio_ffmpeg-0.4.7-py3-none-manylinux2010_x86_64.whl (26.9 MB)\n", - "\u001B[K |████████████████████████████████| 26.9 MB 47.9 MB/s \n", - "\u001B[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.8/dist-packages (from imageio<3.0,>=2.5->moviepy<2.0.0,>=1.0.3->cleanrl[dqn-atari-jax]) (7.1.2)\n", - "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.9->huggingface-hub<0.12.0,>=0.11.1->cleanrl[dqn-atari-jax]) (3.0.9)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (2022.12.7)\n", - "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (1.24.3)\n", - "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (2.10)\n", - "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (3.0.4)\n", - "Collecting commonmark<0.10.0,>=0.9.0\n", - " Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)\n", - "\u001B[K |████████████████████████████████| 51 kB 5.0 MB/s \n", - "\u001B[?25hRequirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.8/dist-packages (from rich>=11.1->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (2.6.1)\n", - "Requirement already satisfied: protobuf<4,>=3.9.2 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (3.19.6)\n", - "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (2.15.0)\n", - "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.6.1)\n", - "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (3.4.1)\n", - "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.38.4)\n", - "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.4.6)\n", - "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.3.0)\n", - "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (57.4.0)\n", - "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.8.1)\n", - "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.0.1)\n", - "Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.51.1)\n", - "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.2.8)\n", - "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (5.2.0)\n", - "Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.15.0)\n", - "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (4.9)\n", - "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.8/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.3.1)\n", - "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.8/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.4.8)\n", - "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.8/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (3.2.2)\n", - "Collecting pathtools\n", - " Downloading pathtools-0.1.2.tar.gz (11 kB)\n", - "Requirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.8/dist-packages (from wandb<0.14.0,>=0.13.6->cleanrl[dqn-atari-jax]) (2.3)\n", - "Collecting GitPython>=1.0.0\n", - " Downloading GitPython-3.1.30-py3-none-any.whl (184 kB)\n", - "\u001B[K |████████████████████████████████| 184 kB 71.4 MB/s \n", - "\u001B[?25hRequirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.8/dist-packages (from wandb<0.14.0,>=0.13.6->cleanrl[dqn-atari-jax]) (5.4.8)\n", - "Collecting shortuuid>=0.5.0\n", - " Downloading shortuuid-1.0.11-py3-none-any.whl (10 kB)\n", - "Collecting sentry-sdk>=1.0.0\n", - " Downloading sentry_sdk-1.12.1-py2.py3-none-any.whl (174 kB)\n", - "\u001B[K |████████████████████████████████| 174 kB 80.8 MB/s \n", - "\u001B[?25hCollecting docker-pycreds>=0.4.0\n", - " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", - "Collecting setproctitle\n", - " Downloading setproctitle-1.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB)\n", - "Collecting gitdb<5,>=4.0.1\n", - " Downloading gitdb-4.0.10-py3-none-any.whl (62 kB)\n", - "\u001B[K |████████████████████████████████| 62 kB 1.7 MB/s \n", - "\u001B[?25hCollecting smmap<6,>=3.0.1\n", - " Downloading smmap-5.0.0-py3-none-any.whl (24 kB)\n", - "Collecting sentry-sdk>=1.0.0\n", - " Downloading sentry_sdk-1.12.0-py2.py3-none-any.whl (173 kB)\n", - "\u001B[K |████████████████████████████████| 173 kB 69.0 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.11.1-py2.py3-none-any.whl (168 kB)\n", - "\u001B[K |████████████████████████████████| 168 kB 66.6 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.11.0-py2.py3-none-any.whl (168 kB)\n", - "\u001B[K |████████████████████████████████| 168 kB 8.1 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.10.1-py2.py3-none-any.whl (166 kB)\n", - "\u001B[K |████████████████████████████████| 166 kB 10.6 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.10.0-py2.py3-none-any.whl (166 kB)\n", - "\u001B[K |████████████████████████████████| 166 kB 71.4 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.9.10-py2.py3-none-any.whl (162 kB)\n", - "\u001B[K |████████████████████████████████| 162 kB 70.1 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.9.9-py2.py3-none-any.whl (162 kB)\n", - "\u001B[K |████████████████████████████████| 162 kB 70.2 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.9.8-py2.py3-none-any.whl (158 kB)\n", - "\u001B[K |████████████████████████████████| 158 kB 75.4 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.9.7-py2.py3-none-any.whl (157 kB)\n", - "\u001B[K |████████████████████████████████| 157 kB 77.6 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.9.6-py2.py3-none-any.whl (157 kB)\n", - "\u001B[K |████████████████████████████████| 157 kB 83.8 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.9.5-py2.py3-none-any.whl (157 kB)\n", - "\u001B[K |████████████████████████████████| 157 kB 88.0 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.9.4-py2.py3-none-any.whl (157 kB)\n", - "\u001B[K |████████████████████████████████| 157 kB 80.1 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.9.3-py2.py3-none-any.whl (157 kB)\n", - "\u001B[K |████████████████████████████████| 157 kB 84.8 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.9.2-py2.py3-none-any.whl (157 kB)\n", - "\u001B[K |████████████████████████████████| 157 kB 85.7 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.9.1-py2.py3-none-any.whl (157 kB)\n", - "\u001B[K |████████████████████████████████| 157 kB 83.5 MB/s \n", - "\u001B[?25h Downloading sentry_sdk-1.9.0-py2.py3-none-any.whl (156 kB)\n", - "\u001B[K |████████████████████████████████| 156 kB 84.0 MB/s \n", - "\u001B[?25hCollecting libtorrent\n", - " Using cached libtorrent-2.0.7-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (8.6 MB)\n", - "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (2.8.2)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (0.11.0)\n", - "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (1.4.4)\n", - "Collecting chex>=0.1.5\n", - " Downloading chex-0.1.5-py3-none-any.whl (85 kB)\n", - "\u001B[K |████████████████████████████████| 85 kB 4.9 MB/s \n", - "\u001B[?25hRequirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.8/dist-packages (from chex>=0.1.5->optax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.1.7)\n", - "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.8/dist-packages (from chex>=0.1.5->optax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.12.0)\n", - "Requirement already satisfied: pytest in /usr/local/lib/python3.8/dist-packages (from orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (3.6.4)\n", - "Collecting cached_property\n", - " Downloading cached_property-1.5.2-py2.py3-none-any.whl (7.6 kB)\n", - "Requirement already satisfied: etils in /usr/local/lib/python3.8/dist-packages (from orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.9.0)\n", - "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (2022.6)\n", - "Requirement already satisfied: more-itertools>=4.0.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (9.0.0)\n", - "Requirement already satisfied: pluggy<0.8,>=0.5 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.7.1)\n", - "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (22.1.0)\n", - "Requirement already satisfied: py>=1.5.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (1.11.0)\n", - "Requirement already satisfied: atomicwrites>=1.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (1.4.1)\n", - "Building wheels for collected packages: gym, moviepy, AutoROM.accept-rom-license, pathtools\n", - " Building wheel for gym (PEP 517) ... \u001B[?25l\u001B[?25hdone\n", - " Created wheel for gym: filename=gym-0.23.1-py3-none-any.whl size=701376 sha256=7b59f30aef873fc1494bd2f2eeac27b103b64ae6ee87d554c8b61b9ddbe35765\n", - " Stored in directory: /root/.cache/pip/wheels/78/28/77/b0c74e80a2a4faae0161d5c53bc4f8e436e77aedc79136ee13\n", - " Building wheel for moviepy (setup.py) ... \u001B[?25l\u001B[?25hdone\n", - " Created wheel for moviepy: filename=moviepy-1.0.3-py3-none-any.whl size=110742 sha256=640c1c0df827ed5835373acab4d2d7b93e98e33b5e6cb90e3d5e703933f9bcf8\n", - " Stored in directory: /root/.cache/pip/wheels/e4/a4/db/0368d3a04033da662e13926594b3a8cf1aa4ffeefe570cfac1\n", - " Building wheel for AutoROM.accept-rom-license (PEP 517) ... \u001B[?25l\u001B[?25hdone\n", - " Created wheel for AutoROM.accept-rom-license: filename=AutoROM.accept_rom_license-0.5.0-py3-none-any.whl size=440868 sha256=a3833e2c22c21355029cb083d9ea62b7abe329af3757ccdce9b0d2a5cc06949f\n", - " Stored in directory: /root/.cache/pip/wheels/bf/c9/25/578470ae932b494c313dc22e6c57afff192140fb3cd5acf185\n", - " Building wheel for pathtools (setup.py) ... \u001B[?25l\u001B[?25hdone\n", - " Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8806 sha256=57226a75b752bf852ac2f0f5ad878217a63376d6c44a4b29ccdf40b4921bf4bc\n", - " Stored in directory: /root/.cache/pip/wheels/4c/8e/7e/72fbc243e1aeecae64a96875432e70d4e92f3d2d18123be004\n", - "Successfully built gym moviepy AutoROM.accept-rom-license pathtools\n", - "Installing collected packages: smmap, gitdb, tensorstore, shortuuid, setproctitle, sentry-sdk, proglog, pathtools, libtorrent, imageio-ffmpeg, gymnasium-notices, gym, GitPython, docker-pycreds, commonmark, chex, cached-property, wandb, tensorboard, stable-baselines3, rich, pygame, orbax, optax, moviepy, huggingface-hub, gymnasium, AutoROM.accept-rom-license, AutoROM, flax, cleanrl-test, ale-py\n", - " Attempting uninstall: gym\n", - " Found existing installation: gym 0.25.2\n", - " Uninstalling gym-0.25.2:\n", - " Successfully uninstalled gym-0.25.2\n", - " Attempting uninstall: tensorboard\n", - " Found existing installation: tensorboard 2.9.1\n", - " Uninstalling tensorboard-2.9.1:\n", - " Successfully uninstalled tensorboard-2.9.1\n", - " Attempting uninstall: moviepy\n", - " Found existing installation: moviepy 0.2.3.5\n", - " Uninstalling moviepy-0.2.3.5:\n", - " Successfully uninstalled moviepy-0.2.3.5\n", - "\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "tensorflow 2.9.2 requires tensorboard<2.10,>=2.9, but you have tensorboard 2.11.0 which is incompatible.\u001B[0m\n", - "Successfully installed AutoROM-0.4.2 AutoROM.accept-rom-license-0.5.0 GitPython-3.1.30 ale-py-0.7.4 cached-property-1.5.2 chex-0.1.5 cleanrl-test-1.1.2 commonmark-0.9.1 docker-pycreds-0.4.0 flax-0.6.3 gitdb-4.0.10 gym-0.23.1 gymnasium-0.26.3 gymnasium-notices-0.0.1 huggingface-hub-0.11.1 imageio-ffmpeg-0.4.7 libtorrent-2.0.7 moviepy-1.0.3 optax-0.1.4 orbax-0.0.23 pathtools-0.1.2 proglog-0.1.10 pygame-2.1.0 rich-13.0.0 sentry-sdk-1.9.0 setproctitle-1.3.2 shortuuid-1.0.11 smmap-5.0.0 stable-baselines3-1.2.0 tensorboard-2.11.0 tensorstore-0.1.28 wandb-0.13.7\n" - ] - } - ], - "source": [ - "!pip install --upgrade \"cleanrl[dqn-atari-jax]\" # CAVEAT: the extra key is `dqn-atari-jax` with dashes instead of `dqn_atari_jax` with underscores" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xXQXZTh_AHZ0" - }, - "source": [ - "## Enjoy Utility\n", - "\n", - "We have a simple way to load the model by running our \"enjoy\" utility, which automatically pull the model from 🤗 HuggingFace and run for a few episodes. It also produces a rendered video through the `--capture_video` flag. See more at our [📜 Documentation](https://docs.cleanrl.dev/get-started/zoo/)." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4H9VZKBC_3_1", - "outputId": "fc03fd9b-84f8-43dc-b4e3-041e7a201c12" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/jupyter_client/connect.py:28: DeprecationWarning: Jupyter is migrating its paths to use standard platformdirs\n", - "given by the platformdirs library. To remove this warning and\n", - "see the appropriate new directories, set the environment variable\n", - "`JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.\n", - "The use of platformdirs will be the default in `jupyter_core` v6\n", - " from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write\n", - "loading saved models from cleanrl/BreakoutNoFrameskip-v4-dqn_atari_jax-seed1...\n", - "Downloading: 100% 6.75M/6.75M [00:00<00:00, 62.6MB/s]\n", - "A.L.E: Arcade Learning Environment (version 0.7.4+069f8bd)\n", - "[Powered by Stella]\n", - "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:138: DeprecationWarning: \u001B[33mWARN: Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. \u001B[0m\n", - " deprecation(\n", - "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:175: DeprecationWarning: \u001B[33mWARN: Function `_bigint_from_bytes(bytes)` is marked as deprecated and will be removed in the future. \u001B[0m\n", - " deprecation(\n", - "/usr/local/lib/python3.8/dist-packages/gym/wrappers/monitoring/video_recorder.py:43: DeprecationWarning: \u001B[33mWARN: `env.metadata[\"render.modes\"] is marked as deprecated and will be replaced with `env.metadata[\"render_modes\"]` see https://github.com/openai/gym/pull/2654 for more details\u001B[0m\n", - " logger.deprecation(\n", - "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:47: DeprecationWarning: \u001B[33mWARN: Function `rng.randint(low, [high, size, dtype])` is marked as deprecated and will be removed in the future. Please use `rng.integers(low, [high, size, dtype])` instead.\u001B[0m\n", - " deprecation(\n", - "/usr/local/lib/python3.8/dist-packages/gym/wrappers/monitoring/video_recorder.py:43: DeprecationWarning: \u001B[33mWARN: `env.metadata[\"render.modes\"] is marked as deprecated and will be replaced with `env.metadata[\"render_modes\"]` see https://github.com/openai/gym/pull/2654 for more details\u001B[0m\n", - " logger.deprecation(\n", - "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:47: DeprecationWarning: \u001B[33mWARN: Function `rng.randint(low, [high, size, dtype])` is marked as deprecated and will be removed in the future. Please use `rng.integers(low, [high, size, dtype])` instead.\u001B[0m\n", - " deprecation(\n", - "eval_episode=0, episodic_return=400.0\n", - "eval_episode=1, episodic_return=128.0\n" - ] - } - ], - "source": [ - "!python -m cleanrl_utils.enjoy --exp-name dqn_atari_jax --env-id BreakoutNoFrameskip-v4 --eval-episodes 2 --capture_video" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 231 - }, - "id": "KpzdA4dkFbdT", - "outputId": "1b53628e-ac19-4f36-89e4-1a831b51f06b" - }, - "outputs": [ + "cells": [ { - "data": { - "text/html": [ - "" - ], - "text/plain": [ - "" + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "id": "oTuvowvgFQpm" + }, + "source": [ + "# CleanRL's Huggingface Integration Demo\n", + "\n", + "\n", + "\n", + "[](https://github.com/vwxyzjn/cleanrl)\n", + "[![tests](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml/badge.svg)](https://github.com/vwxyzjn/cleanrl/actions/workflows/tests.yaml)\n", + "[![docs](https://img.shields.io/github/deployments/vwxyzjn/cleanrl/Production?label=docs&logo=vercel)](https://docs.cleanrl.dev/)\n", + "[](https://discord.gg/D6RCjA6sVT)\n", + "[](https://www.youtube.com/channel/UCDdC6BIFRI0jvcwuhi3aI6w/videos)\n", + "[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)\n", + "[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)\n", + "[](https://huggingface.co/cleanrl)\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vwxyzjn/cleanrl/blob/master/docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb)\n", + "\n", + "\n", + "CleanRL is a Deep Reinforcement Learning library that provides high-quality single-file implementation with research-friendly features. It now has has 🧪 experimental support for saving and loading models from 🤗 HuggingFace's [Model Hub](https://huggingface.co/models). This notebook is a preliminary demo.\n", + "\n", + "\n", + "* 💾 [GitHub Repo](https://github.com/vwxyzjn/cleanrl)\n", + "* 📜 [Documentation](https://docs.cleanrl.dev/)\n", + "* 🤗 [HuggingFace Model Hub](https://huggingface.co/cleanrl)\n", + "* 🔗 [Open RL Benchmark reports](https://wandb.ai/openrlbenchmark/openrlbenchmark/reportlist)\n", + "\n" ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from IPython.display import Video\n", - "Video('videos/eval/rl-video-episode-0.mp4', embed=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WU29XP1ICwxv" - }, - "source": [ - "## Diving Deeper\n", - "\n", - "What happened above was achieved by a simple wrapper for [cleanrl_utils/evals/dqn_eval.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl_utils/evals/dqn_eval.py), which is pretty succinct and may give you a more fine-grained control and access to the model. Its content is roughly as follows, where it attempts to download a model from https://huggingface.co/cleanrl/BreakoutNoFrameskip-v4-dqn_atari_jax-seed1 and run an evaluation pass. " - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "eZY6cAxkDJF5", - "outputId": "0144efd9-5d8e-4631-8a07-6385d8365558" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:138: DeprecationWarning: \u001B[33mWARN: Function `hash_seed(seed, max_bytes)` is marked as deprecated and will be removed in the future. \u001B[0m\n", - " deprecation(\n", - "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:175: DeprecationWarning: \u001B[33mWARN: Function `_bigint_from_bytes(bytes)` is marked as deprecated and will be removed in the future. \u001B[0m\n", - " deprecation(\n", - "/usr/local/lib/python3.8/dist-packages/gym/utils/seeding.py:47: DeprecationWarning: \u001B[33mWARN: Function `rng.randint(low, [high, size, dtype])` is marked as deprecated and will be removed in the future. Please use `rng.integers(low, [high, size, dtype])` instead.\u001B[0m\n", - " deprecation(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "eval_episode=0, episodic_return=340.0\n", - "eval_episode=1, episodic_return=399.0\n" - ] }, { - "data": { - "text/plain": [ - "[340.0, 399.0]" + "cell_type": "markdown", + "metadata": { + "id": "J0zhqyfea0If" + }, + "source": [ + "## Get Started\n", + "\n", + "CleanRL can be installed via `pip`. Let's say we are interested in pulling the model for [`dqn_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari_jax.py), we can install the algorithm-variant-specific dependencies as follows:" ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import random\n", - "from typing import Callable\n", - "\n", - "import flax\n", - "import flax.linen as nn\n", - "import gym\n", - "import jax\n", - "import numpy as np\n", - "\n", - "\n", - "def evaluate(\n", - " model_path: str,\n", - " make_env: Callable,\n", - " env_id: str,\n", - " eval_episodes: int,\n", - " run_name: str,\n", - " Model: nn.Module,\n", - " epsilon: float = 0.05,\n", - " capture_video: bool = True,\n", - " seed=1,\n", - "):\n", - " envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)])\n", - " obs = envs.reset()\n", - " model = Model(action_dim=envs.single_action_space.n)\n", - " q_key = jax.random.PRNGKey(seed)\n", - " params = model.init(q_key, obs)\n", - " with open(model_path, \"rb\") as f:\n", - " params = flax.serialization.from_bytes(params, f.read())\n", - " model.apply = jax.jit(model.apply)\n", - "\n", - " episodic_returns = []\n", - " while len(episodic_returns) < eval_episodes:\n", - " if random.random() < epsilon:\n", - " actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])\n", - " else:\n", - " q_values = model.apply(params, obs)\n", - " actions = q_values.argmax(axis=-1)\n", - " actions = jax.device_get(actions)\n", - " next_obs, _, _, infos = envs.step(actions)\n", - " for info in infos:\n", - " if \"episode\" in info.keys():\n", - " print(f\"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}\")\n", - " episodic_returns += [info[\"episode\"][\"r\"]]\n", - " obs = next_obs\n", - "\n", - " return episodic_returns\n", - "\n", - "\n", - "from huggingface_hub import hf_hub_download\n", - "\n", - "from cleanrl.dqn_atari_jax import QNetwork, make_env\n", - "\n", - "model_path = hf_hub_download(repo_id=\"cleanrl/BreakoutNoFrameskip-v4-dqn_atari_jax-seed1\", filename=\"dqn_atari_jax.cleanrl_model\")\n", - "evaluate(\n", - " model_path,\n", - " make_env,\n", - " \"BreakoutNoFrameskip-v4\",\n", - " eval_episodes=2,\n", - " run_name=f\"eval\",\n", - " Model=QNetwork,\n", - " capture_video=False,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZxM0A6LmQtnn" - }, - "source": [ - "## More Examples\n", - "\n", - "Now let's get going with more examples!" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "TrQae62Y70H0" - }, - "outputs": [], - "source": [ - "import argparse\n", - "from dataclasses import dataclass\n", - "\n", - "from huggingface_hub import hf_hub_download\n", - "\n", - "try:\n", - " from pip import main as pipmain\n", - "except ImportError:\n", - " from pip._internal import main as pipmain\n", - "\n", - "@dataclass\n", - "class Args:\n", - " exp_name: str = \"dqn_atari_jax\"\n", - " seed: int = 1\n", - " hf_entity: str = \"cleanrl\"\n", - " hf_repository: str = \"\"\n", - " env_id: str = \"BreakoutNoFrameskip-v4\"\n", - "\n", - "\n", - "def dqn():\n", - " import cleanrl.dqn\n", - " import cleanrl_utils.evals.dqn_eval\n", - " return cleanrl.dqn.QNetwork, cleanrl.dqn.make_env, cleanrl_utils.evals.dqn_eval.evaluate\n", - "\n", - "def dqn_atari():\n", - " import cleanrl.dqn_atari\n", - " import cleanrl_utils.evals.dqn_eval\n", - " return cleanrl.dqn_atari.QNetwork, cleanrl.dqn_atari.make_env, cleanrl_utils.evals.dqn_eval.evaluate\n", - "\n", - "def dqn_jax():\n", - " import cleanrl.dqn_jax\n", - " import cleanrl_utils.evals.dqn_jax_eval\n", - " return cleanrl.dqn_jax.QNetwork, cleanrl.dqn_jax.make_env, cleanrl_utils.evals.dqn_jax_eval.evaluate\n", - "\n", - "def dqn_atari_jax():\n", - " import cleanrl.dqn_atari_jax\n", - " import cleanrl_utils.evals.dqn_jax_eval\n", - " return cleanrl.dqn_atari_jax.QNetwork, cleanrl.dqn_atari_jax.make_env, cleanrl_utils.evals.dqn_jax_eval.evaluate\n", - "\n", - "MODELS = {\n", - " \"dqn\": dqn,\n", - " \"dqn_atari\": dqn_atari,\n", - " \"dqn_jax\": dqn_jax,\n", - " \"dqn_atari_jax\": dqn_atari_jax,\n", - "}\n", - "\n", - "\n", - "\n", - "exp_names = [\"dqn\", \"dqn_jax\", \"dqn_atari_jax\", \"dqn_atari\"]\n", - "env_idss = [\n", - " [\n", - " \"CartPole-v1\",\n", - " \"Acrobot-v1\",\n", - " \"MountainCar-v0\",\n", - " ],\n", - " [\n", - " \"CartPole-v1\",\n", - " \"Acrobot-v1\",\n", - " \"MountainCar-v0\",\n", - " ],\n", - " [\n", - " \"BreakoutNoFrameskip-v4\",\n", - " \"PongNoFrameskip-v4\",\n", - " \"BeamRiderNoFrameskip-v4\"\n", - " ],\n", - " [\n", - " \"BreakoutNoFrameskip-v4\",\n", - " \"PongNoFrameskip-v4\",\n", - " \"BeamRiderNoFrameskip-v4\"\n", - " ]\n", - " ]\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "id": "IeksFU1me8q8" - }, - "source": [ - "### Install dependencies for each variant" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "dnvpgpWWfABl", - "outputId": "1e41abbf-d9c4-4adf-fe05-40e8e31962f4" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/pip/_vendor/packaging/version.py:127: DeprecationWarning: Creating a LegacyVersion has been deprecated and will be removed in the next major release\n", - " warnings.warn(\n", - "/usr/local/lib/python3.8/dist-packages/pip/_vendor/packaging/version.py:127: DeprecationWarning: Creating a LegacyVersion has been deprecated and will be removed in the next major release\n", - " warnings.warn(\n", - "WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.\n", - "Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.\n", - "To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.\n", - "/usr/local/lib/python3.8/dist-packages/pip/_vendor/packaging/version.py:127: DeprecationWarning: Creating a LegacyVersion has been deprecated and will be removed in the next major release\n", - " warnings.warn(\n", - "WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.\n", - "Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.\n", - "To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==== ['install', '--upgrade', 'cleanrl[dqn]', '--quiet']\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/pip/_vendor/packaging/version.py:127: DeprecationWarning: Creating a LegacyVersion has been deprecated and will be removed in the next major release\n", - " warnings.warn(\n", - "WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.\n", - "Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.\n", - "To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==== ['install', '--upgrade', 'cleanrl[dqn-jax]', '--quiet']\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/pip/_vendor/packaging/version.py:127: DeprecationWarning: Creating a LegacyVersion has been deprecated and will be removed in the next major release\n", - " warnings.warn(\n", - "WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.\n", - "Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.\n", - "To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==== ['install', '--upgrade', 'cleanrl[dqn-atari-jax]', '--quiet']\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/pip/_vendor/packaging/version.py:127: DeprecationWarning: Creating a LegacyVersion has been deprecated and will be removed in the next major release\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==== ['install', '--upgrade', 'cleanrl[dqn-atari]', '--quiet']\n" - ] - } - ], - "source": [ - "for exp_name, env_ids in zip(exp_names, env_idss):\n", - " # install dependencies for the algorithm variant\n", - " pipmain(['install', '--upgrade', f'cleanrl[{exp_name.replace(\"_\", \"-\")}]', \"--quiet\"])\n", - " print(\"====\", ['install', '--upgrade', f'cleanrl[{exp_name.replace(\"_\", \"-\")}]', \"--quiet\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "van2E4jFfC2f" - }, - "source": [ - "# Enjoy!" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000, - "referenced_widgets": [ - "1af264779a2442e596aed8e620561248", - "21bff070d7b342a3b1ca5c9976746a6f", - "82dc6a3d16e24066a8531f1390eb0450", - "df36fcd7bf4c47f48914e43866da1edf", - "f31718702d1a433eac7388ef1612a149", - "e6179240e76b4a2abd79e31b44c8bcf5", - "6e8a545a6d8e43a4ac9cc3a5cd8c06b6", - "1bbd3551eda24347b078a78c54ef9808", - "a3e4a40738274a48ad76090b7e3593ec", - "11d8ea24816b4d9487b5cf701c77368f", - "7bcc89fefa314206aef5e24d1b105b50", - "0c074497102c45aab5db63d863a493a5", - "8ffa468c4cee423ea37eb5c6a3008a8d", - "71728d572be04c4a816e09c4682dd254", - "54f10fce2aa44083bdddd67209f75097", - "b49b12ca89c5448a8994373aabb8c2de", - "3ac10a61d2394023b867276369d75d94", - "8676b5de0bb24b49a2624ce56e57b041", - "d537dbcb8a604e96ab1445a44f4f5795", - "3c2f669e4f5a472288f275d3c383df38", - "a538ea3a39db4b62b90e8fa84cc96ae0", - "96e4392eab1e468e990ca5e3ce239ee2", - "d47a0fe9206746618d6513a725629706", - "a69f93596efd47ebb3daa28aa4193333", - "f47b25f7ba814ad79fc49d5469f073f1", - "42d1d9807ecd4ac79a8a339ebec90c7d", - "0f846fe59b4f4f9595640cd537d5d359", - "242d6e4726bc4b44b51fbefbe9e71d89", - "f55821cfe861470d91254ca6fb1686d8", - "60c2723becc441f490db7ac57f0db14f", - "335346c51b404ec38d28451cd64bfb1d", - "b082ad5929d44cfb873d686178af22f6", - "707972069edb433d808fde341ae797bf", - "85d693c15c7f402d8cbeb3b2dfa1204b", - "ff5d66561809464cba45624ac5c6db4d", - "a833d25fce5c49d89babd5efa6a9e7a1", - "bbf1f621e72e46b5bb34f83de5aad104", - "c0701fd6a60241449580c09929d8a23c", - "5d2f9340fa314329846b680590f7b983", - "c5a49382f0fd40f2989b36d6ad63e5d5", - "4df7f8fbc6c646dc81a1c63fe5167618", - "92a92c1cec1445bca8bda94543c31091", - "7d57323e580e48c789a834ba8f324609", - "5ef4b0f2e682408ea74d44f07863e726", - "03848b157e164490a7a509028df7cad8", - "744e9685a89749738871f14718a58c56", - "442311424f9449cda8f697eec946ba04", - "0677049b6f5f4cfaa662300b8063bf8d", - "6887b96a72134b579780ce8fb4fec51c", - "1ae630c1bb96429595e4c8f64a8cd978", - "54333bb25bf74274bf37e62295bc5a90", - "10a31b06ec11402ea041ee5e048bd9be", - "a72a928130e3435892a666a59ce3f9ae", - "198c4e87657e4a02bcbc0db1306f842c", - "c12145376a554a9dab28277865509cc5", - "352eb266cd374d0f91a3d6628fe4df97", - "47a86cb59e43492b9d38ade53cc5bfa7", - "2a63eecc839f49f7ae8f2d974a922664", - "82a3c7c1a51c473cbe5c8c3fc3f6c7ce", - "b2af0dd940ae4b6893a86778df1b7877", - "1d0c0d2e16b046209c80fcd0a9392a41", - "31d97d0b38414773953861efb7d10afb", - "da450c588af44c78af9d665e5ef9bed9", - "e4124e064efb43bb8acfda5052874b5d", - "73d5d725a38e457eb078999c64171a82", - "cac8aede09734dac8e1a638b67da49cc", - "f31d98b28d5d4174b6fff125d0cfc169", - "d2708e38223d4f5397e1856896560c4d", - "90fddc80c9ba4b669065f2af725a814b", - "a71ba0aed6394d8ab1ca46dc21ee03ca", - "2a79f07e770c40cf821d9f7ab860e99b", - "8ec004e73c7f4b28b85365dcb959663c", - "651abec81a2b47b99fd28db7774f6b3f", - "65efc653defb40778d74abeed961ddf8", - "bb0018f331e647ab9e0b40b6b5e683b3", - "ad7717b0980140c58befed77c4f70250", - "aa85a5bb82e3445b929a198e246a19a0", - "29cf4ba923dc4bb4979b74035a1f50bf", - "6409943fc9bc4fc687bb16b4cda9d07c", - "45d8390f2e26447e81c2522f837c5a9f", - "11be98dd8811486dac47267f48244247", - "7db3fca11e5740e6801d26d40bf91076", - "a98d0644e9a047199464e16d46962088", - "af20581002334cd9af3833c59cd4c070", - "4a25a274b4034181b7d307490c3d5db1", - "78f2427ad9c64f56bf1481c3190e8c6c", - "7438c142a37748548fb84e732855d543", - "7f30a20de166487781da41a13bc3c0e4", - "1100f289ac0e45fb965d6221438c37f3", - "0e990a8ce0b744dfb2a0b53e98eca805", - "8b0029aeefbd4192a896f49607f1ed87", - "4d2c7cdb2eee47f9ac4e59d0fed9b4cd", - "90670ba8064c47fc9d84bcf889ba8b23", - "0e4b221ec94b4404a8a6f7b429154855", - "0b749bec4b2646a28d6f5f532149a8be", - "bac297e2dfc946509091a298e5d602b9", - "52d1eee0f09c4aee9454294a4379bece", - "f0f16ec872c44321a64387ffb0621ad4", - "1542a1c6fc944d219f70149236484bb0", - "0c1c34ffcc3c4071a45d9e48dc01459a", - "89357399364d43a493d7f1ad58f2ea5c", - "1f543ff2d15b4bcd9ca5ea227bf6025e", - "4d1d17e87b0b48a6bbe3f901640d5ae6", - "67e5ec2a25bd41f0b9e62fc6f0395a80", - "18f6e921b61d45fdb26bbff5531d3eb1", - "11df9e9525fb4616ae9fc563c104d2e7", - "5f54fef47da64724b7e7bf59cbcf3f35", - "1cbf2c8edb9246b49ff9334c655a3849", - "8e67dc3a52574897abf9dcb6bbc7c8ec", - "321a70d9c81f4d17b497f13328589153", - "f5fb6c114393429f8ad158d916d2b837", - "54a5be97eea9498ebb1722a65529006f", - "3273622ad1ce442188d54df94c6f737c", - "a34b45c0b3c449e9a914588cfaa375b7", - "64d68f2d555f47e8abe502dd212c8ff4", - "9db1d2d8d8b0433ca67629e00bb977bd", - "8a6e2d8cb9c5404d8e3f5ff2543b4409", - "a11b2e1f12f5460595aded3f4181d64f", - "3d947d7e8dd1446ca6107edccba6a183", - "3585b4c2b37547b38a386003f9370aa3", - "7c82ac27f8fb4fc1a87d89f415f2e2c5" - ] - }, - "id": "NkHbB0R4ewnE", - "outputId": "61d1d3d5-5114-456d-b3ff-8333b52d9f9a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loading models from cleanrl/CartPole-v1-dqn-seed1\n" - ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1af264779a2442e596aed8e620561248", - "version_major": 2, - "version_minor": 0 + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lhnJkrYLOvcs", + "outputId": "381d9d0d-7e83-4f21-ef89-91d4e3b93c18" }, - "text/plain": [ - "Downloading: 0%| | 0.00/45.8k [00:00=0.11.1\n", + " Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)\n", + "\u001b[K |████████████████████████████████| 182 kB 74.9 MB/s \n", + "\u001b[?25hCollecting wandb<0.14.0,>=0.13.6\n", + " Downloading wandb-0.13.7-py2.py3-none-any.whl (1.9 MB)\n", + "\u001b[K |████████████████████████████████| 1.9 MB 63.6 MB/s \n", + "\u001b[?25hRequirement already satisfied: torch>=1.12.1 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (1.13.0+cu116)\n", + "Collecting stable-baselines3==1.2.0\n", + " Downloading stable_baselines3-1.2.0-py3-none-any.whl (161 kB)\n", + "\u001b[K |████████████████████████████████| 161 kB 64.7 MB/s \n", + "\u001b[?25hCollecting tensorboard<3.0.0,>=2.10.0\n", + " Downloading tensorboard-2.11.0-py3-none-any.whl (6.0 MB)\n", + "\u001b[K |████████████████████████████████| 6.0 MB 65.0 MB/s \n", + "\u001b[?25hCollecting moviepy<2.0.0,>=1.0.3\n", + " Downloading moviepy-1.0.3.tar.gz (388 kB)\n", + "\u001b[K |████████████████████████████████| 388 kB 59.5 MB/s \n", + "\u001b[?25hCollecting gym==0.23.1\n", + " Downloading gym-0.23.1.tar.gz (626 kB)\n", + "\u001b[K |████████████████████████████████| 626 kB 59.9 MB/s \n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n", + "Collecting gymnasium<0.27.0,>=0.26.3\n", + " Downloading Gymnasium-0.26.3-py3-none-any.whl (836 kB)\n", + "\u001b[K |████████████████████████████████| 836 kB 64.6 MB/s \n", + "\u001b[?25hCollecting AutoROM[accept-rom-license]<0.5.0,>=0.4.2\n", + " Downloading AutoROM-0.4.2-py3-none-any.whl (16 kB)\n", + "Collecting ale-py==0.7.4\n", + " Downloading ale_py-0.7.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)\n", + "\u001b[K |████████████████████████████████| 1.6 MB 55.1 MB/s \n", + "\u001b[?25hRequirement already satisfied: opencv-python<5.0.0.0,>=4.6.0.66 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (4.6.0.66)\n", + "Requirement already satisfied: jax<0.4.0,>=0.3.17 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (0.3.25)\n", + "Collecting flax<0.7.0,>=0.6.0\n", + " Downloading flax-0.6.3-py3-none-any.whl (197 kB)\n", + "\u001b[K |████████████████████████████████| 197 kB 73.9 MB/s \n", + "\u001b[?25hRequirement already satisfied: jaxlib<0.4.0,>=0.3.15 in /usr/local/lib/python3.8/dist-packages (from cleanrl[dqn-atari-jax]) (0.3.25+cuda11.cudnn805)\n", + "Requirement already satisfied: importlib-metadata>=4.10.0 in /usr/local/lib/python3.8/dist-packages (from ale-py==0.7.4->cleanrl[dqn-atari-jax]) (5.1.0)\n", + "Requirement already satisfied: importlib-resources in /usr/local/lib/python3.8/dist-packages (from ale-py==0.7.4->cleanrl[dqn-atari-jax]) (5.10.1)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from ale-py==0.7.4->cleanrl[dqn-atari-jax]) (1.21.6)\n", + "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.8/dist-packages (from gym==0.23.1->cleanrl[dqn-atari-jax]) (0.0.8)\n", + "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.8/dist-packages (from gym==0.23.1->cleanrl[dqn-atari-jax]) (1.5.0)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (1.3.5)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (3.2.2)\n", + "Requirement already satisfied: click in /usr/local/lib/python3.8/dist-packages (from AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (7.1.2)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (2.23.0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (from AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (4.64.1)\n", + "Collecting AutoROM.accept-rom-license\n", + " Downloading AutoROM.accept-rom-license-0.5.0.tar.gz (10 kB)\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: PyYAML>=5.4.1 in /usr/local/lib/python3.8/dist-packages (from flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (6.0)\n", + "Collecting rich>=11.1\n", + " Downloading rich-13.0.0-py3-none-any.whl (238 kB)\n", + "\u001b[K |████████████████████████████████| 238 kB 76.7 MB/s \n", + "\u001b[?25hRequirement already satisfied: typing-extensions>=4.1.1 in /usr/local/lib/python3.8/dist-packages (from flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (4.4.0)\n", + "Collecting orbax\n", + " Downloading orbax-0.0.23-py3-none-any.whl (66 kB)\n", + "\u001b[K |████████████████████████████████| 66 kB 6.3 MB/s \n", + "\u001b[?25hCollecting optax\n", + " Downloading optax-0.1.4-py3-none-any.whl (154 kB)\n", + "\u001b[K |████████████████████████████████| 154 kB 82.0 MB/s \n", + "\u001b[?25hCollecting tensorstore\n", + " Downloading tensorstore-0.1.28-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)\n", + "\u001b[K |████████████████████████████████| 8.3 MB 64.7 MB/s \n", + "\u001b[?25hRequirement already satisfied: msgpack in /usr/local/lib/python3.8/dist-packages (from flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (1.0.4)\n", + "Collecting gymnasium-notices>=0.0.1\n", + " Downloading gymnasium_notices-0.0.1-py3-none-any.whl (2.8 kB)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<0.12.0,>=0.11.1->cleanrl[dqn-atari-jax]) (3.8.2)\n", + "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<0.12.0,>=0.11.1->cleanrl[dqn-atari-jax]) (21.3)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.8/dist-packages (from importlib-metadata>=4.10.0->ale-py==0.7.4->cleanrl[dqn-atari-jax]) (3.11.0)\n", + "Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.8/dist-packages (from jax<0.4.0,>=0.3.17->cleanrl[dqn-atari-jax]) (1.7.3)\n", + "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.8/dist-packages (from jax<0.4.0,>=0.3.17->cleanrl[dqn-atari-jax]) (3.3.0)\n", + "Requirement already satisfied: decorator<5.0,>=4.0.2 in /usr/local/lib/python3.8/dist-packages (from moviepy<2.0.0,>=1.0.3->cleanrl[dqn-atari-jax]) (4.4.2)\n", + "Collecting proglog<=1.0.0\n", + " Downloading proglog-0.1.10-py3-none-any.whl (6.1 kB)\n", + "Requirement already satisfied: imageio<3.0,>=2.5 in /usr/local/lib/python3.8/dist-packages (from moviepy<2.0.0,>=1.0.3->cleanrl[dqn-atari-jax]) (2.9.0)\n", + "Collecting imageio_ffmpeg>=0.2.0\n", + " Downloading imageio_ffmpeg-0.4.7-py3-none-manylinux2010_x86_64.whl (26.9 MB)\n", + "\u001b[K |████████████████████████████████| 26.9 MB 47.9 MB/s \n", + "\u001b[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.8/dist-packages (from imageio<3.0,>=2.5->moviepy<2.0.0,>=1.0.3->cleanrl[dqn-atari-jax]) (7.1.2)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.9->huggingface-hub<0.12.0,>=0.11.1->cleanrl[dqn-atari-jax]) (3.0.9)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (2022.12.7)\n", + "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (1.24.3)\n", + "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (2.10)\n", + "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->AutoROM[accept-rom-license]<0.5.0,>=0.4.2->cleanrl[dqn-atari-jax]) (3.0.4)\n", + "Collecting commonmark<0.10.0,>=0.9.0\n", + " Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)\n", + "\u001b[K |████████████████████████████████| 51 kB 5.0 MB/s \n", + "\u001b[?25hRequirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.8/dist-packages (from rich>=11.1->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (2.6.1)\n", + "Requirement already satisfied: protobuf<4,>=3.9.2 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (3.19.6)\n", + "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (2.15.0)\n", + "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.6.1)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (3.4.1)\n", + "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.38.4)\n", + "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.4.6)\n", + "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.3.0)\n", + "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (57.4.0)\n", + "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.8.1)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.0.1)\n", + "Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.8/dist-packages (from tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.51.1)\n", + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.2.8)\n", + "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (5.2.0)\n", + "Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.15.0)\n", + "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.8/dist-packages (from google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (4.9)\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.8/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (1.3.1)\n", + "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.8/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (0.4.8)\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.8/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<3.0.0,>=2.10.0->cleanrl[dqn-atari-jax]) (3.2.2)\n", + "Collecting pathtools\n", + " Downloading pathtools-0.1.2.tar.gz (11 kB)\n", + "Requirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.8/dist-packages (from wandb<0.14.0,>=0.13.6->cleanrl[dqn-atari-jax]) (2.3)\n", + "Collecting GitPython>=1.0.0\n", + " Downloading GitPython-3.1.30-py3-none-any.whl (184 kB)\n", + "\u001b[K |████████████████████████████████| 184 kB 71.4 MB/s \n", + "\u001b[?25hRequirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.8/dist-packages (from wandb<0.14.0,>=0.13.6->cleanrl[dqn-atari-jax]) (5.4.8)\n", + "Collecting shortuuid>=0.5.0\n", + " Downloading shortuuid-1.0.11-py3-none-any.whl (10 kB)\n", + "Collecting sentry-sdk>=1.0.0\n", + " Downloading sentry_sdk-1.12.1-py2.py3-none-any.whl (174 kB)\n", + "\u001b[K |████████████████████████████████| 174 kB 80.8 MB/s \n", + "\u001b[?25hCollecting docker-pycreds>=0.4.0\n", + " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", + "Collecting setproctitle\n", + " Downloading setproctitle-1.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB)\n", + "Collecting gitdb<5,>=4.0.1\n", + " Downloading gitdb-4.0.10-py3-none-any.whl (62 kB)\n", + "\u001b[K |████████████████████████████████| 62 kB 1.7 MB/s \n", + "\u001b[?25hCollecting smmap<6,>=3.0.1\n", + " Downloading smmap-5.0.0-py3-none-any.whl (24 kB)\n", + "Collecting sentry-sdk>=1.0.0\n", + " Downloading sentry_sdk-1.12.0-py2.py3-none-any.whl (173 kB)\n", + "\u001b[K |████████████████████████████████| 173 kB 69.0 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.11.1-py2.py3-none-any.whl (168 kB)\n", + "\u001b[K |████████████████████████████████| 168 kB 66.6 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.11.0-py2.py3-none-any.whl (168 kB)\n", + "\u001b[K |████████████████████████████████| 168 kB 8.1 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.10.1-py2.py3-none-any.whl (166 kB)\n", + "\u001b[K |████████████████████████████████| 166 kB 10.6 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.10.0-py2.py3-none-any.whl (166 kB)\n", + "\u001b[K |████████████████████████████████| 166 kB 71.4 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.9.10-py2.py3-none-any.whl (162 kB)\n", + "\u001b[K |████████████████████████████████| 162 kB 70.1 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.9.9-py2.py3-none-any.whl (162 kB)\n", + "\u001b[K |████████████████████████████████| 162 kB 70.2 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.9.8-py2.py3-none-any.whl (158 kB)\n", + "\u001b[K |████████████████████████████████| 158 kB 75.4 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.9.7-py2.py3-none-any.whl (157 kB)\n", + "\u001b[K |████████████████████████████████| 157 kB 77.6 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.9.6-py2.py3-none-any.whl (157 kB)\n", + "\u001b[K |████████████████████████████████| 157 kB 83.8 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.9.5-py2.py3-none-any.whl (157 kB)\n", + "\u001b[K |████████████████████████████████| 157 kB 88.0 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.9.4-py2.py3-none-any.whl (157 kB)\n", + "\u001b[K |████████████████████████████████| 157 kB 80.1 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.9.3-py2.py3-none-any.whl (157 kB)\n", + "\u001b[K |████████████████████████████████| 157 kB 84.8 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.9.2-py2.py3-none-any.whl (157 kB)\n", + "\u001b[K |████████████████████████████████| 157 kB 85.7 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.9.1-py2.py3-none-any.whl (157 kB)\n", + "\u001b[K |████████████████████████████████| 157 kB 83.5 MB/s \n", + "\u001b[?25h Downloading sentry_sdk-1.9.0-py2.py3-none-any.whl (156 kB)\n", + "\u001b[K |████████████████████████████████| 156 kB 84.0 MB/s \n", + "\u001b[?25hCollecting libtorrent\n", + " Using cached libtorrent-2.0.7-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (8.6 MB)\n", + "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (2.8.2)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (0.11.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (1.4.4)\n", + "Collecting chex>=0.1.5\n", + " Downloading chex-0.1.5-py3-none-any.whl (85 kB)\n", + "\u001b[K |████████████████████████████████| 85 kB 4.9 MB/s \n", + "\u001b[?25hRequirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.8/dist-packages (from chex>=0.1.5->optax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.1.7)\n", + "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.8/dist-packages (from chex>=0.1.5->optax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.12.0)\n", + "Requirement already satisfied: pytest in /usr/local/lib/python3.8/dist-packages (from orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (3.6.4)\n", + "Collecting cached_property\n", + " Downloading cached_property-1.5.2-py2.py3-none-any.whl (7.6 kB)\n", + "Requirement already satisfied: etils in /usr/local/lib/python3.8/dist-packages (from orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.9.0)\n", + "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->stable-baselines3==1.2.0->cleanrl[dqn-atari-jax]) (2022.6)\n", + "Requirement already satisfied: more-itertools>=4.0.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (9.0.0)\n", + "Requirement already satisfied: pluggy<0.8,>=0.5 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (0.7.1)\n", + "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (22.1.0)\n", + "Requirement already satisfied: py>=1.5.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (1.11.0)\n", + "Requirement already satisfied: atomicwrites>=1.0 in /usr/local/lib/python3.8/dist-packages (from pytest->orbax->flax<0.7.0,>=0.6.0->cleanrl[dqn-atari-jax]) (1.4.1)\n", + "Building wheels for collected packages: gym, moviepy, AutoROM.accept-rom-license, pathtools\n", + " Building wheel for gym (PEP 517) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for gym: filename=gym-0.23.1-py3-none-any.whl size=701376 sha256=7b59f30aef873fc1494bd2f2eeac27b103b64ae6ee87d554c8b61b9ddbe35765\n", + " Stored in directory: /root/.cache/pip/wheels/78/28/77/b0c74e80a2a4faae0161d5c53bc4f8e436e77aedc79136ee13\n", + " Building wheel for moviepy (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for moviepy: filename=moviepy-1.0.3-py3-none-any.whl size=110742 sha256=640c1c0df827ed5835373acab4d2d7b93e98e33b5e6cb90e3d5e703933f9bcf8\n", + " Stored in directory: /root/.cache/pip/wheels/e4/a4/db/0368d3a04033da662e13926594b3a8cf1aa4ffeefe570cfac1\n", + " Building wheel for AutoROM.accept-rom-license (PEP 517) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for AutoROM.accept-rom-license: filename=AutoROM.accept_rom_license-0.5.0-py3-none-any.whl size=440868 sha256=a3833e2c22c21355029cb083d9ea62b7abe329af3757ccdce9b0d2a5cc06949f\n", + " Stored in directory: /root/.cache/pip/wheels/bf/c9/25/578470ae932b494c313dc22e6c57afff192140fb3cd5acf185\n", + " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8806 sha256=57226a75b752bf852ac2f0f5ad878217a63376d6c44a4b29ccdf40b4921bf4bc\n", + " Stored in directory: /root/.cache/pip/wheels/4c/8e/7e/72fbc243e1aeecae64a96875432e70d4e92f3d2d18123be004\n", + "Successfully built gym moviepy AutoROM.accept-rom-license pathtools\n", + "Installing collected packages: smmap, gitdb, tensorstore, shortuuid, setproctitle, sentry-sdk, proglog, pathtools, libtorrent, imageio-ffmpeg, gymnasium-notices, gym, GitPython, docker-pycreds, commonmark, chex, cached-property, wandb, tensorboard, stable-baselines3, rich, pygame, orbax, optax, moviepy, huggingface-hub, gymnasium, AutoROM.accept-rom-license, AutoROM, flax, cleanrl-test, ale-py\n", + " Attempting uninstall: gym\n", + " Found existing installation: gym 0.25.2\n", + " Uninstalling gym-0.25.2:\n", + " Successfully uninstalled gym-0.25.2\n", + " Attempting uninstall: tensorboard\n", + " Found existing installation: tensorboard 2.9.1\n", + " Uninstalling tensorboard-2.9.1:\n", + " Successfully uninstalled tensorboard-2.9.1\n", + " Attempting uninstall: moviepy\n", + " Found existing installation: moviepy 0.2.3.5\n", + " Uninstalling moviepy-0.2.3.5:\n", + " Successfully uninstalled moviepy-0.2.3.5\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow 2.9.2 requires tensorboard<2.10,>=2.9, but you have tensorboard 2.11.0 which is incompatible.\u001b[0m\n", + "Successfully installed AutoROM-0.4.2 AutoROM.accept-rom-license-0.5.0 GitPython-3.1.30 ale-py-0.7.4 cached-property-1.5.2 chex-0.1.5 cleanrl-test-1.1.2 commonmark-0.9.1 docker-pycreds-0.4.0 flax-0.6.3 gitdb-4.0.10 gym-0.23.1 gymnasium-0.26.3 gymnasium-notices-0.0.1 huggingface-hub-0.11.1 imageio-ffmpeg-0.4.7 libtorrent-2.0.7 moviepy-1.0.3 optax-0.1.4 orbax-0.0.23 pathtools-0.1.2 proglog-0.1.10 pygame-2.1.0 rich-13.0.0 sentry-sdk-1.9.0 setproctitle-1.3.2 shortuuid-1.0.11 smmap-5.0.0 stable-baselines3-1.2.0 tensorboard-2.11.0 tensorstore-0.1.28 wandb-0.13.7\n" + ] + } + ], + "source": [ + "!pip install --upgrade \"cleanrl[dqn-atari-jax]\" # CAVEAT: the extra key is `dqn-atari-jax` with dashes instead of `dqn_atari_jax` with underscores" ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/gym/core.py:172: DeprecationWarning: \u001B[33mWARN: Function `env.seed(seed)` is marked as deprecated and will be removed in the future. Please use `env.reset(seed=seed) instead.\u001B[0m\n", - " deprecation(\n" - ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "eval_episode=0, episodic_return=500.0\n", - "eval_episode=1, episodic_return=500.0\n", - "loading models from cleanrl/Acrobot-v1-dqn-seed1\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0c074497102c45aab5db63d863a493a5", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "xXQXZTh_AHZ0" }, - "text/plain": [ - "Downloading: 0%| | 0.00/47.1k [00:00\n", + " \n", + " Your browser does not support the video tag.\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from IPython.display import Video\n", + "Video('videos/eval/rl-video-episode-0.mp4', embed=True)" ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "eval_episode=0, episodic_return=500.0\n", - "eval_episode=1, episodic_return=500.0\n", - "loading models from cleanrl/Acrobot-v1-dqn_jax-seed1\n" - ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "03848b157e164490a7a509028df7cad8", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "WU29XP1ICwxv" }, - "text/plain": [ - "Downloading: 0%| | 0.00/45.2k [00:00 Date: Thu, 18 Jan 2024 13:52:59 -0500 Subject: [PATCH 19/20] Update dead pettingzoo.ml links to Farama foundation links --- docs/rl-algorithms/ppo.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/rl-algorithms/ppo.md b/docs/rl-algorithms/ppo.md index e83b38e63..c2cc19181 100644 --- a/docs/rl-algorithms/ppo.md +++ b/docs/rl-algorithms/ppo.md @@ -1029,7 +1029,7 @@ Tracked experiments and game play videos: ## `ppo_pettingzoo_ma_atari.py` -[ppo_pettingzoo_ma_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_pettingzoo_ma_atari.py) trains an agent to learn playing Atari games via selfplay. The selfplay environment is implemented as a vectorized environment from [PettingZoo.ml](https://www.pettingzoo.ml/atari). The basic idea is to create vectorized environment $E$ with `num_envs = N`, where $N$ is the number of players in the game. Say $N = 2$, then the 0-th sub environment of $E$ will return the observation for player 0 and 1-th sub environment will return the observation of player 1. Then the two environments takes a batch of 2 actions and execute them for player 0 and player 1, respectively. See "Vectorized architecture" in [The 37 Implementation Details of Proximal Policy Optimization](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/) for more detail. +[ppo_pettingzoo_ma_atari.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_pettingzoo_ma_atari.py) trains an agent to learn playing Atari games via selfplay. The selfplay environment is implemented as a vectorized environment from [PettingZoo](https://pettingzoo.farama.org/environments/atari/). The basic idea is to create vectorized environment $E$ with `num_envs = N`, where $N$ is the number of players in the game. Say $N = 2$, then the 0-th sub environment of $E$ will return the observation for player 0 and 1-th sub environment will return the observation of player 1. Then the two environments takes a batch of 2 actions and execute them for player 0 and player 1, respectively. See "Vectorized architecture" in [The 37 Implementation Details of Proximal Policy Optimization](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/) for more detail. `ppo_pettingzoo_ma_atari.py` has the following features: @@ -1064,7 +1064,7 @@ Tracked experiments and game play videos: python cleanrl/ppo_pettingzoo_ma_atari.py --env-id surround_v2 ``` -See [https://www.pettingzoo.ml/atari](https://www.pettingzoo.ml/atari) for a full-list of supported environments such as `basketball_pong_v3`. Notice pettingzoo sometimes introduces breaking changes, so make sure to install the pinned dependencies via `poetry`. +See [https://pettingzoo.farama.org/environments/atari/](https://pettingzoo.farama.org/environments/atari/) for a full-list of supported environments such as `basketball_pong_v3`. Notice pettingzoo sometimes introduces breaking changes, so make sure to install the pinned dependencies via `poetry`. ### Explanation of the logged metrics From 1b725cfbac36b884de33e169554ce662ce79889b Mon Sep 17 00:00:00 2001 From: elliottower Date: Thu, 18 Jan 2024 14:26:44 -0500 Subject: [PATCH 20/20] Update to newly release SuperSuit 3.9.2 (minor bugfixes but best to keep going forward) --- poetry.lock | 14 +++++++++----- pyproject.toml | 2 +- requirements/requirements-pettingzoo.txt | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index 657c56357..4cc2106a0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1080,6 +1080,7 @@ files = [ {file = "greenlet-2.0.2-cp27-cp27m-win32.whl", hash = "sha256:6c3acb79b0bfd4fe733dff8bc62695283b57949ebcca05ae5c129eb606ff2d74"}, {file = "greenlet-2.0.2-cp27-cp27m-win_amd64.whl", hash = "sha256:283737e0da3f08bd637b5ad058507e578dd462db259f7f6e4c5c365ba4ee9343"}, {file = "greenlet-2.0.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:d27ec7509b9c18b6d73f2f5ede2622441de812e7b1a80bbd446cb0633bd3d5ae"}, + {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d967650d3f56af314b72df7089d96cda1083a7fc2da05b375d2bc48c82ab3f3c"}, {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:30bcf80dda7f15ac77ba5af2b961bdd9dbc77fd4ac6105cee85b0d0a5fcf74df"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26fbfce90728d82bc9e6c38ea4d038cba20b7faf8a0ca53a9c07b67318d46088"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9190f09060ea4debddd24665d6804b995a9c122ef5917ab26e1566dcc712ceeb"}, @@ -1088,6 +1089,7 @@ files = [ {file = "greenlet-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:76ae285c8104046b3a7f06b42f29c7b73f77683df18c49ab5af7983994c2dd91"}, {file = "greenlet-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:2d4686f195e32d36b4d7cf2d166857dbd0ee9f3d20ae349b6bf8afc8485b3645"}, {file = "greenlet-2.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c4302695ad8027363e96311df24ee28978162cdcdd2006476c43970b384a244c"}, + {file = "greenlet-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d4606a527e30548153be1a9f155f4e283d109ffba663a15856089fb55f933e47"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c48f54ef8e05f04d6eff74b8233f6063cb1ed960243eacc474ee73a2ea8573ca"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a1846f1b999e78e13837c93c778dcfc3365902cfb8d1bdb7dd73ead37059f0d0"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a06ad5312349fec0ab944664b01d26f8d1f05009566339ac6f63f56589bc1a2"}, @@ -1117,6 +1119,7 @@ files = [ {file = "greenlet-2.0.2-cp37-cp37m-win32.whl", hash = "sha256:3f6ea9bd35eb450837a3d80e77b517ea5bc56b4647f5502cd28de13675ee12f7"}, {file = "greenlet-2.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:7492e2b7bd7c9b9916388d9df23fa49d9b88ac0640db0a5b4ecc2b653bf451e3"}, {file = "greenlet-2.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b864ba53912b6c3ab6bcb2beb19f19edd01a6bfcbdfe1f37ddd1778abfe75a30"}, + {file = "greenlet-2.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1087300cf9700bbf455b1b97e24db18f2f77b55302a68272c56209d5587c12d1"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:ba2956617f1c42598a308a84c6cf021a90ff3862eddafd20c3333d50f0edb45b"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc3a569657468b6f3fb60587e48356fe512c1754ca05a564f11366ac9e306526"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8eab883b3b2a38cc1e050819ef06a7e6344d4a990d24d45bc6f2cf959045a45b"}, @@ -1125,6 +1128,7 @@ files = [ {file = "greenlet-2.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0ef99cdbe2b682b9ccbb964743a6aca37905fda5e0452e5ee239b1654d37f2a"}, {file = "greenlet-2.0.2-cp38-cp38-win32.whl", hash = "sha256:b80f600eddddce72320dbbc8e3784d16bd3fb7b517e82476d8da921f27d4b249"}, {file = "greenlet-2.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:4d2e11331fc0c02b6e84b0d28ece3a36e0548ee1a1ce9ddde03752d9b79bba40"}, + {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8512a0c38cfd4e66a858ddd1b17705587900dd760c6003998e9472b77b56d417"}, {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:88d9ab96491d38a5ab7c56dd7a3cc37d83336ecc564e4e8816dbed12e5aaefc8"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:561091a7be172ab497a3527602d467e2b3fbe75f9e783d8b8ce403fa414f71a6"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:971ce5e14dc5e73715755d0ca2975ac88cfdaefcaab078a284fea6cfabf866df"}, @@ -3707,13 +3711,13 @@ tests = ["black", "isort (>=5.0)", "mypy", "pytest", "pytest-cov", "pytest-env", [[package]] name = "supersuit" -version = "3.9.1" +version = "3.9.2" description = "Wrappers for Gymnasium and PettingZoo" optional = true python-versions = "<3.12,>=3.8" files = [ - {file = "SuperSuit-3.9.1-py3-none-any.whl", hash = "sha256:24907f8edb9578c8b35eb374e53fdde96daf37c006d8e929c7bf485e5c52f356"}, - {file = "SuperSuit-3.9.1.tar.gz", hash = "sha256:536732019e5f00420a17a7e3078a73824191515b6b0af37b06322d4846cda655"}, + {file = "SuperSuit-3.9.2-py3-none-any.whl", hash = "sha256:1dcecd419100eeed19c51444a341dd7ab14deaf3cd775ba475de4e63eba6159c"}, + {file = "SuperSuit-3.9.2.tar.gz", hash = "sha256:60e384fe63ab6752acbfc34f991f48d6346592b1dd3475138e3599ab41eaaf24"}, ] [package.dependencies] @@ -3722,7 +3726,7 @@ numpy = ">=1.19.0" tinyscaler = ">=1.2.6" [package.extras] -testing = ["pettingzoo[butterfly,classic] (>=1.23.1)", "pytest"] +testing = ["moviepy (>=1.0.0)", "pettingzoo[butterfly,classic] (>=1.23.1)", "pytest", "stable-baselines3 (>=2.0.0)"] [[package]] name = "tabulate" @@ -4259,4 +4263,4 @@ qdagger-dqn-atari-jax-impalacnn = ["AutoROM", "ale-py", "flax", "jax", "jaxlib", [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "453c8d2c113d81cb529f771d450301c0a4fa6d5ab0bfc3964a110d650ee7db39" +content-hash = "aaa9d84a456774e5f9ecf02beade2b4f42d71980872f1f929305b05c57d73958" diff --git a/pyproject.toml b/pyproject.toml index 1b45385a1..e5b4282c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ optuna = {version = "^3.0.1", optional = true} optuna-dashboard = {version = "^0.7.2", optional = true} envpool = {version = "^0.6.4", optional = true} PettingZoo = {version = "^1.24.3", optional = true} -SuperSuit = {version = "^3.9.1", optional = true} +SuperSuit = {version = ">=3.9.2", optional = true} multi-agent-ale-py = {version = "0.1.11", optional = true} boto3 = {version = "^1.24.70", optional = true} awscli = {version = "^1.31.0", optional = true} diff --git a/requirements/requirements-pettingzoo.txt b/requirements/requirements-pettingzoo.txt index abcff76d2..f997ba402 100644 --- a/requirements/requirements-pettingzoo.txt +++ b/requirements/requirements-pettingzoo.txt @@ -63,7 +63,7 @@ shtab==1.6.4 ; python_version >= "3.8" and python_version < "3.11" six==1.16.0 ; python_version >= "3.8" and python_version < "3.11" smmap==5.0.0 ; python_version >= "3.8" and python_version < "3.11" stable-baselines3==2.0.0 ; python_version >= "3.8" and python_version < "3.11" -supersuit==3.9.1 ; python_version >= "3.8" and python_version < "3.11" +supersuit==3.9.2 ; python_version >= "3.8" and python_version < "3.11" tenacity==8.2.3 ; python_version >= "3.8" and python_version < "3.11" tensorboard-data-server==0.6.1 ; python_version >= "3.8" and python_version < "3.11" tensorboard-plugin-wit==1.8.1 ; python_version >= "3.8" and python_version < "3.11"