Skip to content

Commit

Permalink
Update Ray tutorials to RLlib 2.7.0 (#1112)
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottower authored Oct 13, 2023
1 parent c90f947 commit 43d9427
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linux-tutorials-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
tutorial: [Tianshou, CustomEnvironment, CleanRL, SB3/kaz, SB3/waterworld, SB3/connect_four, SB3/test, AgileRL] # TODO: add back Ray once next release after 2.6.2
tutorial: [Tianshou, CustomEnvironment, CleanRL, SB3/kaz, SB3/waterworld, SB3/connect_four, SB3/test, AgileRL]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
5 changes: 5 additions & 0 deletions tutorials/Ray/render_rllib_leduc_holdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@

args = parser.parse_args()


if args.checkpoint_path is None:
print("The following arguments are required: --checkpoint-path")
exit(0)

checkpoint_path = os.path.expanduser(args.checkpoint_path)


Expand Down
35 changes: 34 additions & 1 deletion tutorials/Ray/render_rllib_pistonball.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,40 @@
from ray.rllib.algorithms.ppo import PPO
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.tune.registry import register_env
from tutorials.Ray.rllib_pistonball import CNNModelV2
from torch import nn

from pettingzoo.butterfly import pistonball_v6


class CNNModelV2(TorchModelV2, nn.Module):
def __init__(self, obs_space, act_space, num_outputs, *args, **kwargs):
TorchModelV2.__init__(self, obs_space, act_space, num_outputs, *args, **kwargs)
nn.Module.__init__(self)
self.model = nn.Sequential(
nn.Conv2d(3, 32, [8, 8], stride=(4, 4)),
nn.ReLU(),
nn.Conv2d(32, 64, [4, 4], stride=(2, 2)),
nn.ReLU(),
nn.Conv2d(64, 64, [3, 3], stride=(1, 1)),
nn.ReLU(),
nn.Flatten(),
(nn.Linear(3136, 512)),
nn.ReLU(),
)
self.policy_fn = nn.Linear(512, num_outputs)
self.value_fn = nn.Linear(512, 1)

def forward(self, input_dict, state, seq_lens):
model_out = self.model(input_dict["obs"].permute(0, 3, 1, 2))
self._value_out = self.value_fn(model_out)
return self.policy_fn(model_out), state

def value_function(self):
return self._value_out.flatten()


os.environ["SDL_VIDEODRIVER"] = "dummy"

parser = argparse.ArgumentParser(
Expand All @@ -29,6 +58,10 @@

args = parser.parse_args()

if args.checkpoint_path is None:
print("The following arguments are required: --checkpoint-path")
exit(0)

checkpoint_path = os.path.expanduser(args.checkpoint_path)

ModelCatalog.register_custom_model("CNNModelV2", CNNModelV2)
Expand Down
3 changes: 1 addition & 2 deletions tutorials/Ray/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
PettingZoo[classic,butterfly]>=1.24.0
Pillow>=9.4.0
# note: currently requires nightly release, see https://docs.ray.io/en/latest/ray-overview/installation.html#daily-releases-nightlies
ray[rllib]>2.6.3
ray[rllib]>=2.7.0
SuperSuit>=3.9.0
torch>=1.13.1
tensorflow-probability>=0.19.0

0 comments on commit 43d9427

Please sign in to comment.