Skip to content

Commit

Permalink
tianshou tuts fixed (#852)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillDudley authored Nov 7, 2022
1 parent 4936f19 commit c082f1a
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/linux-tutorials-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
fail-fast: false
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10']
tutorial: ['CleanRL']
tutorial: ['CleanRL', 'Tianshou']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -33,4 +33,4 @@ jobs:
pip install -r requirements.txt
pip uninstall -y pettingzoo
pip install -e ../..
for f in *.py; do xvfb-run -s "-screen 0 1024x768x24" python "$f"; done
for f in *.py; do xvfb-run -a -s "-screen 0 1024x768x24" python "$f"; done
3 changes: 3 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ environments/third_party_envs
:caption: Tutorials
tutorials/cleanrl/implementing_PPO
tutorials/tianshou/beginner
tutorials/tianshou/intermediate
tutorials/tianshou/advanced
```

```{toctree}
Expand Down
3 changes: 2 additions & 1 deletion tutorials/Tianshou/2_training_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def _get_agents(
if agent_learn is None:
# model
net = Net(
state_shape=observation_space.shape or observation_space.n,
state_shape=observation_space["observation"].shape
or observation_space["observation"].n,
action_shape=env.action_space.shape or env.action_space.n,
hidden_sizes=[128, 128, 128, 128],
device="cuda" if torch.cuda.is_available() else "cpu",
Expand Down
4 changes: 3 additions & 1 deletion tutorials/Tianshou/3_cli_and_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def get_agents(
if isinstance(env.observation_space, gym.spaces.Dict)
else env.observation_space
)
args.state_shape = observation_space.shape or observation_space.n
args.state_shape = (
observation_space["observation"].shape or observation_space["observation"].n
)
args.action_shape = env.action_space.shape or env.action_space.n
if agent_learn is None:
# model
Expand Down
5 changes: 3 additions & 2 deletions tutorials/Tianshou/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pettingzoo==1.22.0
git+https://github.com/thu-ml/tianshou
pettingzoo[classic]==1.22.1
packaging==21.3
git+https://github.com/WillDudley/tianshou.git

0 comments on commit c082f1a

Please sign in to comment.