Skip to content

Commit

Permalink
Add type checking with mypy (#331)
Browse files Browse the repository at this point in the history
* Add type checking with mypy

* Install mypy types

* Fix install types

* ALGO type hint

* Fix type errors

* ignore enjoy recoed_video and callbacks

* Fix type checking

Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
araffin and qgallouedec authored Jan 2, 2023
1 parent 5fe9a5d commit 1aa0644
Show file tree
Hide file tree
Showing 18 changed files with 94 additions and 52 deletions.
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ branch = False
omit =
tests/*
rl_zoo3/plots/*
rl_zoo3/push_to_hub.py
scripts/*

[report]
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ on:

jobs:
build:
env:
TERM: xterm-256color
FORCE_COLOR: 1
# Skip CI if [ci skip] in the commit message
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
runs-on: ubuntu-latest
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/trained_agents.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ on:

jobs:
build:
env:
TERM: xterm-256color
FORCE_COLOR: 1
# Skip CI if [ci skip] in the commit message
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

### Other
- `scripts/plot_train.py` plots models such that newer models appear on top of older ones.
- Added additional type checking using mypy


## Release 1.6.3 (2022-10-13)
Expand Down
10 changes: 7 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ pytest:
check-trained-agents:
python -m pytest -v tests/test_enjoy.py -k trained_agent --color=yes

# Type check
type:
pytype -j auto rl_zoo3/ tests/ scripts/ -d import-error
pytype:
pytype -j auto ${LINT_PATHS} -d import-error

mypy:
mypy ${LINT_PATHS} --install-types --non-interactive

type: pytype mypy

lint:
# stop the build if there are Python syntax errors or undefined names
Expand Down
3 changes: 2 additions & 1 deletion rl_zoo3/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import shutil
import subprocess
from typing import Dict, List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -30,7 +31,7 @@
trained_models.update(get_hf_trained_models())

n_experiments = len(trained_models)
results = {
results: Dict[str, List] = {
"algo": [],
"env_id": [],
"mean_reward": [],
Expand Down
34 changes: 22 additions & 12 deletions rl_zoo3/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy
from functools import wraps
from threading import Thread
from typing import Optional
from typing import Optional, Type, Union

import optuna
from sb3_contrib import TQC
Expand Down Expand Up @@ -80,13 +80,16 @@ def _init_callback(self) -> None:
os.makedirs(self.save_path, exist_ok=True)

def _on_step(self) -> bool:
# make mypy happy
assert self.model is not None

if self.n_calls % self.save_freq == 0:
if self.name_prefix is not None:
path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps.pkl")
else:
path = os.path.join(self.save_path, "vecnormalize.pkl")
if self.model.get_vec_normalize_env() is not None:
self.model.get_vec_normalize_env().save(path)
self.model.get_vec_normalize_env().save(path) # type: ignore[union-attr]
if self.verbose > 1:
print(f"Saving VecNormalize to {path}")
return True
Expand Down Expand Up @@ -114,10 +117,10 @@ def __init__(self, gradient_steps: int = 100, verbose: int = 0, sleep_time: floa
super().__init__(verbose)
self.batch_size = 0
self._model_ready = True
self._model = None
self._model: Union[SAC, TQC]
self.gradient_steps = gradient_steps
self.process = None
self.model_class = None
self.process: Thread
self.model_class: Union[Type[SAC], Type[TQC]]
self.sleep_time = sleep_time

def _init_callback(self) -> None:
Expand All @@ -126,18 +129,21 @@ def _init_callback(self) -> None:
# Windows TemporaryFile is not a io Buffer
# we save the model in the logs/ folder
if os.name == "nt":
temp_file = os.path.join("logs", "model_tmp.zip")
temp_file = os.path.join("logs", "model_tmp.zip") # type: ignore[arg-type,assignment]

# make mypy happy
assert isinstance(self.model, (SAC, TQC)), f"{self.model} is not supported for parallel training"

self.model.save(temp_file)
self.model.save(temp_file) # type: ignore[arg-type]

# TODO: add support for other algorithms
for model_class in [SAC, TQC]:
if isinstance(self.model, model_class):
self.model_class = model_class
self.model_class = model_class # type: ignore[assignment]
break

assert self.model_class is not None, f"{self.model} is not supported for parallel training"
self._model = self.model_class.load(temp_file)
self._model = self.model_class.load(temp_file) # type: ignore[arg-type]

self.batch_size = self._model.batch_size

Expand All @@ -151,7 +157,7 @@ def wrapper(*args, **kwargs):

# Add logger for parallel training
self._model.set_logger(self.model.logger)
self.model.train = patch_train(self.model.train)
self.model.train = patch_train(self.model.train) # type: ignore[assignment]

# Hack: Re-add correct values at save time
def patch_save(function):
Expand All @@ -161,7 +167,7 @@ def wrapper(*args, **kwargs):

return wrapper

self.model.save = patch_save(self.model.save)
self.model.save = patch_save(self.model.save) # type: ignore[assignment]

def train(self) -> None:
self._model_ready = False
Expand All @@ -179,10 +185,13 @@ def _on_step(self) -> bool:
return True

def _on_rollout_end(self) -> None:
# Make mypy happy
assert isinstance(self.model, (SAC, TQC))

if self._model_ready:
self._model.replay_buffer = deepcopy(self.model.replay_buffer)
self.model.set_parameters(deepcopy(self._model.get_parameters()))
self.model.actor = self.model.policy.actor
self.model.actor = self.model.policy.actor # type: ignore[union-attr, attr-defined]
if self.num_timesteps >= self._model.learning_starts:
self.train()
# Do not wait for the training loop to finish
Expand All @@ -209,6 +218,7 @@ def __init__(self, verbose=0):
self._tensorboard_writer = None

def _init_callback(self) -> None:
assert self.logger is not None
# Retrieve tensorboard writer to not flood the logger output
for out_format in self.logger.output_formats:
if isinstance(out_format, TensorBoardOutputFormat):
Expand Down
8 changes: 4 additions & 4 deletions rl_zoo3/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from rl_zoo3.utils import StoreDict, get_model_path


def enjoy(): # noqa: C901
def enjoy() -> None: # noqa: C901
parser = argparse.ArgumentParser()
parser.add_argument("--env", help="environment ID", type=EnvironmentName, default="CartPole-v1")
parser.add_argument("-f", "--folder", help="Log folder", type=str, default="rl-trained-agents")
Expand Down Expand Up @@ -139,7 +139,7 @@ def enjoy(): # noqa: C901
is_atari = ExperimentManager.is_atari(env_name.gym_id)

stats_path = os.path.join(log_path, env_name)
hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=args.norm_reward, test_mode=True)
hyperparams, maybe_stats_path = get_saved_hyperparams(stats_path, norm_reward=args.norm_reward, test_mode=True)

# load env_kwargs if existing
env_kwargs = {}
Expand All @@ -158,7 +158,7 @@ def enjoy(): # noqa: C901
env = create_test_env(
env_name.gym_id,
n_envs=args.n_envs,
stats_path=stats_path,
stats_path=maybe_stats_path,
seed=args.seed,
log_dir=log_dir,
should_render=not args.no_render,
Expand Down Expand Up @@ -213,7 +213,7 @@ def enjoy(): # noqa: C901
try:
for _ in generator:
action, lstm_states = model.predict(
obs,
obs, # type: ignore[arg-type]
state=lstm_states,
episode_start=episode_start,
deterministic=deterministic,
Expand Down
14 changes: 7 additions & 7 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def __init__(
default_path = Path(__file__).parent.parent

self.config = config or str(default_path / f"hyperparams/{self.algo}.yml")
self.env_kwargs = {} if env_kwargs is None else env_kwargs
self.env_kwargs: Dict[str, Any] = {} if env_kwargs is None else env_kwargs
self.n_timesteps = n_timesteps
self.normalize = False
self.normalize_kwargs = {}
self.normalize_kwargs: Dict[str, Any] = {}
self.env_wrapper = None
self.frame_stack = None
self.seed = seed
Expand All @@ -122,21 +122,21 @@ def __init__(
self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type]
self.vec_env_wrapper = None

self.vec_env_kwargs = {}
self.vec_env_kwargs: Dict[str, Any] = {}
# self.vec_env_kwargs = {} if vec_env_type == "dummy" else {"start_method": "fork"}

# Callbacks
self.specified_callbacks = []
self.callbacks = []
self.specified_callbacks: List = []
self.callbacks: List[BaseCallback] = []
self.save_freq = save_freq
self.eval_freq = eval_freq
self.n_eval_episodes = n_eval_episodes
self.n_eval_envs = n_eval_envs

self.n_envs = 1 # it will be updated when reading hyperparams
self.n_actions = None # For DDPG/TD3 action noise objects
self._hyperparams = {}
self.monitor_kwargs = {}
self._hyperparams: Dict[str, Any] = {}
self.monitor_kwargs: Dict[str, Any] = {}

self.trained_agent = trained_agent
self.continue_training = trained_agent.endswith(".zip") and os.path.isfile(trained_agent)
Expand Down
6 changes: 3 additions & 3 deletions rl_zoo3/push_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def package_to_hub(
is_atari = ExperimentManager.is_atari(env_name.gym_id)

stats_path = os.path.join(log_path, env_name)
hyperparams, stats_path = get_saved_hyperparams(stats_path, test_mode=True)
hyperparams, maybe_stats_path = get_saved_hyperparams(stats_path, test_mode=True)

# load env_kwargs if existing
env_kwargs = {}
Expand All @@ -358,7 +358,7 @@ def package_to_hub(
eval_env = create_test_env(
env_name.gym_id,
n_envs=args.n_envs,
stats_path=stats_path,
stats_path=maybe_stats_path,
seed=args.seed,
log_dir=None,
should_render=not args.no_render,
Expand All @@ -373,7 +373,7 @@ def package_to_hub(

# Note: we assume that we push models using the same machine (same python version)
# that trained them, if not, we would need to pass custom object as in enjoy.py
custom_objects = {}
custom_objects: Dict[str, Any] = {}
model = ALGOS[algo].load(model_path, env=eval_env, custom_objects=custom_objects, device=args.device, **kwargs)

# Deterministic by default except for atari games
Expand Down
6 changes: 3 additions & 3 deletions rl_zoo3/record_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@

# record a video of every model
models_dir_entries = [dir_ent.name for dir_ent in os.scandir(log_path) if dir_ent.is_file()]
checkpoints = list(filter(lambda x: x.startswith("rl_model_"), models_dir_entries))
checkpoints = list(map(lambda x: int(re.findall(r"\d+", x)[0]), checkpoints))
checkpoints_names = list(filter(lambda x: x.startswith("rl_model_"), models_dir_entries))
checkpoints = list(map(lambda x: int(re.findall(r"\d+", x)[0]), checkpoints_names))
checkpoints.sort()

args_final_model = [
Expand Down Expand Up @@ -102,7 +102,7 @@
# sort checkpoints by the number of steps
def get_number_from_checkpoint_filename(filename: str) -> int:
match = re.search("checkpoint-(.*?)-", filename)
number = 0
number = "0"
if match is not None:
number = match.group(1)

Expand Down
8 changes: 4 additions & 4 deletions rl_zoo3/record_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
is_atari = ExperimentManager.is_atari(env_name.gym_id)

stats_path = os.path.join(log_path, env_name)
hyperparams, stats_path = get_saved_hyperparams(stats_path)
hyperparams, maybe_stats_path = get_saved_hyperparams(stats_path)

# load env_kwargs if existing
env_kwargs = {}
Expand All @@ -92,7 +92,7 @@
env = create_test_env(
env_name.gym_id,
n_envs=n_envs,
stats_path=stats_path,
stats_path=maybe_stats_path,
seed=seed,
log_dir=None,
should_render=not args.no_render,
Expand Down Expand Up @@ -148,12 +148,12 @@
try:
for _ in range(video_length + 1):
action, lstm_states = model.predict(
obs,
obs, # type: ignore[arg-type]
state=lstm_states,
episode_start=episode_starts,
deterministic=deterministic,
)
obs, _, dones, _ = env.step(action)
obs, _, dones, _ = env.step(action) # type: ignore[assignment]
episode_starts = dones
if not args.no_render:
env.render()
Expand Down
5 changes: 3 additions & 2 deletions rl_zoo3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from rl_zoo3.utils import ALGOS, StoreDict


def train():
def train() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--algo", help="RL Algorithm", default="ppo", type=str, required=False, choices=list(ALGOS.keys()))
parser.add_argument("--env", type=str, default="CartPole-v1", help="environment ID")
Expand Down Expand Up @@ -179,7 +179,7 @@ def train():
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
if args.seed < 0:
# Seed but with a random one
args.seed = np.random.randint(2**32 - 1, dtype="int64").item()
args.seed = np.random.randint(2**32 - 1, dtype="int64").item() # type: ignore[attr-defined]

set_random_seed(args.seed)

Expand Down Expand Up @@ -262,6 +262,7 @@ def train():
if args.track:
# we need to save the loaded hyperparameters
args.saved_hyperparams = saved_hyperparams
assert run is not None # make mypy happy
run.config.setdefaults(vars(args))

# Normal training
Expand Down
Loading

0 comments on commit 1aa0644

Please sign in to comment.