From 43d94272430f19569e04306048530a0c34c32935 Mon Sep 17 00:00:00 2001 From: Elliot Tower Date: Fri, 13 Oct 2023 13:28:46 -0400 Subject: [PATCH] Update Ray tutorials to RLlib 2.7.0 (#1112) --- .github/workflows/linux-tutorials-test.yml | 2 +- tutorials/Ray/render_rllib_leduc_holdem.py | 5 ++++ tutorials/Ray/render_rllib_pistonball.py | 35 +++++++++++++++++++++- tutorials/Ray/requirements.txt | 3 +- 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/.github/workflows/linux-tutorials-test.yml b/.github/workflows/linux-tutorials-test.yml index 7732a1bdc..4de3f2d42 100644 --- a/.github/workflows/linux-tutorials-test.yml +++ b/.github/workflows/linux-tutorials-test.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: python-version: ['3.8', '3.9', '3.10', '3.11'] - tutorial: [Tianshou, CustomEnvironment, CleanRL, SB3/kaz, SB3/waterworld, SB3/connect_four, SB3/test, AgileRL] # TODO: add back Ray once next release after 2.6.2 + tutorial: [Tianshou, CustomEnvironment, CleanRL, SB3/kaz, SB3/waterworld, SB3/connect_four, SB3/test, AgileRL] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/tutorials/Ray/render_rllib_leduc_holdem.py b/tutorials/Ray/render_rllib_leduc_holdem.py index b514872ff..ac1a7921c 100644 --- a/tutorials/Ray/render_rllib_leduc_holdem.py +++ b/tutorials/Ray/render_rllib_leduc_holdem.py @@ -28,6 +28,11 @@ args = parser.parse_args() + +if args.checkpoint_path is None: + print("The following arguments are required: --checkpoint-path") + exit(0) + checkpoint_path = os.path.expanduser(args.checkpoint_path) diff --git a/tutorials/Ray/render_rllib_pistonball.py b/tutorials/Ray/render_rllib_pistonball.py index a15edd3ea..4e29ec3ba 100644 --- a/tutorials/Ray/render_rllib_pistonball.py +++ b/tutorials/Ray/render_rllib_pistonball.py @@ -12,11 +12,40 @@ from ray.rllib.algorithms.ppo import PPO from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv from ray.rllib.models import ModelCatalog +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.tune.registry import register_env -from tutorials.Ray.rllib_pistonball import CNNModelV2 +from torch import nn from pettingzoo.butterfly import pistonball_v6 + +class CNNModelV2(TorchModelV2, nn.Module): + def __init__(self, obs_space, act_space, num_outputs, *args, **kwargs): + TorchModelV2.__init__(self, obs_space, act_space, num_outputs, *args, **kwargs) + nn.Module.__init__(self) + self.model = nn.Sequential( + nn.Conv2d(3, 32, [8, 8], stride=(4, 4)), + nn.ReLU(), + nn.Conv2d(32, 64, [4, 4], stride=(2, 2)), + nn.ReLU(), + nn.Conv2d(64, 64, [3, 3], stride=(1, 1)), + nn.ReLU(), + nn.Flatten(), + (nn.Linear(3136, 512)), + nn.ReLU(), + ) + self.policy_fn = nn.Linear(512, num_outputs) + self.value_fn = nn.Linear(512, 1) + + def forward(self, input_dict, state, seq_lens): + model_out = self.model(input_dict["obs"].permute(0, 3, 1, 2)) + self._value_out = self.value_fn(model_out) + return self.policy_fn(model_out), state + + def value_function(self): + return self._value_out.flatten() + + os.environ["SDL_VIDEODRIVER"] = "dummy" parser = argparse.ArgumentParser( @@ -29,6 +58,10 @@ args = parser.parse_args() +if args.checkpoint_path is None: + print("The following arguments are required: --checkpoint-path") + exit(0) + checkpoint_path = os.path.expanduser(args.checkpoint_path) ModelCatalog.register_custom_model("CNNModelV2", CNNModelV2) diff --git a/tutorials/Ray/requirements.txt b/tutorials/Ray/requirements.txt index 3cd41dae8..df0623c4b 100644 --- a/tutorials/Ray/requirements.txt +++ b/tutorials/Ray/requirements.txt @@ -1,7 +1,6 @@ PettingZoo[classic,butterfly]>=1.24.0 Pillow>=9.4.0 -# note: currently requires nightly release, see https://docs.ray.io/en/latest/ray-overview/installation.html#daily-releases-nightlies -ray[rllib]>2.6.3 +ray[rllib]>=2.7.0 SuperSuit>=3.9.0 torch>=1.13.1 tensorflow-probability>=0.19.0