Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Fix format and ci #176

Merged
merged 8 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: ['3.8', '3.9', '3.10']

steps:
- uses: actions/checkout@v2
Expand Down
7 changes: 3 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@ repos:
hooks:
- id: black
files: 'mbrl'
language_version: python3.7

- repo: https://gitlab.com/pycqa/flake8
- repo: https://github.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8
files: 'mbrl'

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.971
rev: v0.991
hooks:
- id: mypy
files: 'mbrl'
Expand All @@ -22,7 +21,7 @@ repos:
exclude: setup.py

- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ See also our companion [paper](https://arxiv.org/abs/2104.10159).

#### Standard Installation

``mbrl`` requires Python 3.7+ library and [PyTorch (>= 1.7)](https://pytorch.org).
``mbrl`` requires Python 3.8+ library and [PyTorch (>= 1.7)](https://pytorch.org).
To install the latest stable version, run

pip install mbrl
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Installation

Standard Installation
^^^^^^^^^^^^^^^^^^^^^
``mbrl`` requires Python 3.7+ and `PyTorch (>= 1.7) <https://pytorch.org/>`_.
``mbrl`` requires Python 3.8+ and `PyTorch (>= 1.7) <https://pytorch.org/>`_.

To install the latest stable version, run

Expand Down
4 changes: 2 additions & 2 deletions mbrl/algorithms/mbpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def evaluate(
num_episodes: int,
video_recorder: VideoRecorder,
) -> float:
avg_episode_reward = 0
avg_episode_reward = 0.0
for episode in range(num_episodes):
obs = env.reset()
video_recorder.init(enabled=(episode == 0))
done = False
episode_reward = 0
episode_reward = 0.0
while not done:
action = agent.act(obs)
obs, reward, done, _ = env.step(action)
Expand Down
2 changes: 1 addition & 1 deletion mbrl/env/cartpole_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class CartPoleEnv(gym.Env):
# This is a continuous version of gym's cartpole environment, with the only difference
# being valid actions are any numbers in the range [-1, 1], and the are applied as
# a multiplicative factor to the total force.
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 50}
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": [50]}

def __init__(self):
self.gravity = 9.8
Expand Down
2 changes: 1 addition & 1 deletion mbrl/third_party/pytorch_sac_pranz24/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
)

for i_episode in itertools.count(1):
episode_reward = 0
episode_reward = 0.0
episode_steps = 0
done = False
state = env.reset()
Expand Down
20 changes: 0 additions & 20 deletions pyproyect.toml

This file was deleted.

2 changes: 1 addition & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ nbsphinx>=0.8.0
sphinx-rtd-theme>=0.5.0
flake8>=3.8.4
mypy>=0.902
black>=21.4b2
black>=22.6.0
importlib_metadata<5
pytest>=6.0.1
types-pyyaml>=0.1.6
Expand Down
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def parse_requirements_file(path):
classifiers=[
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
Expand All @@ -41,6 +43,6 @@ def parse_requirements_file(path):
"dev": reqs_main + reqs_dev,
},
include_package_data=True,
python_requires=">=3.7",
python_requires=">=3.8",
zip_safe=False,
)
2 changes: 1 addition & 1 deletion tests/pybullet/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _is_eq(a, b) -> bool:
if not type(a) == type(b):
return False
if isinstance(a, np.ndarray):
return np.all(a == b)
return all(a == b)
elif isinstance(a, dict):
if not set(a.keys()) == set(b.keys()):
return False
Expand Down