diff --git a/.circleci/config.yml b/.circleci/config.yml index 9f90432c6..bd0009f24 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -28,15 +28,28 @@ executors: resource_class: xlarge environment: # If you change these, also change ci/code_checks.sh - SRC_FILES: src/ tests/ experiments/ examples/ docs/conf.py setup.py + SRC_FILES: src/ tests/ experiments/ examples/ docs/conf.py setup.py ci/ NUM_CPUS: 8 static-analysis-medium: <<: *defaults resource_class: medium environment: # If you change these, also change ci/code_checks.sh - SRC_FILES: src/ tests/ experiments/ examples/ docs/conf.py setup.py + SRC_FILES: src/ tests/ experiments/ examples/ docs/conf.py setup.py ci/ NUM_CPUS: 2 + EXCLUDE_MYPY: | + (?x)( + src/imitation/algorithms/preference_comparisons.py$ + | src/imitation/rewards/reward_nets.py$ + | src/imitation/util/sacred.py$ + | src/imitation/algorithms/base.py$ + | src/imitation/scripts/train_preference_comparisons.py$ + | src/imitation/rewards/serialize.py$ + | src/imitation/scripts/common/train.py$ + | src/imitation/algorithms/mce_irl.py$ + | src/imitation/algorithms/density.py$ + | tests/algorithms/test_bc.py$ + ) commands: dependencies-linux: @@ -229,6 +242,14 @@ jobs: # since they'll just get checked in a separate shellcheck invocation. exclude: SC1091 + - run: + name: ipynb-check + command: ./ci/clean_notebooks.py --check ./docs/tutorials + + - run: + name: typeignore-check + command: ./ci/check_typeignore.py ${SRC_FILES} + - run: name: flake8 command: flake8 --version && flake8 -j "${NUM_CPUS}" ${SRC_FILES} @@ -263,6 +284,10 @@ jobs: name: pytype command: pytype --version && pytype -j "${NUM_CPUS}" ${SRC_FILES[@]} + - run: + name: mypy + command: mypy --version && mypy ${SRC_FILES[@]} --exclude "${EXCLUDE_MYPY}" --follow-imports=silent --show-error-codes + unit-test-linux: executor: unit-test-linux steps: diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..a312bb801 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,6 @@ +[report] +exclude_lines = + pragma: no cover + @overload + @typing.overload + raise NotImplementedError \ No newline at end of file diff --git a/ci/Xdummy-entrypoint.py b/ci/Xdummy-entrypoint.py index 2fd2e6c5d..2ffd1d9e1 100755 --- a/ci/Xdummy-entrypoint.py +++ b/ci/Xdummy-entrypoint.py @@ -1,6 +1,7 @@ #!/usr/bin/python3 -# This script starts an X server and sets DISPLAY, then runs wrapped command. +"""This script starts an X server and sets DISPLAY, then runs wrapped command.""" + # Usage: ./Xdummy-entrypoint.py [command] # # Adapted from https://github.com/openai/mujoco-py/blob/master/vendor/Xdummy-entrypoint @@ -31,7 +32,7 @@ "-config", "/etc/dummy_xorg.conf", ":0", - ] + ], ) os.environ["DISPLAY"] = ":0" diff --git a/ci/check_typeignore.py b/ci/check_typeignore.py new file mode 100755 index 000000000..f042f81e9 --- /dev/null +++ b/ci/check_typeignore.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python + +"""Check for invalid "# type: ignore" comments. + +This script checks that no files in our source code have a "#type: ignore" comment +without explicitly indicating the reason for the ignore. This is to ensure that we +don't accidentally ignore errors that we should be fixing. +""" +import argparse +import os +import pathlib +import re +import sys +from typing import List + +# Regex to match a "# type: ignore" comment not followed by a reason. +TYPE_IGNORE_COMMENT = re.compile(r"#\s*type:\s*ignore\s*(?![^\[]*\[)") + +# Regex to match a "# type: ignore[]" comment. +TYPE_IGNORE_REASON_COMMENT = re.compile(r"#\s*type:\s*ignore\[(?P.*)\]") + + +class InvalidTypeIgnore(ValueError): + """Raised when a file has an invalid "# type: ignore" comment.""" + + +def check_file(file: pathlib.Path): + """Checks that the given file has no "# type: ignore" comments without a reason.""" + with open(file, "r") as f: + for i, line in enumerate(f): + if TYPE_IGNORE_COMMENT.search(line): + raise InvalidTypeIgnore( + f"{file}:{i+1}: Found a '# type: ignore' comment without a reason.", + ) + + if search := TYPE_IGNORE_REASON_COMMENT.search(line): + reason = search.group("reason") + if reason == "": + raise InvalidTypeIgnore( + f"{file}:{i+1}: Found a '# type: ignore[]' " + "comment without a reason.", + ) + + +def check_files(files: List[pathlib.Path]): + """Checks that the given files have no type: ignore comments without a reason.""" + for file in files: + if file == pathlib.Path(__file__): + continue + check_file(file) + + +def get_files_to_check(root_dirs: List[pathlib.Path]) -> List[pathlib.Path]: + """Returns a list of files that should be checked for "# type: ignore" comments.""" + # Get the list of files that should be checked. + files = [] + for root_dir in root_dirs: + for root, _, filenames in os.walk(root_dir): + for filename in filenames: + if filename.endswith(".py"): + files.append(pathlib.Path(root) / filename) + + return files + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "files", + nargs="+", + type=pathlib.Path, + help="List of files or paths to check for invalid '# type: ignore' comments.", + ) + args = parser.parse_args() + return parser, args + + +def main(): + """Check for invalid "# type: ignore" comments.""" + parser, args = parse_args() + file_list = get_files_to_check(args.files) + try: + check_files(file_list) + except InvalidTypeIgnore as e: + print(e) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/ci/clean_notebooks.py b/ci/clean_notebooks.py new file mode 100755 index 000000000..b9e394a20 --- /dev/null +++ b/ci/clean_notebooks.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python +"""Clean all notebooks in the repository.""" +import argparse +import pathlib +import sys +import traceback +from typing import Any, Dict, List + +import nbformat + + +class UncleanNotebookError(Exception): + """Raised when a notebook is unclean.""" + + +markdown_structure: Dict[str, Dict[str, Any]] = { + "cell_type": {"do": "keep"}, + "metadata": {"do": "constant", "value": dict()}, + "source": {"do": "keep"}, + "id": {"do": "keep"}, +} + +code_structure: Dict[str, Dict[str, Any]] = { + "cell_type": {"do": "keep"}, + "metadata": {"do": "constant", "value": dict()}, + "source": {"do": "keep"}, + "outputs": {"do": "constant", "value": list()}, + "execution_count": {"do": "constant", "value": None}, + "id": {"do": "keep"}, +} + +structure: Dict[str, Dict[str, Dict[str, Any]]] = { + "markdown": markdown_structure, + "code": code_structure, +} + + +def clean_notebook(file: pathlib.Path, check_only=False) -> None: + """Clean an ipynb notebook. + + "Cleaning" means removing all output and metadata, as well as any other unnecessary + or vendor-dependent information or fields, so that it can be committed to the + repository, and so that artificial diffs are not introduced when the notebook is + executed. + + Args: + file: Path to the notebook to clean. + check_only: If True, only check if the notebook is clean, and raise an + exception if it is not. If False, clean the notebook in-place. + + Raises: + UncleanNotebookError: If `check_only` is True and the notebook is not clean. + Message contains brief description of the reason for the failure. + ValueError: unknown cell structure action. + """ + # Read the notebook + with open(file) as f: + nb = nbformat.read(f, as_version=4) + + was_dirty = False + + if check_only: + print(f"Checking {file}") + + for cell in nb.cells: + + # Remove empty cells + if cell["cell_type"] == "code" and not cell["source"]: + if check_only: + raise UncleanNotebookError(f"Notebook {file} has empty code cell") + nb.cells.remove(cell) + was_dirty = True + + # Clean the cell + # (copy the cell keys list so we can iterate over it while modifying it) + for key in list(cell): + if key not in structure[cell["cell_type"]]: + if check_only: + raise UncleanNotebookError( + f"Notebook {file} has unknown cell key {key}", + ) + del cell[key] + was_dirty = True + else: + cell_structure = structure[cell["cell_type"]][key] + if cell_structure["do"] == "keep": + continue + elif cell_structure["do"] == "constant": + constant_value = cell_structure["value"] + if cell[key] != constant_value: + if check_only: + raise UncleanNotebookError( + f"Notebook {file} has illegal cell value for key {key}" + f" (value: {cell[key]}, " + f"expected: {constant_value})", + ) + cell[key] = constant_value + was_dirty = True + else: + raise ValueError( + f"Unknown cell structure action {cell_structure['do']}", + ) + + if not check_only and was_dirty: + # Write the notebook + with open(file, "w") as f: + nbformat.write(nb, f) + print(f"Cleaned {file}") + + +def parse_args(): + """Parse command-line arguments.""" + # if the argument --check has been passed, check if the notebooks are clean + # otherwise, clean them in-place + parser = argparse.ArgumentParser() + # capture files and paths to clean + parser.add_argument( + "files", + nargs="+", + type=pathlib.Path, + help="List of files or paths to clean", + ) + parser.add_argument("--check", action="store_true") + args = parser.parse_args() + return parser, args + + +def get_files(input_paths: List): + """Build list of files to scan from list of paths and files.""" + files = [] + for file in input_paths: + if file.is_dir(): + files.extend(file.glob("**/*.ipynb")) + else: + if file.suffix == ".ipynb": + files.append(file) + else: + print(f"Skipping {file} (not a notebook)") + if not files: + print("No notebooks found") + sys.exit(1) + return files + + +def main(): + """Clean all notebooks in the repository, or check that they are clean.""" + parser, args = parse_args() + check_only = args.check + input_paths = args.files + + if len(input_paths) == 0: + parser.print_help() + sys.exit(1) + + files = get_files(input_paths) + + for file in files: + try: + clean_notebook(file, check_only=check_only) + except UncleanNotebookError: + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/ci/code_checks.sh b/ci/code_checks.sh index 771ed2cee..3b7dd0610 100755 --- a/ci/code_checks.sh +++ b/ci/code_checks.sh @@ -1,12 +1,25 @@ #!/usr/bin/env bash # If you change these, also change .circleci/config.yml. -SRC_FILES=(src/ tests/ experiments/ examples/ docs/conf.py setup.py) +SRC_FILES=(src/ tests/ experiments/ examples/ docs/conf.py setup.py ci/) +EXCLUDE_MYPY="(?x)( + src/imitation/algorithms/preference_comparisons.py$ + | src/imitation/rewards/reward_nets.py$ + | src/imitation/util/sacred.py$ + | src/imitation/algorithms/base.py$ + | src/imitation/scripts/train_preference_comparisons.py$ + | src/imitation/rewards/serialize.py$ + | src/imitation/scripts/common/train.py$ + | src/imitation/algorithms/mce_irl.py$ + | src/imitation/algorithms/density.py$ + | tests/algorithms/test_bc.py$ +)" set -x # echo commands set -e # quit immediately on error echo "Source format checking" +./ci/clean_notebooks.py --check ./docs/tutorials/ flake8 --darglint-ignore-regex '.*' "${SRC_FILES[@]}" black --check --diff "${SRC_FILES[@]}" codespell -I .codespell.skip --skip='*.pyc,tests/testdata/*,*.ipynb,*.csv' "${SRC_FILES[@]}" @@ -22,6 +35,7 @@ fi if [ "${skipexpensive:-}" != "true" ]; then echo "Type checking" pytype -j auto "${SRC_FILES[@]}" + mypy "${SRC_FILES[@]}" --exclude "${EXCLUDE_MYPY}" --follow-imports=silent --show-error-codes echo "Building docs (validates docstrings)" pushd docs/ diff --git a/docs/algorithms/airl.rst b/docs/algorithms/airl.rst index 13fbec8b2..fea05abed 100644 --- a/docs/algorithms/airl.rst +++ b/docs/algorithms/airl.rst @@ -21,6 +21,7 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl` .. testcode:: + import numpy as np import gym from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy @@ -34,6 +35,8 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl` from imitation.util.networks import RunningNorm from imitation.util.util import make_vec_env + rng = np.random.default_rng(0) + env = gym.make("seals/CartPole-v0") expert = PPO(policy=MlpPolicy, env=env) expert.learn(1000) @@ -42,13 +45,15 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl` expert, make_vec_env( "seals/CartPole-v0", + rng=rng, n_envs=5, post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], ), rollout.make_sample_until(min_timesteps=None, min_episodes=60), + rng=rng, ) - venv = make_vec_env("seals/CartPole-v0", n_envs=8) + venv = make_vec_env("seals/CartPole-v0", rng=rng, n_envs=8) learner = PPO(env=venv, policy=MlpPolicy) reward_net = BasicShapedRewardNet( venv.observation_space, diff --git a/docs/algorithms/bc.rst b/docs/algorithms/bc.rst index 4e254be0c..b7026894a 100644 --- a/docs/algorithms/bc.rst +++ b/docs/algorithms/bc.rst @@ -18,7 +18,8 @@ Example Detailed example notebook: :doc:`../tutorials/1_train_bc` .. testcode:: - + + import numpy as np import gym from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy @@ -29,6 +30,7 @@ Detailed example notebook: :doc:`../tutorials/1_train_bc` from imitation.data import rollout from imitation.data.wrappers import RolloutInfoWrapper + rng = np.random.default_rng(0) env = gym.make("CartPole-v1") expert = PPO(policy=MlpPolicy, env=env) expert.learn(1000) @@ -37,6 +39,7 @@ Detailed example notebook: :doc:`../tutorials/1_train_bc` expert, DummyVecEnv([lambda: RolloutInfoWrapper(env)]), rollout.make_sample_until(min_timesteps=None, min_episodes=50), + rng=rng, ) transitions = rollout.flatten_trajectories(rollouts) @@ -44,6 +47,7 @@ Detailed example notebook: :doc:`../tutorials/1_train_bc` observation_space=env.observation_space, action_space=env.action_space, demonstrations=transitions, + rng=rng, ) bc_trainer.train(n_epochs=1) reward, _ = evaluate_policy(bc_trainer.policy, env, 10) diff --git a/docs/algorithms/dagger.rst b/docs/algorithms/dagger.rst index 7b0521784..5be16dc96 100644 --- a/docs/algorithms/dagger.rst +++ b/docs/algorithms/dagger.rst @@ -24,7 +24,7 @@ Detailed example notebook: :doc:`../tutorials/2_train_dagger` .. testcode:: import tempfile - + import numpy as np import gym from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy @@ -34,6 +34,7 @@ Detailed example notebook: :doc:`../tutorials/2_train_dagger` from imitation.algorithms import bc from imitation.algorithms.dagger import SimpleDAggerTrainer + rng = np.random.default_rng(0) env = gym.make("CartPole-v1") expert = PPO(policy=MlpPolicy, env=env) expert.learn(1000) @@ -42,11 +43,16 @@ Detailed example notebook: :doc:`../tutorials/2_train_dagger` bc_trainer = bc.BC( observation_space=env.observation_space, action_space=env.action_space, + rng=rng, ) with tempfile.TemporaryDirectory(prefix="dagger_example_") as tmpdir: print(tmpdir) dagger_trainer = SimpleDAggerTrainer( - venv=venv, scratch_dir=tmpdir, expert_policy=expert, bc_trainer=bc_trainer, + venv=venv, + scratch_dir=tmpdir, + expert_policy=expert, + bc_trainer=bc_trainer, + rng=rng, ) dagger_trainer.train(2000) diff --git a/docs/algorithms/density.rst b/docs/algorithms/density.rst index ed822bbbc..c06436644 100644 --- a/docs/algorithms/density.rst +++ b/docs/algorithms/density.rst @@ -10,6 +10,7 @@ Detailed example notebook: :doc:`../tutorials/7_train_density` .. testcode:: import pprint + import numpy as np from stable_baselines3 import PPO from stable_baselines3.common.policies import ActorCriticPolicy @@ -18,7 +19,9 @@ Detailed example notebook: :doc:`../tutorials/7_train_density` from imitation.data import types from imitation.util import util - env = util.make_vec_env("Pendulum-v1", 2) + rng = np.random.default_rng(0) + + env = util.make_vec_env("Pendulum-v1", rng=rng, n_envs=2) rollouts = types.load("../tests/testdata/expert_models/pendulum_0/rollouts/final.pkl") imitation_trainer = PPO(ActorCriticPolicy, env) @@ -26,6 +29,7 @@ Detailed example notebook: :doc:`../tutorials/7_train_density` venv=env, demonstrations=rollouts, rl_algo=imitation_trainer, + rng=rng, ) density_trainer.train() diff --git a/docs/algorithms/gail.rst b/docs/algorithms/gail.rst index 584cb36f9..75cc55fe1 100644 --- a/docs/algorithms/gail.rst +++ b/docs/algorithms/gail.rst @@ -19,6 +19,7 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail` .. testcode:: + import numpy as np import gym from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy @@ -32,6 +33,8 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail` from imitation.util.networks import RunningNorm from imitation.util.util import make_vec_env + rng = np.random.default_rng(0) + env = gym.make("seals/CartPole-v0") expert = PPO(policy=MlpPolicy, env=env, n_steps=64) expert.learn(1000) @@ -42,11 +45,13 @@ Detailed example notebook: :doc:`../tutorials/3_train_gail` "seals/CartPole-v0", n_envs=5, post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], + rng=rng, ), rollout.make_sample_until(min_timesteps=None, min_episodes=60), + rng=rng, ) - venv = make_vec_env("seals/CartPole-v0", n_envs=8) + venv = make_vec_env("seals/CartPole-v0", n_envs=8, rng=rng) learner = PPO(env=venv, policy=MlpPolicy) reward_net = BasicRewardNet( venv.observation_space, diff --git a/docs/algorithms/mce_irl.rst b/docs/algorithms/mce_irl.rst index 0b2e8ce21..353640f6e 100644 --- a/docs/algorithms/mce_irl.rst +++ b/docs/algorithms/mce_irl.rst @@ -13,6 +13,8 @@ Detailed example notebook: :doc:`../tutorials/6_train_mce` from functools import partial + import numpy as np + from stable_baselines3.common.vec_env import DummyVecEnv from imitation.algorithms.mce_irl import ( @@ -25,6 +27,8 @@ Detailed example notebook: :doc:`../tutorials/6_train_mce` from imitation.envs.examples.model_envs import CliffWorld from imitation.rewards import reward_nets + rng = np.random.default_rng(0) + env_creator = partial(CliffWorld, height=4, horizon=8, width=7, use_xy_obs=True) env_single = env_creator() @@ -51,6 +55,7 @@ Detailed example notebook: :doc:`../tutorials/6_train_mce` reward_net, log_interval=250, optimizer_kwargs={"lr": 0.01}, + rng=rng, ) occ_measure = mce_irl.train() @@ -58,6 +63,7 @@ Detailed example notebook: :doc:`../tutorials/6_train_mce` policy=mce_irl.policy, venv=state_venv, sample_until=rollout.make_min_timesteps(5000), + rng=rng, ) print("Imitation stats: ", rollout.rollout_stats(imitation_trajs)) diff --git a/docs/algorithms/preference_comparisons.rst b/docs/algorithms/preference_comparisons.rst index 16a74fb89..4f9a4ac1d 100644 --- a/docs/algorithms/preference_comparisons.rst +++ b/docs/algorithms/preference_comparisons.rst @@ -22,6 +22,8 @@ Detailed example notebook: :doc:`../tutorials/5_train_preference_comparisons` .. testcode:: + import numpy as np + from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.ppo import MlpPolicy @@ -33,19 +35,22 @@ Detailed example notebook: :doc:`../tutorials/5_train_preference_comparisons` from imitation.util.networks import RunningNorm from imitation.util.util import make_vec_env - venv = make_vec_env("Pendulum-v1") + rng = np.random.default_rng(0) + + venv = make_vec_env("Pendulum-v1", rng=rng) reward_net = BasicRewardNet( venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm, ) - fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, seed=0) - gatherer = preference_comparisons.SyntheticGatherer(seed=0) + fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, rng=rng) + gatherer = preference_comparisons.SyntheticGatherer(rng=rng) preference_model = preference_comparisons.PreferenceModel(reward_net) reward_trainer = preference_comparisons.BasicRewardTrainer( preference_model=preference_model, loss=preference_comparisons.CrossEntropyRewardLoss(), epochs=3, + rng=rng, ) agent = PPO( @@ -63,7 +68,7 @@ Detailed example notebook: :doc:`../tutorials/5_train_preference_comparisons` reward_fn=reward_net, venv=venv, exploration_frac=0.0, - seed=0, + rng=rng, ) pref_comparisons = preference_comparisons.PreferenceComparisons( @@ -73,7 +78,6 @@ Detailed example notebook: :doc:`../tutorials/5_train_preference_comparisons` fragmenter=fragmenter, preference_gatherer=gatherer, reward_trainer=reward_trainer, - seed=0, initial_epoch_multiplier=1, ) pref_comparisons.train(total_timesteps=5_000, total_comparisons=200) diff --git a/docs/tutorials/1_train_bc.ipynb b/docs/tutorials/1_train_bc.ipynb index 30e9126c1..500f0a3db 100644 --- a/docs/tutorials/1_train_bc.ipynb +++ b/docs/tutorials/1_train_bc.ipynb @@ -27,11 +27,10 @@ "metadata": {}, "outputs": [], "source": [ + "import gym\n", "from stable_baselines3 import PPO\n", "from stable_baselines3.ppo import MlpPolicy\n", "\n", - "import gym\n", - "\n", "env = gym.make(\"CartPole-v1\")\n", "expert = PPO(\n", " policy=MlpPolicy,\n", @@ -85,11 +84,14 @@ "from imitation.data import rollout\n", "from imitation.data.wrappers import RolloutInfoWrapper\n", "from stable_baselines3.common.vec_env import DummyVecEnv\n", + "import numpy as np\n", "\n", + "rng = np.random.default_rng()\n", "rollouts = rollout.rollout(\n", " expert,\n", " DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n", " rollout.make_sample_until(min_timesteps=None, min_episodes=50),\n", + " rng=rng,\n", ")\n", "transitions = rollout.flatten_trajectories(rollouts)" ] @@ -134,6 +136,7 @@ " observation_space=env.observation_space,\n", " action_space=env.action_space,\n", " demonstrations=transitions,\n", + " rng=rng,\n", ")" ] }, @@ -193,8 +196,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/tutorials/2_train_dagger.ipynb b/docs/tutorials/2_train_dagger.ipynb index c1fee45c8..ed8c84a2b 100644 --- a/docs/tutorials/2_train_dagger.ipynb +++ b/docs/tutorials/2_train_dagger.ipynb @@ -26,11 +26,10 @@ "metadata": {}, "outputs": [], "source": [ + "import gym\n", "from stable_baselines3 import PPO\n", "from stable_baselines3.ppo import MlpPolicy\n", "\n", - "import gym\n", - "\n", "env = gym.make(\"CartPole-v1\")\n", "expert = PPO(\n", " policy=MlpPolicy,\n", @@ -60,6 +59,7 @@ "source": [ "import tempfile\n", "import gym\n", + "import numpy as np\n", "from stable_baselines3.common.vec_env import DummyVecEnv\n", "\n", "from imitation.algorithms import bc\n", @@ -71,12 +71,17 @@ "bc_trainer = bc.BC(\n", " observation_space=env.observation_space,\n", " action_space=env.action_space,\n", + " rng=np.random.default_rng(),\n", ")\n", "\n", "with tempfile.TemporaryDirectory(prefix=\"dagger_example_\") as tmpdir:\n", " print(tmpdir)\n", " dagger_trainer = SimpleDAggerTrainer(\n", - " venv=venv, scratch_dir=tmpdir, expert_policy=expert, bc_trainer=bc_trainer\n", + " venv=venv,\n", + " scratch_dir=tmpdir,\n", + " expert_policy=expert,\n", + " bc_trainer=bc_trainer,\n", + " rng=np.random.default_rng(),\n", " )\n", "\n", " dagger_trainer.train(2000)" @@ -122,8 +127,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/tutorials/3_train_gail.ipynb b/docs/tutorials/3_train_gail.ipynb index 33125614b..200c0fbed 100644 --- a/docs/tutorials/3_train_gail.ipynb +++ b/docs/tutorials/3_train_gail.ipynb @@ -25,10 +25,10 @@ "metadata": {}, "outputs": [], "source": [ + "import gym\n", "from stable_baselines3 import PPO\n", "from stable_baselines3.ppo import MlpPolicy\n", - "import gym\n", - "import seals\n", + "import seals # needed to load environments\n", "\n", "env = gym.make(\"seals/CartPole-v0\")\n", "expert = PPO(\n", @@ -61,15 +61,19 @@ "from imitation.data.wrappers import RolloutInfoWrapper\n", "from imitation.util.util import make_vec_env\n", "from stable_baselines3.common.vec_env import DummyVecEnv\n", + "import numpy as np\n", "\n", + "rng = np.random.default_rng()\n", "rollouts = rollout.rollout(\n", " expert,\n", " make_vec_env(\n", " \"seals/CartPole-v0\",\n", " n_envs=5,\n", " post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],\n", + " rng=rng,\n", " ),\n", " rollout.make_sample_until(min_timesteps=None, min_episodes=60),\n", + " rng=rng,\n", ")" ] }, @@ -94,13 +98,12 @@ "from imitation.util.util import make_vec_env\n", "from stable_baselines3 import PPO\n", "from stable_baselines3.common.evaluation import evaluate_policy\n", - "from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv\n", + "from stable_baselines3.common.vec_env import DummyVecEnv\n", "\n", "import gym\n", - "import seals\n", "\n", "\n", - "venv = make_vec_env(\"seals/CartPole-v0\", n_envs=8)\n", + "venv = make_vec_env(\"seals/CartPole-v0\", n_envs=8, rng=rng)\n", "learner = PPO(\n", " env=venv,\n", " policy=MlpPolicy,\n", @@ -180,8 +183,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/docs/tutorials/4_train_airl.ipynb b/docs/tutorials/4_train_airl.ipynb index a9f7fbb4c..27ac1c332 100644 --- a/docs/tutorials/4_train_airl.ipynb +++ b/docs/tutorials/4_train_airl.ipynb @@ -22,10 +22,10 @@ "metadata": {}, "outputs": [], "source": [ + "import gym\n", "from stable_baselines3 import PPO\n", "from stable_baselines3.ppo import MlpPolicy\n", - "import gym\n", - "import seals\n", + "import seals # needed to load environments\n", "\n", "env = gym.make(\"seals/CartPole-v0\")\n", "expert = PPO(\n", @@ -58,15 +58,19 @@ "from imitation.data.wrappers import RolloutInfoWrapper\n", "from imitation.util.util import make_vec_env\n", "from stable_baselines3.common.vec_env import DummyVecEnv\n", + "import numpy as np\n", "\n", + "rng = np.random.default_rng()\n", "rollouts = rollout.rollout(\n", " expert,\n", " make_vec_env(\n", " \"seals/CartPole-v0\",\n", " n_envs=5,\n", " post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],\n", + " rng=rng,\n", " ),\n", " rollout.make_sample_until(min_timesteps=None, min_episodes=60),\n", + " rng=rng,\n", ")" ] }, @@ -91,13 +95,12 @@ "from imitation.util.util import make_vec_env\n", "from stable_baselines3 import PPO\n", "from stable_baselines3.common.evaluation import evaluate_policy\n", - "from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv\n", "\n", "import gym\n", "import seals\n", "\n", "\n", - "venv = make_vec_env(\"seals/CartPole-v0\", n_envs=8)\n", + "venv = make_vec_env(\"seals/CartPole-v0\", n_envs=8, rng=rng)\n", "learner = PPO(\n", " env=venv,\n", " policy=MlpPolicy,\n", @@ -159,7 +162,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.12 ('imitation')", + "display_name": "Python 3.8.10 ('venv': venv)", "language": "python", "name": "python3" }, @@ -173,13 +176,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "3cfa84a08340a0535ac1870c24c62bfe5d4563d20dd3264de6ae795b45772dcc" - } + "version": "3.8.10" } }, "nbformat": 4, diff --git a/docs/tutorials/5_train_preference_comparisons.ipynb b/docs/tutorials/5_train_preference_comparisons.ipynb index 1d13d8863..fe63732d6 100644 --- a/docs/tutorials/5_train_preference_comparisons.ipynb +++ b/docs/tutorials/5_train_preference_comparisons.ipynb @@ -23,6 +23,7 @@ "metadata": {}, "outputs": [], "source": [ + "import random\n", "from imitation.algorithms import preference_comparisons\n", "from imitation.rewards.reward_nets import BasicRewardNet\n", "from imitation.util.networks import RunningNorm\n", @@ -30,20 +31,27 @@ "from imitation.policies.base import FeedForward32Policy, NormalizeFeaturesExtractor\n", "import gym\n", "from stable_baselines3 import PPO\n", + "import numpy as np\n", "\n", - "venv = make_vec_env(\"Pendulum-v1\")\n", + "rng = np.random.default_rng(0)\n", + "\n", + "venv = make_vec_env(\"Pendulum-v1\", rng=rng)\n", "\n", "reward_net = BasicRewardNet(\n", " venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm\n", ")\n", "\n", - "fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, seed=0)\n", - "gatherer = preference_comparisons.SyntheticGatherer(seed=0)\n", + "fragmenter = preference_comparisons.RandomFragmenter(\n", + " warning_threshold=0,\n", + " rng=rng,\n", + ")\n", + "gatherer = preference_comparisons.SyntheticGatherer(rng=rng)\n", "preference_model = preference_comparisons.PreferenceModel(reward_net)\n", "reward_trainer = preference_comparisons.BasicRewardTrainer(\n", " preference_model=preference_model,\n", " loss=preference_comparisons.CrossEntropyRewardLoss(),\n", " epochs=3,\n", + " rng=rng,\n", ")\n", "\n", "agent = PPO(\n", @@ -66,7 +74,7 @@ " reward_fn=reward_net,\n", " venv=venv,\n", " exploration_frac=0.0,\n", - " seed=0,\n", + " rng=rng,\n", ")\n", "\n", "pref_comparisons = preference_comparisons.PreferenceComparisons(\n", @@ -80,7 +88,6 @@ " transition_oversampling=1,\n", " initial_comparison_frac=0.1,\n", " allow_variable_horizon=False,\n", - " seed=0,\n", " initial_epoch_multiplier=1,\n", ")" ] diff --git a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb index e3b408215..6c26c8f32 100644 --- a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb +++ b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb @@ -24,14 +24,13 @@ "cell_type": "code", "execution_count": null, "id": "93187e19", - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "import torch as th\n", "import gym\n", "from gym.wrappers import TimeLimit\n", + "import numpy as np\n", "\n", "from seals.util import AutoResetWrapper\n", "\n", @@ -47,6 +46,8 @@ "\n", "device = th.device(\"cuda\" if th.cuda.is_available() else \"cpu\")\n", "\n", + "rng = np.random.default_rng()\n", + "\n", "# Here we ensure that our environment has constant-length episodes by resetting\n", "# it when done, and running until 100 timesteps have elapsed.\n", "# For real training, you will want a much longer time limit.\n", @@ -67,13 +68,14 @@ " venv.action_space,\n", ").to(device)\n", "\n", - "fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, seed=0)\n", - "gatherer = preference_comparisons.SyntheticGatherer(seed=0)\n", + "fragmenter = preference_comparisons.RandomFragmenter(warning_threshold=0, rng=rng)\n", + "gatherer = preference_comparisons.SyntheticGatherer(rng=rng)\n", "preference_model = preference_comparisons.PreferenceModel(reward_net)\n", "reward_trainer = preference_comparisons.BasicRewardTrainer(\n", " preference_model=preference_model,\n", " loss=preference_comparisons.CrossEntropyRewardLoss(),\n", " epochs=3,\n", + " rng=rng,\n", ")\n", "\n", "agent = PPO(\n", @@ -92,7 +94,7 @@ " reward_fn=reward_net,\n", " venv=venv,\n", " exploration_frac=0.0,\n", - " seed=0,\n", + " rng=rng,\n", ")\n", "\n", "pref_comparisons = preference_comparisons.PreferenceComparisons(\n", @@ -106,7 +108,6 @@ " transition_oversampling=1,\n", " initial_comparison_frac=0.1,\n", " allow_variable_horizon=False,\n", - " seed=0,\n", " initial_epoch_multiplier=1,\n", ")" ] @@ -123,9 +124,7 @@ "cell_type": "code", "execution_count": null, "id": "1c2c4d3a", - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [], "source": [ "pref_comparisons.train(\n", diff --git a/docs/tutorials/6_train_mce.ipynb b/docs/tutorials/6_train_mce.ipynb index cb215747e..51a8c4a66 100644 --- a/docs/tutorials/6_train_mce.ipynb +++ b/docs/tutorials/6_train_mce.ipynb @@ -24,20 +24,21 @@ "outputs": [], "source": [ "from functools import partial\n", + "\n", + "import numpy as np\n", + "from stable_baselines3.common.vec_env import DummyVecEnv\n", + "\n", "from imitation.algorithms.mce_irl import (\n", " MCEIRL,\n", " mce_occupancy_measures,\n", " mce_partition_fh,\n", " TabularPolicy,\n", ")\n", - "\n", "from imitation.data import rollout\n", "from imitation.envs import resettable_env\n", "from imitation.envs.examples.model_envs import CliffWorld\n", - "from stable_baselines3.common.vec_env import DummyVecEnv\n", "from imitation.rewards import reward_nets\n", "\n", - "\n", "env_creator = partial(CliffWorld, height=4, horizon=8, width=7, use_xy_obs=True)\n", "env_single = env_creator()\n", "\n", @@ -62,17 +63,19 @@ "\n", "_, om = mce_occupancy_measures(env_single, pi=pi)\n", "\n", + "rng = np.random.default_rng()\n", "expert = TabularPolicy(\n", " state_space=env_single.pomdp_state_space,\n", " action_space=env_single.action_space,\n", " pi=pi,\n", - " rng=None,\n", + " rng=rng,\n", ")\n", "\n", "expert_trajs = rollout.generate_trajectories(\n", " policy=expert,\n", " venv=state_venv,\n", " sample_until=rollout.make_min_timesteps(5000),\n", + " rng=rng,\n", ")\n", "\n", "print(\"Expert stats: \", rollout.rollout_stats(expert_trajs))" @@ -108,7 +111,12 @@ " )\n", "\n", " mce_irl = MCEIRL(\n", - " demos, env_single, reward_net, log_interval=250, optimizer_kwargs=dict(lr=lr)\n", + " demos,\n", + " env_single,\n", + " reward_net,\n", + " log_interval=250,\n", + " optimizer_kwargs=dict(lr=lr),\n", + " rng=rng,\n", " )\n", " occ_measure = mce_irl.train(**kwargs)\n", "\n", @@ -116,6 +124,7 @@ " policy=mce_irl.policy,\n", " venv=state_venv,\n", " sample_until=rollout.make_min_timesteps(5000),\n", + " rng=rng,\n", " )\n", " print(\"Imitation stats: \", rollout.rollout_stats(imitation_trajs))\n", "\n", diff --git a/docs/tutorials/7_train_density.ipynb b/docs/tutorials/7_train_density.ipynb index df75e7057..d4c1c3f85 100644 --- a/docs/tutorials/7_train_density.ipynb +++ b/docs/tutorials/7_train_density.ipynb @@ -31,12 +31,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "is_executing": false, - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "# Set FAST = False for longer training. Use True for testing and CI.\n", @@ -58,11 +53,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "is_executing": false - } - }, + "metadata": {}, "outputs": [], "source": [ "from stable_baselines3.common.policies import ActorCriticPolicy\n", @@ -72,8 +63,10 @@ "from stable_baselines3.common.vec_env import DummyVecEnv\n", "from imitation.data.wrappers import RolloutInfoWrapper\n", "import gym\n", + "import numpy as np\n", "\n", "\n", + "rng = np.random.default_rng()\n", "env_name = \"Pendulum-v1\"\n", "expert = PPO.load(\n", " load_from_hub(\"HumanCompatibleAI/ppo-Pendulum-v1\", \"ppo-Pendulum-v1.zip\")\n", @@ -82,15 +75,19 @@ " [lambda: RolloutInfoWrapper(gym.make(env_name)) for _ in range(N_VEC)]\n", ")\n", "rollouts = rollout.rollout(\n", - " expert, rollout_env, rollout.make_sample_until(min_timesteps=2000, min_episodes=57)\n", + " expert,\n", + " rollout_env,\n", + " rollout.make_sample_until(min_timesteps=2000, min_episodes=57),\n", + " rng=rng,\n", ")\n", "\n", - "env = util.make_vec_env(env_name, N_VEC)\n", + "env = util.make_vec_env(env_name, n_envs=N_VEC, rng=rng)\n", "\n", "\n", "imitation_trainer = PPO(ActorCriticPolicy, env, learning_rate=3e-4, n_steps=2048)\n", "density_trainer = db.DensityAlgorithm(\n", " venv=env,\n", + " rng=rng,\n", " demonstrations=rollouts,\n", " rl_algo=imitation_trainer,\n", " density_type=db.DensityType.STATE_ACTION_DENSITY,\n", @@ -105,12 +102,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "is_executing": false, - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "def print_stats(density_trainer, n_trajectories, epoch=\"\"):\n", diff --git a/examples/quickstart.py b/examples/quickstart.py index 98f68c53b..379fdcc88 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -4,6 +4,7 @@ """ import gym +import numpy as np from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv @@ -14,6 +15,7 @@ from imitation.data.wrappers import RolloutInfoWrapper env = gym.make("CartPole-v1") +rng = np.random.default_rng(0) def train_expert(): @@ -40,6 +42,7 @@ def sample_expert_transitions(): expert, DummyVecEnv([lambda: RolloutInfoWrapper(env)]), rollout.make_sample_until(min_timesteps=None, min_episodes=50), + rng=rng, ) return rollout.flatten_trajectories(rollouts) @@ -49,13 +52,24 @@ def sample_expert_transitions(): observation_space=env.observation_space, action_space=env.action_space, demonstrations=transitions, + rng=rng, ) -reward, _ = evaluate_policy(bc_trainer.policy, env, n_eval_episodes=3, render=True) +reward, _ = evaluate_policy( + bc_trainer.policy, # type: ignore[arg-type] + env, + n_eval_episodes=3, + render=True, +) print(f"Reward before training: {reward}") print("Training a policy using Behavior Cloning") bc_trainer.train(n_epochs=1) -reward, _ = evaluate_policy(bc_trainer.policy, env, n_eval_episodes=3, render=True) +reward, _ = evaluate_policy( + bc_trainer.policy, # type: ignore[arg-type] + env, + n_eval_episodes=3, + render=True, +) print(f"Reward after training: {reward}") diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..e886c0858 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +ignore_missing_imports = true \ No newline at end of file diff --git a/setup.py b/setup.py index f15479dad..ec67e70bc 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ # TODO: upgrade jupyter-client once # https://github.com/jupyter/jupyter_client/issues/637 is fixed "jupyter-client~=6.1.12", + "mypy==0.971", "pandas~=1.4.3", "pytest~=7.1.2", "pytest-cov~=3.0.0", diff --git a/src/imitation/algorithms/adversarial/airl.py b/src/imitation/algorithms/adversarial/airl.py index 7676b4548..7f0a9d9d6 100644 --- a/src/imitation/algorithms/adversarial/airl.py +++ b/src/imitation/algorithms/adversarial/airl.py @@ -1,4 +1,5 @@ """Adversarial Inverse Reinforcement Learning (AIRL).""" +from typing import Optional import torch as th from stable_baselines3.common import base_class, policies, vec_env @@ -69,7 +70,7 @@ def logits_expert_is_high( action: th.Tensor, next_state: th.Tensor, done: th.Tensor, - log_policy_act_prob: th.Tensor, + log_policy_act_prob: Optional[th.Tensor] = None, ) -> th.Tensor: r"""Compute the discriminator's logits for each state-action sample. diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index eee6937e4..41c1b4129 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -4,13 +4,23 @@ import dataclasses import logging import os -from typing import Callable, Mapping, Optional, Sequence, Tuple, Type +from typing import ( + Callable, + Iterable, + Iterator, + Mapping, + Optional, + Sequence, + Tuple, + Type, + overload, +) import numpy as np import torch as th import torch.utils.tensorboard as thboard import tqdm -from stable_baselines3.common import base_class, policies, vec_env +from stable_baselines3.common import base_class, on_policy_algorithm, policies, vec_env from stable_baselines3.sac import policies as sac_policies from torch.nn import functional as F @@ -63,7 +73,7 @@ def compute_train_stats( else: # float() is defensive, since we cannot divide Torch tensors by # Python ints - expert_acc = _n_pred_expert / float(n_expert) + expert_acc = _n_pred_expert.item() / float(n_expert) _n_pred_gen = th.sum(th.logical_and(bin_is_generated_true, correct_vec)) _n_gen_or_1 = max(1, n_generated) @@ -103,6 +113,11 @@ class AdversarialTrainer(base.DemonstrationAlgorithm[types.Transitions]): If `debug_use_ground_truth=True` was passed into the initializer then `self.venv_train` is the same as `self.venv`.""" + _demo_data_loader: Optional[Iterable[base.TransitionMapping]] + _endless_expert_iterator: Optional[Iterator[base.TransitionMapping]] + + venv_wrapped: vec_env.VecEnvWrapper + def __init__( self, *, @@ -205,15 +220,15 @@ def __init__( os.makedirs(summary_dir, exist_ok=True) self._summary_writer = thboard.SummaryWriter(summary_dir) - venv = self.venv_buffering = wrappers.BufferingWrapper(self.venv) + self.venv_buffering = wrappers.BufferingWrapper(self.venv) if debug_use_ground_truth: # Would use an identity reward fn here, but RewardFns can't see rewards. - self.venv_wrapped = venv + self.venv_wrapped = self.venv_buffering self.gen_callback = None else: - venv = self.venv_wrapped = reward_wrapper.RewardVecEnvWrapper( - venv, + self.venv_wrapped = reward_wrapper.RewardVecEnvWrapper( + self.venv_buffering, reward_fn=self.reward_train.predict_processed, ) self.gen_callback = self.venv_wrapped.make_log_callback() @@ -226,7 +241,7 @@ def __init__( gen_algo_env = self.gen_algo.get_env() assert gen_algo_env is not None self.gen_train_timesteps = gen_algo_env.num_envs - if hasattr(self.gen_algo, "n_steps"): # on policy + if isinstance(self.gen_algo, on_policy_algorithm.OnPolicyAlgorithm): self.gen_train_timesteps *= self.gen_algo.n_steps else: self.gen_train_timesteps = gen_train_timesteps @@ -240,7 +255,9 @@ def __init__( @property def policy(self) -> policies.BasePolicy: - return self.gen_algo.policy + policy = self.gen_algo.policy + assert policy is not None + return policy @abc.abstractmethod def logits_expert_is_high( @@ -288,6 +305,7 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: self._endless_expert_iterator = util.endless_iter(self._demo_data_loader) def _next_expert_batch(self) -> Mapping: + assert self._endless_expert_iterator is not None return next(self._endless_expert_iterator) def train_disc( @@ -429,9 +447,18 @@ def train( callback(r) self.logger.dump(self._global_step) + @overload + def _torchify_array(self, ndarray: np.ndarray) -> th.Tensor: + ... + + @overload + def _torchify_array(self, ndarray: None) -> None: + ... + def _torchify_array(self, ndarray: Optional[np.ndarray]) -> Optional[th.Tensor]: if ndarray is not None: return th.as_tensor(ndarray, device=self.reward_train.device) + return None def _get_log_policy_act_prob( self, @@ -502,8 +529,8 @@ def _make_disc_train_batch( raise RuntimeError( "No generator samples for training. " "Call `train_gen()` first.", ) - gen_samples = self._gen_replay_buffer.sample(self.demo_batch_size) - gen_samples = types.dataclass_quick_asdict(gen_samples) + gen_samples_dataclass = self._gen_replay_buffer.sample(self.demo_batch_size) + gen_samples = types.dataclass_quick_asdict(gen_samples_dataclass) n_gen = len(gen_samples["obs"]) n_expert = len(expert_samples["obs"]) diff --git a/src/imitation/algorithms/base.py b/src/imitation/algorithms/base.py index a07d35793..861927de2 100644 --- a/src/imitation/algorithms/base.py +++ b/src/imitation/algorithms/base.py @@ -1,7 +1,7 @@ """Module of base classes and helper methods for imitation learning algorithms.""" import abc -from typing import Any, Generic, Iterable, Mapping, Optional, TypeVar, Union +from typing import Any, Generic, Iterable, Mapping, Optional, TypeVar, Union, cast import numpy as np import torch as th @@ -10,6 +10,7 @@ from imitation.data import rollout, types from imitation.util import logger as imit_logger +from imitation.util import util class BaseImitationAlgorithm(abc.ABC): @@ -116,7 +117,7 @@ def __setstate__(self, state): AnyTransitions = Union[ Iterable[types.Trajectory], Iterable[TransitionMapping], - TransitionKind, + types.TransitionsMinimal, ] @@ -242,11 +243,9 @@ def make_data_loader( raise ValueError(f"batch_size={batch_size} must be positive.") if isinstance(transitions, Iterable): - try: - first_item = next(iter(transitions)) - except StopIteration: - first_item = None + first_item, transitions = util.get_first_iter_element(transitions) if isinstance(first_item, types.Trajectory): + transitions = cast(Iterable[types.Trajectory], transitions) transitions = rollout.flatten_trajectories(list(transitions)) if isinstance(transitions, types.TransitionsMinimal): @@ -256,14 +255,16 @@ def make_data_loader( f"is smaller than batch size {batch_size}.", ) - extra_kwargs = dict(shuffle=True, drop_last=True) - if data_loader_kwargs is not None: - extra_kwargs.update(data_loader_kwargs) + kwargs: Mapping[str, Any] = { + "shuffle": True, + "drop_last": True, + **(data_loader_kwargs or {}), + } return th_data.DataLoader( transitions, batch_size=batch_size, collate_fn=types.transitions_collate_fn, - **extra_kwargs, + **kwargs, ) elif isinstance(transitions, Iterable): return _WrappedDataLoader(transitions, batch_size) diff --git a/src/imitation/algorithms/bc.py b/src/imitation/algorithms/bc.py index e789a33c5..cc9a20cb2 100644 --- a/src/imitation/algorithms/bc.py +++ b/src/imitation/algorithms/bc.py @@ -28,6 +28,7 @@ from imitation.data import rollout, types from imitation.policies import base as policy_base from imitation.util import logger as imit_logger +from imitation.util import util @dataclasses.dataclass(frozen=True) @@ -113,6 +114,8 @@ def __call__( A BCTrainingMetrics object with the loss and all the components it consists of. """ + obs = util.safe_to_tensor(obs) + acts = util.safe_to_tensor(acts) _, log_prob, entropy = policy.evaluate_actions(obs, acts) prob_true_act = th.exp(log_prob).mean() log_prob = log_prob.mean() @@ -120,6 +123,8 @@ def __call__( l2_norms = [th.sum(th.square(w)) for w in policy.parameters()] l2_norm = sum(l2_norms) / 2 # divide by 2 to cancel with gradient of square + # sum of list defaults to float(0) if len == 0. + assert isinstance(l2_norm, th.Tensor) ent_loss = -self.ent_weight * entropy neglogp = -log_prob @@ -177,7 +182,7 @@ class RolloutStatsComputer: n_episodes: The number of episodes to base the statistics on. """ - venv: vec_env.VecEnv + venv: Optional[vec_env.VecEnv] n_episodes: int # TODO(shwang): Maybe instead use a callback that can be shared between @@ -185,12 +190,17 @@ class RolloutStatsComputer: # EvalCallback could be a good fit: # https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback - def __call__(self, policy: policies.ActorCriticPolicy) -> Mapping[str, float]: + def __call__( + self, + policy: policies.ActorCriticPolicy, + rng: np.random.Generator, + ) -> Mapping[str, float]: if self.venv is not None and self.n_episodes > 0: trajs = rollout.generate_trajectories( policy, self.venv, rollout.make_min_episodes(self.n_episodes), + rng=rng, ) return rollout.rollout_stats(trajs) else: @@ -272,6 +282,7 @@ def __init__( *, observation_space: gym.Space, action_space: gym.Space, + rng: np.random.Generator, policy: Optional[policies.ActorCriticPolicy] = None, demonstrations: Optional[algo_base.AnyTransitions] = None, batch_size: int = 32, @@ -287,6 +298,7 @@ def __init__( Args: observation_space: the observation space of the environment. action_space: the action space of the environment. + rng: the random state to use for the random number generator. policy: a Stable Baselines3 policy; if unspecified, defaults to `FeedForward32Policy`. demonstrations: Demonstrations from an expert (optional). Transitions @@ -317,6 +329,8 @@ def __init__( self.action_space = action_space self.observation_space = observation_space + self.rng = rng + if policy is None: policy = policy_base.FeedForward32Policy( observation_space=observation_space, @@ -417,6 +431,7 @@ def _on_epoch_end(epoch_number: int): if on_epoch_end is not None: on_epoch_end() + assert self._demo_data_loader is not None demonstration_batches = BatchIteratorWithEpochEndCallback( self._demo_data_loader, n_epochs, @@ -438,7 +453,7 @@ def _on_epoch_end(epoch_number: int): loss = self.trainer(batch) if batch_num % log_interval == 0: - rollout_stats = compute_rollout_stats(self.policy) + rollout_stats = compute_rollout_stats(self.policy, self.rng) self._bc_logger.log_batch( batch_num, @@ -457,4 +472,4 @@ def save_policy(self, policy_path: types.AnyPath) -> None: Args: policy_path: path to save policy to. """ - th.save(self.policy, policy_path) + th.save(self.policy, types.path_to_str(policy_path)) diff --git a/src/imitation/algorithms/dagger.py b/src/imitation/algorithms/dagger.py index 8700e0b12..9a3b086f8 100644 --- a/src/imitation/algorithms/dagger.py +++ b/src/imitation/algorithms/dagger.py @@ -21,7 +21,8 @@ from imitation.algorithms import base, bc from imitation.data import rollout, types -from imitation.util import logger, util +from imitation.util import logger as imit_logger +from imitation.util import util class BetaSchedule(abc.ABC): @@ -69,7 +70,7 @@ def __call__(self, round_num: int) -> float: def reconstruct_trainer( scratch_dir: types.AnyPath, venv: vec_env.VecEnv, - custom_logger: Optional[logger.HierarchicalLogger] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, device: Union[th.device, str] = "auto", ) -> "DAggerTrainer": """Reconstruct trainer from the latest snapshot in some working directory. @@ -88,8 +89,11 @@ def reconstruct_trainer( Returns: A deserialized `DAggerTrainer`. """ - custom_logger = custom_logger or logger.configure() - checkpoint_path = pathlib.Path(scratch_dir, "checkpoint-latest.pt") + custom_logger = custom_logger or imit_logger.configure() + checkpoint_path = pathlib.Path( + types.path_to_str(scratch_dir), + "checkpoint-latest.pt", + ) trainer = th.load(checkpoint_path, map_location=utils.get_device(device)) trainer.venv = venv trainer._logger = custom_logger @@ -105,14 +109,14 @@ def _save_dagger_demo( # however that NPZ save here is likely more space efficient than # pickle from types.save(), and types.save only accepts # TrajectoryWithRew right now (subclass of Trajectory). - save_dir = pathlib.Path(save_dir) + save_dir_obj = pathlib.Path(types.path_to_str(save_dir)) assert isinstance(trajectory, types.Trajectory) actual_prefix = f"{prefix}-" if prefix else "" timestamp = util.make_unique_timestamp() filename = f"{actual_prefix}dagger-demo-{timestamp}.npz" - save_dir.mkdir(parents=True, exist_ok=True) - npz_path = pathlib.Path(save_dir, filename) + save_dir_obj.mkdir(parents=True, exist_ok=True) + npz_path = save_dir_obj / filename np.savez_compressed(npz_path, **dataclasses.asdict(trajectory)) logging.info(f"Saved demo at '{npz_path}'") @@ -146,12 +150,17 @@ class InteractiveTrajectoryCollector(vec_env.VecEnvWrapper): of every episode. """ + traj_accum: Optional[rollout.TrajectoryAccumulator] + _last_obs: Optional[np.ndarray] + _last_user_actions: Optional[np.ndarray] + def __init__( self, venv: vec_env.VecEnv, get_robot_acts: Callable[[np.ndarray], np.ndarray], beta: float, save_dir: types.AnyPath, + rng: np.random.Generator, ): """Builds InteractiveTrajectoryCollector. @@ -164,6 +173,7 @@ def __init__( robot action. The choice of robot or human action is independently randomized for each individual `Env` at every timestep. save_dir: directory to save collected trajectories in. + rng: random state for random number generation. """ super().__init__(venv) self.get_robot_acts = get_robot_acts @@ -175,7 +185,7 @@ def __init__( self._done_before = True self._is_reset = False self._last_user_actions = None - self.rng = np.random.RandomState() + self.rng = rng def seed(self, seed=Optional[int]) -> List[Union[None, int]]: """Set the seed for the DAgger random number generator and wrapped VecEnv. @@ -190,7 +200,7 @@ def seed(self, seed=Optional[int]) -> List[Union[None, int]]: A list containing the seeds for each individual env. Note that all list elements may be None, if the env does not return anything when seeded. """ - self.rng = np.random.RandomState(seed=seed) + self.rng = np.random.default_rng(seed=seed) return self.venv.seed(seed) def reset(self) -> np.ndarray: @@ -201,6 +211,7 @@ def reset(self) -> np.ndarray: """ self.traj_accum = rollout.TrajectoryAccumulator() obs = self.venv.reset() + assert isinstance(obs, np.ndarray) for i, ob in enumerate(obs): self.traj_accum.add_step({"obs": ob}, key=i) self._last_obs = obs @@ -228,6 +239,7 @@ def step_async(self, actions: np.ndarray) -> None: and executed instead via `self.get_robot_act`. """ assert self._is_reset, "call .reset() before .step()" + assert self._last_obs is not None # Replace each given action with a robot action 100*(1-beta)% of the time. actual_acts = np.array(actions) @@ -248,6 +260,9 @@ def step_wait(self) -> VecEnvStepReturn: Observation, reward, dones (is terminal?) and info dict. """ next_obs, rews, dones, infos = self.venv.step_wait() + assert isinstance(next_obs, np.ndarray) + assert self.traj_accum is not None + assert self._last_user_actions is not None self._last_obs = next_obs fresh_demos = self.traj_accum.add_steps_and_auto_finish( obs=next_obs, @@ -296,6 +311,8 @@ class DAggerTrainer(base.BaseImitationAlgorithm): … """ + _all_demos: List[types.Trajectory] + DEFAULT_N_EPOCHS: int = 4 """The default number of BC training epochs in `extend_and_update`.""" @@ -304,9 +321,10 @@ def __init__( *, venv: vec_env.VecEnv, scratch_dir: types.AnyPath, + rng: np.random.Generator, beta_schedule: Optional[Callable[[int], float]] = None, bc_trainer: bc.BC, - custom_logger: Optional[logger.HierarchicalLogger] = None, + custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): """Builds DAggerTrainer. @@ -314,6 +332,7 @@ def __init__( venv: Vectorized training environment. scratch_dir: Directory to use to store intermediate training information (e.g. for resuming training). + rng: random state for random number generation. beta_schedule: Provides a value of `beta` (the probability of taking expert action in any given state) at each round of training. If `None`, then `linear_beta_schedule` will be used instead. @@ -325,11 +344,12 @@ def __init__( if beta_schedule is None: beta_schedule = LinearBetaSchedule(15) self.beta_schedule = beta_schedule - self.scratch_dir = pathlib.Path(scratch_dir) + self.scratch_dir = pathlib.Path(types.path_to_str(scratch_dir)) self.venv = venv self.round_num = 0 self._last_loaded_round = -1 self._all_demos = [] + self.rng = rng utils.check_for_correct_spaces( self.venv, @@ -346,8 +366,13 @@ def __getstate__(self): del d["_logger"] return d - @base.BaseImitationAlgorithm.logger.setter - def logger(self, value: logger.HierarchicalLogger) -> None: + @property + def logger(self) -> imit_logger.HierarchicalLogger: + """Returns logger for this object.""" + return super().logger + + @logger.setter + def logger(self, value: imit_logger.HierarchicalLogger) -> None: # DAgger and inner-BC logger should stay in sync self._logger = value self.bc_trainer.logger = value @@ -471,6 +496,7 @@ def create_trajectory_collector(self) -> InteractiveTrajectoryCollector: get_robot_acts=lambda acts: self.bc_trainer.policy.predict(acts)[0], beta=beta, save_dir=save_dir, + rng=self.rng, ) return collector @@ -525,6 +551,7 @@ def __init__( venv: vec_env.VecEnv, scratch_dir: types.AnyPath, expert_policy: policies.BasePolicy, + rng: np.random.Generator, expert_trajs: Optional[Sequence[types.Trajectory]] = None, **dagger_trainer_kwargs, ): @@ -538,6 +565,7 @@ def __init__( scratch_dir: Directory to use to store intermediate training information (e.g. for resuming training). expert_policy: The expert policy used to generate synthetic demonstrations. + rng: Random state to use for the random number generator. expert_trajs: Optional starting dataset that is inserted into the round 0 dataset. dagger_trainer_kwargs: Other keyword arguments passed to the @@ -547,7 +575,12 @@ def __init__( ValueError: The observation or action space does not match between `venv` and `expert_policy`. """ - super().__init__(venv=venv, scratch_dir=scratch_dir, **dagger_trainer_kwargs) + super().__init__( + venv=venv, + scratch_dir=scratch_dir, + rng=rng, + **dagger_trainer_kwargs, + ) self.expert_policy = expert_policy if expert_policy.observation_space != self.venv.observation_space: raise ValueError( diff --git a/src/imitation/algorithms/density.py b/src/imitation/algorithms/density.py index ad52e7ad3..c585b396a 100644 --- a/src/imitation/algorithms/density.py +++ b/src/imitation/algorithms/density.py @@ -6,7 +6,8 @@ import enum import itertools -from typing import Iterable, Mapping, Optional +from collections.abc import Mapping +from typing import Dict, Iterable, List, Optional, cast import numpy as np from gym.spaces.utils import flatten @@ -17,6 +18,7 @@ from imitation.data import rollout, types, wrappers from imitation.rewards import reward_wrapper from imitation.util import logger as imit_logger +from imitation.util import util class DensityType(enum.Enum): @@ -39,13 +41,27 @@ class DensityAlgorithm(base.DemonstrationAlgorithm): and then computes a reward using the log of these probabilities. """ - transitions: Mapping[Optional[int], np.ndarray] + is_stationary: bool + density_type: DensityType + venv: vec_env.VecEnv + transitions: Dict[Optional[int], np.ndarray] + kernel: str + kernel_bandwidth: float + standardise: bool + + _scaler: Optional[preprocessing.StandardScaler] + _density_models: Dict[Optional[int], neighbors.KernelDensity] + rl_algo: Optional[base_class.BaseAlgorithm] + buffering_wrapper: wrappers.BufferingWrapper + venv_wrapped: reward_wrapper.RewardVecEnvWrapper + wrapper_callback: reward_wrapper.WrappedRewardCallback def __init__( self, *, demonstrations: Optional[base.AnyTransitions], venv: vec_env.VecEnv, + rng: np.random.Generator, density_type: DensityType = DensityType.STATE_ACTION_DENSITY, kernel: str = "gaussian", kernel_bandwidth: float = 0.5, @@ -69,6 +85,7 @@ def __init__( any environment interaction to fit the reward model, but we use this to extract the observation and action space, and to train the RL algorithm `rl_algo` (if specified). + rng: random state for sampling from demonstrations. rl_algo: An RL algorithm to train on the resulting reward model (optional). is_stationary: if True, share same density models for all timesteps; if False, use a different density model for each timestep. @@ -93,7 +110,7 @@ def __init__( self.is_stationary = is_stationary self.density_type = density_type self.venv = venv - self.transitions = {} + self.transitions = dict() super().__init__( demonstrations=demonstrations, custom_logger=custom_logger, @@ -104,7 +121,8 @@ def __init__( self.kernel_bandwidth = kernel_bandwidth self.standardise = standardise_inputs self._scaler = None - self._density_models = {} + self._density_models = dict() + self.rng = rng self.rl_algo = rl_algo self.buffering_wrapper = wrappers.BufferingWrapper(self.venv) @@ -114,53 +132,82 @@ def __init__( ) self.wrapper_callback = self.venv_wrapped.make_log_callback() - def _set_demo_from_batch( + def _get_demo_from_batch( self, obs_b: np.ndarray, act_b: np.ndarray, next_obs_b: Optional[np.ndarray], - ) -> None: - next_obs_b = next_obs_b or itertools.repeat(None) - for obs, act, next_obs in zip(obs_b, act_b, next_obs_b): + ) -> Dict[Optional[int], List[np.ndarray]]: + if next_obs_b is None and self.density_type == DensityType.STATE_STATE_DENSITY: + raise ValueError( + "STATE_STATE_DENSITY requires next_obs_b " + "to be provided, but it was None", + ) + + assert act_b.shape[1:] == self.venv.action_space.shape + assert obs_b.shape[1:] == self.venv.observation_space.shape + assert len(act_b) == len(obs_b) + if next_obs_b is not None: + assert next_obs_b.shape[1:] == self.venv.observation_space.shape + assert len(next_obs_b) == len(obs_b) + + transitions: Dict[Optional[int], List[np.ndarray]] = {} + next_obs_b_iterator = ( + next_obs_b if next_obs_b is not None else itertools.repeat(None) + ) + for obs, act, next_obs in zip(obs_b, act_b, next_obs_b_iterator): flat_trans = self._preprocess_transition(obs, act, next_obs) - self.transitions.setdefault(None, []).append(flat_trans) + transitions.setdefault(None, []).append(flat_trans) + return transitions def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: """Sets the demonstration data.""" - self.transitions = {} - - if isinstance(demonstrations, Iterable): - first_item = next(iter(demonstrations)) + transitions: Dict[Optional[int], List[np.ndarray]] = {} + + if isinstance(demonstrations, types.TransitionsMinimal): + next_obs_b = getattr(demonstrations, "next_obs", None) + transitions.update( + self._get_demo_from_batch( + demonstrations.obs, + demonstrations.acts, + next_obs_b, + ), + ) + elif isinstance(demonstrations, Iterable): + first_item, demonstrations = util.get_first_iter_element(demonstrations) if isinstance(first_item, types.Trajectory): - # Demonstrations are trajectories. - # We have timestep information. + # we assume that all elements are also types.Trajectory. + # (this means we have timestamp information) + # It's not perfectly type safe, but it allows for the flexibility of + # using iterables, which is useful for large data structures. + demonstrations = cast(Iterable[types.Trajectory], demonstrations) + for traj in demonstrations: for i, (obs, act, next_obs) in enumerate( zip(traj.obs[:-1], traj.acts, traj.obs[1:]), ): flat_trans = self._preprocess_transition(obs, act, next_obs) - self.transitions.setdefault(i, []).append(flat_trans) - else: - # Demonstrations are a Torch DataLoader or other Mapping iterable + transitions.setdefault(i, []).append(flat_trans) + elif isinstance(first_item, Mapping): + # analogous to cast above. + demonstrations = cast(Iterable[base.TransitionMapping], demonstrations) + for batch in demonstrations: - self._set_demo_from_batch( - batch["obs"], - batch["acts"], - batch.get("next_obs"), + transitions.update( + self._get_demo_from_batch( + util.safe_to_numpy(batch["obs"], warn=True), + util.safe_to_numpy(batch["acts"], warn=True), + util.safe_to_numpy(batch.get("next_obs"), warn=True), + ), ) - elif isinstance(demonstrations, types.TransitionsMinimal): - next_obs_b = ( - demonstrations.next_obs if hasattr(demonstrations, "next_obs") else None - ) - self._set_demo_from_batch( - demonstrations.obs, - demonstrations.acts, - next_obs_b, - ) + else: + raise TypeError( + f"Unsupported demonstration type {type(demonstrations)}", + ) else: raise TypeError(f"Unsupported demonstration type {type(demonstrations)}") - self.transitions = {k: np.stack(v, axis=0) for k, v in self.transitions.items()} + self.transitions = {k: np.stack(v, axis=0) for k, v in transitions.items()} if not self.is_stationary and None in self.transitions: raise ValueError( @@ -174,7 +221,7 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: def train(self): """Fits the density model to demonstration data `self.transitions`.""" # if requested, we'll scale demonstration transitions so that they have - # zero mean and unit variance (i.e all components are equally important) + # zero mean and unit variance (i.e. all components are equally important) self._scaler = preprocessing.StandardScaler( with_mean=self.standardise, with_std=self.standardise, @@ -200,7 +247,7 @@ def _preprocess_transition( self, obs: np.ndarray, act: np.ndarray, - next_obs: np.ndarray, + next_obs: Optional[np.ndarray], ) -> np.ndarray: """Compute flattened transition on subset specified by `self.density_type`.""" if self.density_type == DensityType.STATE_DENSITY: @@ -213,6 +260,7 @@ def _preprocess_transition( ], ) elif self.density_type == DensityType.STATE_STATE_DENSITY: + assert next_obs is not None return np.concatenate( [ flatten(self.venv.observation_space, obs), @@ -224,10 +272,10 @@ def _preprocess_transition( def __call__( self, - obs_b: np.ndarray, - act_b: np.ndarray, - next_obs_b: np.ndarray, - dones: np.ndarray, + state: np.ndarray, + action: np.ndarray, + next_state: np.ndarray, + done: np.ndarray, steps: Optional[np.ndarray] = None, ) -> np.ndarray: r"""Compute reward from given (s,a,s') transition batch. @@ -236,12 +284,12 @@ def __call__( VecEnvs. Args: - obs_b: current batch of observations. - act_b: batch of actions that agent took in response to those + state: current batch of observations. + action: batch of actions that agent took in response to those observations. - next_obs_b: batch of observations encountered after the + next_state: batch of observations encountered after the agent took those actions. - dones: is it terminal state? + done: is it terminal state? steps: What timestep is this from? Used if `self.is_stationary` is false, otherwise ignored. @@ -258,16 +306,18 @@ def __call__( if not self.is_stationary and steps is None: raise ValueError("steps must be provided with non-stationary models") - del dones # TODO(adam): should we handle terminal state specially in any way? + del done # TODO(adam): should we handle terminal state specially in any way? rew_list = [] - assert len(obs_b) == len(act_b) and len(obs_b) == len(next_obs_b) - for idx, (obs, act, next_obs) in enumerate(zip(obs_b, act_b, next_obs_b)): + assert len(state) == len(action) and len(state) == len(next_state) + for idx, (obs, act, next_obs) in enumerate(zip(state, action, next_state)): flat_trans = self._preprocess_transition(obs, act, next_obs) + assert self._scaler is not None scaled_padded_trans = self._scaler.transform(flat_trans[np.newaxis]) if self.is_stationary: rew = self._density_models[None].score(scaled_padded_trans) else: + assert steps is not None time = steps[idx] if time >= len(self._density_models): # Can't do anything sensible here yet. Correct solution is to use @@ -294,6 +344,7 @@ def train_policy(self, n_timesteps: int = int(1e6), **kwargs): method of the imitation RL model. Refer to Stable Baselines docs for details. """ + assert self.rl_algo is not None self.rl_algo.set_env(self.venv_wrapped) self.rl_algo.learn( n_timesteps, @@ -322,6 +373,7 @@ def test_policy(self, *, n_trajectories: int = 10, true_reward: bool = True): self.rl_algo, self.venv if true_reward else self.venv_wrapped, sample_until=rollout.make_min_episodes(n_trajectories), + rng=self.rng, ) # We collect `trajs` above so disregard return value from `pop_trajectories`, # but still call it to clear out any saved trajectories. @@ -332,4 +384,6 @@ def test_policy(self, *, n_trajectories: int = 10, true_reward: bool = True): @property def policy(self) -> base_class.BasePolicy: + assert self.rl_algo is not None + assert self.rl_algo.policy is not None return self.rl_algo.policy diff --git a/src/imitation/algorithms/mce_irl.py b/src/imitation/algorithms/mce_irl.py index 2c7e29378..0a3984dd5 100644 --- a/src/imitation/algorithms/mce_irl.py +++ b/src/imitation/algorithms/mce_irl.py @@ -7,7 +7,7 @@ """ import collections import warnings -from typing import Any, Iterable, Mapping, Optional, Tuple, Type, Union +from typing import Any, Iterable, List, Mapping, Optional, Tuple, Type, Union import gym import numpy as np @@ -141,12 +141,15 @@ def squeeze_r(r_output: th.Tensor) -> th.Tensor: class TabularPolicy(policies.BasePolicy): """A tabular policy. Cannot be trained -- prediction only.""" + pi: np.ndarray + rng: np.random.Generator + def __init__( self, state_space: gym.Space, action_space: gym.Space, pi: np.ndarray, - rng: Optional[np.random.RandomState], + rng: np.random.Generator, ): """Builds TabularPolicy. @@ -162,8 +165,7 @@ def __init__( assert isinstance(action_space, gym.spaces.Discrete), "action not tabular" # What we call state space here is observation space in SB3 nomenclature. super().__init__(observation_space=state_space, action_space=action_space) - self.rng = rng or np.random - self.pi = None + self.rng = rng self.set_pi(pi) def set_pi(self, pi: np.ndarray) -> None: @@ -174,18 +176,24 @@ def set_pi(self, pi: np.ndarray) -> None: self.pi = pi def _predict(self, observation: th.Tensor, deterministic: bool = False): - raise NotImplementedError("Should never be called as predict overridden.") + raise NotImplementedError( + "Should never be called as predict overridden.", + ) - def forward(self, observation: th.Tensor, deterministic: bool = False): - raise NotImplementedError("Should never be called.") + def forward( # type: ignore[override] + self, + observation: th.Tensor, + deterministic: bool = False, + ): + raise NotImplementedError("Should never be called.") # pragma: no cover def predict( self, - observation: np.ndarray, - state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, + observation: Union[np.ndarray, Mapping[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """Predict action to take in given state. Arguments follow SB3 naming convention as this is an SB3 policy. @@ -199,36 +207,35 @@ def predict( Args: observation: States in the underlying MDP. state: Hidden states of the policy -- used to represent timesteps by us. - mask: Has episode completed? + episode_start: Has episode completed? deterministic: If true, pick action with highest probability; otherwise, sample. Returns: Tuple of the actions and new hidden states. """ - timesteps = state # rename to avoid confusion - del state - - if timesteps is None: + if state is None: timesteps = np.zeros(len(observation), dtype=int) else: - timesteps = np.array(timesteps) + assert len(state) == 1 + timesteps = state[0] assert len(timesteps) == len(observation), "timestep and obs batch size differ" - if mask is not None: - timesteps[mask] = 0 + if episode_start is not None: + timesteps[episode_start] = 0 - actions = [] + actions: List[int] = [] for obs, t in zip(observation, timesteps): assert self.observation_space.contains(obs), "illegal state" dist = self.pi[t, obs, :] if deterministic: - actions.append(dist.argmax()) + actions.append(int(dist.argmax())) else: actions.append(self.rng.choice(len(dist), p=dist)) timesteps += 1 # increment timestep - return np.array(actions), timesteps + state = (timesteps,) + return np.array(actions), state MCEDemonstrations = Union[np.ndarray, base.AnyTransitions] @@ -254,6 +261,7 @@ def __init__( demonstrations: Optional[MCEDemonstrations], env: resettable_env.TabularModelEnv, reward_net: reward_nets.RewardNet, + rng: np.random.Generator, optimizer_cls: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Mapping[str, Any]] = None, discount: float = 1.0, @@ -262,7 +270,6 @@ def __init__( # TODO(adam): do we need log_interval or can just use record_mean...? log_interval: Optional[int] = 100, *, - rng: Optional[np.random.RandomState] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): r"""Creates MCE IRL. @@ -274,10 +281,11 @@ def __init__( The demonstrations must have observations one-hot coded unless demonstrations is a state-occupancy measure. env: a tabular MDP. + rng: random state used for sampling from policy. reward_net: a neural network that computes rewards for the supplied observations. - optimizer_cls: optimiser to use for supervised training. - optimizer_kwargs: keyword arguments for optimiser construction. + optimizer_cls: optimizer to use for supervised training. + optimizer_kwargs: keyword arguments for optimizer construction. discount: the discount factor to use when computing occupancy measure. If not 1.0 (undiscounted), then `demonstrations` must either be a (discounted) state-occupancy measure, or trajectories. Transitions @@ -290,7 +298,6 @@ def __init__( MCE IRL gradient falls below this value. log_interval: how often to log current loss stats (using `logging`). None to disable. - rng: random state used for sampling from policy. custom_logger: Where to log to; if None (default), creates a new logger. """ self.discount = discount @@ -350,7 +357,7 @@ def _set_demo_from_obs( # then possibly shuffled. So add next observations for terminal states, # as they will not appear anywhere else; but ignore next observations # for all other states as they occur elsewhere in dataset. - if next_obses is not None: + if dones is not None and next_obses is not None: for done, obs in zip(dones, next_obses): if isinstance(done, th.Tensor): done = done.item() # must be scalar @@ -376,7 +383,7 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: # Demonstrations are either trajectories or transitions; # we must compute occupancy measure from this. if isinstance(demonstrations, Iterable): - first_item = next(iter(demonstrations)) + first_item, demonstrations = util.get_first_iter_element(demonstrations) if isinstance(first_item, types.Trajectory): self._set_demo_from_trajectories(demonstrations) return @@ -401,14 +408,13 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None: # Collect them together into one big NumPy array. This is inefficient, # we could compute the running statistics instead, but in practice do # not expect large dataset sizes together with MCE IRL. - collated = collections.defaultdict(list) + collated_list = collections.defaultdict(list) for batch in demonstrations: assert isinstance(batch, Mapping) for k in ("obs", "dones", "next_obs"): if k in batch: - collated[k].append(batch[k]) - for k, v in collated.items(): - collated[k] = np.concatenate(v) + collated_list[k].append(batch[k]) + collated = {k: np.concatenate(v) for k, v in collated_list.items()} assert "obs" in collated for k, v in collated.items(): @@ -474,6 +480,7 @@ def train(self, max_iter: int = 1000) -> np.ndarray: dtype=self.reward_net.dtype, device=self.reward_net.device, ) + assert self.demo_state_om is not None assert self.demo_state_om.shape == (len(obs_mat),) with networks.training(self.reward_net): diff --git a/src/imitation/algorithms/preference_comparisons.py b/src/imitation/algorithms/preference_comparisons.py index 7d37b337b..58897113f 100644 --- a/src/imitation/algorithms/preference_comparisons.py +++ b/src/imitation/algorithms/preference_comparisons.py @@ -6,7 +6,6 @@ import abc import math import pickle -import random import re from collections import defaultdict from typing import ( @@ -20,6 +19,7 @@ Sequence, Tuple, Union, + cast, ) import numpy as np @@ -100,19 +100,19 @@ class TrajectoryDataset(TrajectoryGenerator): def __init__( self, trajectories: Sequence[TrajectoryWithRew], - seed: Optional[int] = None, + rng: np.random.Generator, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): """Creates a dataset loaded from `path`. Args: trajectories: the dataset of rollouts. - seed: Seed for RNG used for shuffling dataset. + rng: RNG used for shuffling dataset. custom_logger: Where to log to; if None (default), creates a new logger. """ super().__init__(custom_logger=custom_logger) self._trajectories = trajectories - self.rng = random.Random(seed) + self.rng = rng def sample(self, steps: int) -> Sequence[TrajectoryWithRew]: # make a copy before shuffling @@ -129,10 +129,10 @@ def __init__( algorithm: base_class.BaseAlgorithm, reward_fn: Union[reward_function.RewardFn, reward_nets.RewardNet], venv: vec_env.VecEnv, + rng: np.random.Generator, exploration_frac: float = 0.0, switch_prob: float = 0.5, random_prob: float = 0.5, - seed: Optional[int] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): """Initialize the agent trainer. @@ -142,13 +142,13 @@ def __init__( reward_fn: either a RewardFn or a RewardNet instance that will supply the rewards used for training the agent. venv: vectorized environment to train in. + rng: random number generator used for exploration and for sampling. exploration_frac: fraction of the trajectories that will be generated partially randomly rather than only by the agent when sampling. switch_prob: the probability of switching the current policy at each step for the exploratory samples. random_prob: the probability of picking the random policy when switching during exploration. - seed: random seed for exploratory trajectories. custom_logger: Where to log to; if None (default), creates a new logger. """ self.algorithm = algorithm @@ -164,6 +164,7 @@ def __init__( reward_fn = reward_fn.predict_processed self.reward_fn = reward_fn self.exploration_frac = exploration_frac + self.rng = rng # The BufferingWrapper records all trajectories, so we can return # them after training. This should come first (before the wrapper that @@ -185,9 +186,11 @@ def __init__( self.algorithm.set_env(self.venv) # Unlike with BufferingWrapper, we should use `algorithm.get_env()` instead # of `venv` when interacting with `algorithm`. + algo_venv = self.algorithm.get_env() + assert algo_venv is not None policy_callable = rollout._policy_to_callable( self.algorithm, - self.algorithm.get_env(), + algo_venv, # By setting deterministic_policy to False, we ensure that the rollouts # are collected from a deterministic policy only if self.algorithm is # deterministic. If self.algorithm is stochastic, then policy_callable @@ -196,10 +199,10 @@ def __init__( ) self.exploration_wrapper = exploration_wrapper.ExplorationWrapper( policy_callable=policy_callable, - venv=self.algorithm.get_env(), + venv=algo_venv, random_prob=random_prob, switch_prob=switch_prob, - seed=seed, + rng=self.rng, ) def train(self, steps: int, **kwargs) -> None: @@ -258,45 +261,57 @@ def sample(self, steps: int) -> Sequence[types.TrajectoryWithRew]: # here because 1) they might miss initial timesteps taken by the RL agent # and 2) their rewards are the ones provided by the reward model! # Instead, we collect the trajectories using the BufferingWrapper. + algo_venv = self.algorithm.get_env() + assert algo_venv is not None rollout.generate_trajectories( self.algorithm, - self.algorithm.get_env(), + algo_venv, sample_until=sample_until, # By setting deterministic_policy to False, we ensure that the rollouts # are collected from a deterministic policy only if self.algorithm is # deterministic. If self.algorithm is stochastic, then policy_callable # will also be stochastic. deterministic_policy=False, + rng=self.rng, ) additional_trajs, _ = self.buffering_wrapper.pop_finished_trajectories() agent_trajs = list(agent_trajs) + list(additional_trajs) agent_trajs = _get_trajectories(agent_trajs, agent_steps) - exploration_trajs = [] + trajectories = list(agent_trajs) + if exploration_steps > 0: self.logger.log(f"Sampling {exploration_steps} exploratory transitions.") sample_until = rollout.make_sample_until( min_timesteps=exploration_steps, min_episodes=None, ) + algo_venv = self.algorithm.get_env() + assert algo_venv is not None rollout.generate_trajectories( policy=self.exploration_wrapper, - venv=self.algorithm.get_env(), + venv=algo_venv, sample_until=sample_until, - # buffering_wrapper collects rollouts from a non-deterministic policy + # buffering_wrapper collects rollouts from a non-deterministic policy, # so we do that here as well for consistency. deterministic_policy=False, + rng=self.rng, ) exploration_trajs, _ = self.buffering_wrapper.pop_finished_trajectories() exploration_trajs = _get_trajectories(exploration_trajs, exploration_steps) - # We call _get_trajectories separately on agent_trajs and exploration_trajs - # and then just concatenate. This could mean we return slightly too many - # transitions, but it gets the proportion of exploratory and agent transitions - # roughly right. - return list(agent_trajs) + list(exploration_trajs) + # We call _get_trajectories separately on agent_trajs and exploration_trajs + # and then just concatenate. This could mean we return slightly too many + # transitions, but it gets the proportion of exploratory and agent + # transitions roughly right. + trajectories.extend(list(exploration_trajs)) + return trajectories - @TrajectoryGenerator.logger.setter + @property + def logger(self): + return super().logger + + @logger.setter def logger(self, value: imit_logger.HierarchicalLogger): self._logger = value self.algorithm.set_logger(self.logger) @@ -320,7 +335,7 @@ def _get_trajectories( steps_cumsum = np.cumsum([len(traj) for traj in trajectories]) # Now we find the first index that gives us enough # total steps: - idx = (steps_cumsum >= steps).argmax() + idx = int((steps_cumsum >= steps).argmax()) # we need to include the element at position idx trajectories = trajectories[: idx + 1] # sanity check @@ -365,10 +380,11 @@ def __init__( self.noise_prob = noise_prob self.discount_factor = discount_factor self.threshold = threshold - self.is_ensemble, base_model = is_base_model_ensemble(self.model) + base_model = get_base_model(model) + self.ensemble_model = None # if the base model is an ensemble model, then keep the base model as # model to get rewards from all networks - if self.is_ensemble: + if isinstance(base_model, reward_nets.RewardEnsemble): # reward_model may include an AddSTDRewardWrapper for RL training; but we # must train directly on the base model for reward model training. is_base = model is base_model @@ -382,11 +398,11 @@ def __init__( "RewardEnsemble can only be wrapped" f" by AddSTDRewardWrapper but found {type(model).__name__}.", ) - self.model = base_model + self.ensemble_model = base_model self.member_pref_models = [] - for member in self.model.members: + for member in self.ensemble_model.members: member_pref_model = PreferenceModel( - member, + cast(reward_nets.RewardNet, member), # nn.ModuleList is not generic self.noise_prob, self.discount_factor, self.threshold, @@ -431,6 +447,8 @@ def forward( rews2 = self.rewards(trans2) probs[i] = self.probability(rews1, rews2) if gt_reward_available: + frag1 = cast(TrajectoryWithRew, frag1) + frag2 = cast(TrajectoryWithRew, frag2) gt_rews_1 = th.from_numpy(frag1.rews) gt_rews_2 = th.from_numpy(frag2.rews) gt_probs[i] = self.probability(gt_rews_1, gt_rews_2) @@ -452,15 +470,20 @@ def rewards(self, transitions: Transitions) -> th.Tensor: action = transitions.acts next_state = transitions.next_obs done = transitions.dones - if self.is_ensemble: - rews = self.model.predict_processed_all(state, action, next_state, done) - assert rews.shape == (len(state), self.model.num_members) - return util.safe_to_tensor(rews).to(self.model.device) + if self.ensemble_model is not None: + rews_np = self.ensemble_model.predict_processed_all( + state, + action, + next_state, + done, + ) + assert rews_np.shape == (len(state), self.ensemble_model.num_members) + rews = util.safe_to_tensor(rews_np).to(self.ensemble_model.device) else: preprocessed = self.model.preprocess(state, action, next_state, done) rews = self.model(*preprocessed) assert rews.shape == (len(state),) - return rews + return rews def probability(self, rews1: th.Tensor, rews2: th.Tensor) -> th.Tensor: """Computes the Boltzmann rational probability that the first trajectory is best. @@ -478,17 +501,17 @@ def probability(self, rews1: th.Tensor, rews2: th.Tensor) -> th.Tensor: () for non-ensemble model which is a torch scalar. """ # check rews has correct shape based on the model - expected_dims = 2 if self.is_ensemble else 1 + expected_dims = 2 if self.ensemble_model is not None else 1 assert rews1.ndim == rews2.ndim == expected_dims # First, we compute the difference of the returns of # the two fragments. We have a special case for a discount # factor of 1 to avoid unnecessary computation (especially # since this is the default setting). if self.discount_factor == 1: - returns_diff = (rews2 - rews1).sum(axis=0) + returns_diff = (rews2 - rews1).sum(axis=0) # type: ignore[call-overload] else: discounts = self.discount_factor ** th.arange(len(rews1)) - if self.is_ensemble: + if self.ensemble_model is not None: discounts = discounts.reshape(-1, 1) returns_diff = (discounts * (rews2 - rews1)).sum(axis=0) # Clip to avoid overflows (which in particular may occur @@ -499,7 +522,7 @@ def probability(self, rews1: th.Tensor, rews2: th.Tensor) -> th.Tensor: # probability that fragment 1 is preferred. model_probability = 1 / (1 + returns_diff.exp()) probability = self.noise_prob * 0.5 + (1 - self.noise_prob) * model_probability - if self.is_ensemble: + if self.ensemble_model is not None: assert probability.shape == (self.model.num_members,) else: assert probability.shape == () @@ -551,21 +574,21 @@ class RandomFragmenter(Fragmenter): def __init__( self, - seed: Optional[float] = None, + rng: np.random.Generator, warning_threshold: int = 10, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): """Initialize the fragmenter. Args: - seed: an optional seed for the internal RNG + rng: the random number generator warning_threshold: give a warning if the number of available transitions is less than this many times the number of required samples. Set to 0 to disable this warning. custom_logger: Where to log to; if None (default), creates a new logger. """ super().__init__(custom_logger) - self.rng = random.Random(seed) + self.rng = rng self.warning_threshold = warning_threshold def __call__( @@ -617,9 +640,9 @@ def __call__( # we need two fragments for each comparison for _ in range(2 * num_pairs): - traj = self.rng.choices(trajectories, weights, k=1)[0] + traj = self.rng.choice(trajectories, p=np.array(weights) / sum(weights)) n = len(traj) - start = self.rng.randint(0, n - fragment_length) + start = self.rng.integers(0, n - fragment_length, endpoint=True) end = start + fragment_length terminal = (end == n) and traj.terminal fragment = TrajectoryWithRew( @@ -669,7 +692,7 @@ def __init__( ValueError: Preference model not wrapped over an ensemble of networks. """ super().__init__(custom_logger=custom_logger) - if not preference_model.is_ensemble: + if preference_model.ensemble_model is None: raise ValueError( "PreferenceModel not wrapped over an ensemble of networks.", ) @@ -735,12 +758,12 @@ def variance_estimate(self, rews1, rews2) -> float: var_estimate = (returns1 - returns2).var().item() else: # uncertainty_on is probability or label probs = self.preference_model.probability(rews1, rews2) - probs = probs.cpu().numpy() - assert probs.shape == (self.preference_model.model.num_members,) + probs_np = probs.cpu().numpy() + assert probs_np.shape == (self.preference_model.model.num_members,) if self.uncertainty_on == "probability": - var_estimate = probs.var() + var_estimate = probs_np.var() elif self.uncertainty_on == "label": # uncertainty_on is label - preds = (probs > 0.5).astype(np.float32) + preds = (probs_np > 0.5).astype(np.float32) # probability estimate of Bernoulli random variable prob_estimate = preds.mean() # variance estimate of Bernoulli random variable @@ -755,20 +778,20 @@ class PreferenceGatherer(abc.ABC): def __init__( self, - seed: Optional[int] = None, + rng: Optional[np.random.Generator] = None, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): """Initializes the preference gatherer. Args: - seed: seed for the internal RNG, if applicable + rng: random number generator, if applicable. custom_logger: Where to log to; if None (default), creates a new logger. """ # The random seed isn't used here, but it's useful to have this # as an argument nevertheless because that means we can always # pass in a seed in training scripts (without worrying about whether # the PreferenceGatherer we use needs one). - del seed + del rng self.logger = custom_logger or imit_logger.configure() @abc.abstractmethod @@ -798,7 +821,7 @@ def __init__( temperature: float = 1, discount_factor: float = 1, sample: bool = True, - seed: Optional[int] = None, + rng: Optional[np.random.Generator] = None, threshold: float = 50, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, ): @@ -815,21 +838,28 @@ def __init__( a Bernoulli distribution (or 0.5 in the case of ties with zero temperature). If False, then the underlying Bernoulli probabilities are returned instead. - seed: seed for the internal RNG (only used if temperature > 0 and sample) + rng: random number generator, only used if + ``temperature > 0`` and ``sample=True`` threshold: preferences are sampled from a softmax of returns. To avoid overflows, we clip differences in returns that are above this threshold (after multiplying with temperature). This threshold is therefore in logspace. The default value of 50 means that probabilities below 2e-22 are rounded up to 2e-22. custom_logger: Where to log to; if None (default), creates a new logger. + + Raises: + ValueError: if `sample` is true and no random state is provided. """ super().__init__(custom_logger=custom_logger) self.temperature = temperature self.discount_factor = discount_factor self.sample = sample - self.rng = np.random.default_rng(seed=seed) + self.rng = rng self.threshold = threshold + if self.sample and self.rng is None: + raise ValueError("If `sample` is True, then `rng` must be provided.") + def __call__(self, fragment_pairs: Sequence[TrajectoryWithRewPair]) -> np.ndarray: """Computes probability fragment 1 is preferred over fragment 2.""" returns1, returns2 = self._reward_sums(fragment_pairs) @@ -893,7 +923,7 @@ def __init__(self, max_size: Optional[int] = None): self.fragments1: List[TrajectoryWithRew] = [] self.fragments2: List[TrajectoryWithRew] = [] self.max_size = max_size - self.preferences = np.array([]) + self.preferences: np.ndarray = np.array([]) def push(self, fragments: Sequence[TrajectoryWithRewPair], preferences: np.ndarray): """Add more samples to the dataset. @@ -1094,11 +1124,11 @@ def __init__( self, preference_model: PreferenceModel, loss: RewardLoss, + rng: np.random.Generator, batch_size: int = 32, epochs: int = 1, lr: float = 1e-3, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - seed: Optional[int] = None, regularizer_factory: Optional[regularizers.RegularizerFactory] = None, ): """Initialize the reward model trainer. @@ -1106,13 +1136,13 @@ def __init__( Args: preference_model: the preference model to train the reward network. loss: the loss to use + rng: the random number generator to use for splitting the dataset into + training and validation. batch_size: number of fragment pairs per batch epochs: number of epochs in each training iteration (can be adjusted on the fly by specifying an `epoch_multiplier` in `self.train()` if longer training is desired in specific cases). lr: the learning rate - seed: the random seed to use for splitting the dataset into training - and validation. custom_logger: Where to log to; if None (default), creates a new logger. regularizer_factory: if you would like to apply regularization during training, specify a regularizer factory here. The factory will be @@ -1124,7 +1154,7 @@ def __init__( self.batch_size = batch_size self.epochs = epochs self.optim = th.optim.AdamW(self._preference_model.parameters(), lr=lr) - self.seed = seed + self.rng = rng self.regularizer = ( regularizer_factory(optimizer=self.optim, logger=self.logger) if regularizer_factory is not None @@ -1169,7 +1199,8 @@ def _train( train_dataset, val_dataset = data_th.random_split( dataset, lengths=[train_length, val_length], - generator=th.Generator().manual_seed(self.seed) if self.seed else None, + # we convert the numpy generator to the pytorch generator. + generator=th.Generator().manual_seed(util.make_seeds(self.rng)), ) dataloader = self._make_data_loader(train_dataset) val_dataloader = self._make_data_loader(val_dataset) @@ -1247,11 +1278,11 @@ def __init__( self, preference_model: PreferenceModel, loss: RewardLoss, + rng: np.random.Generator, batch_size: int = 32, epochs: int = 1, lr: float = 1e-3, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, - seed: Optional[int] = None, regularizer_factory: Optional[regularizers.RegularizerFactory] = None, ): """Initialize the reward model trainer. @@ -1259,13 +1290,12 @@ def __init__( Args: preference_model: the preference model to train the reward network. loss: the loss to use + rng: random state for the internal RNG used in bagging batch_size: number of fragment pairs per batch epochs: number of epochs in each training iteration (can be adjusted on the fly by specifying an `epoch_multiplier` in `self.train()` if longer training is desired in specific cases). lr: the learning rate - seed: the random seed to use for splitting the dataset into training - and validation, and for bagging. custom_logger: Where to log to; if None (default), creates a new logger. regularizer_factory: A factory for creating a regularizer. If None, no regularization is used. @@ -1273,7 +1303,7 @@ def __init__( Raises: TypeError: if model is not a RewardEnsemble. """ - if not preference_model.is_ensemble: + if preference_model.ensemble_model is None: raise TypeError( "PreferenceModel of a RewardEnsemble expected by EnsembleTrainer.", ) @@ -1285,6 +1315,7 @@ def __init__( epochs=epochs, lr=lr, custom_logger=custom_logger, + rng=rng, regularizer_factory=regularizer_factory, ) self.member_trainers = [] @@ -1297,11 +1328,9 @@ def __init__( lr=lr, custom_logger=self.logger, regularizer_factory=regularizer_factory, + rng=self.rng, ) self.member_trainers.append(reward_trainer) - self.rng = th.Generator() - if seed: - self.rng = self.rng.manual_seed(seed) @property def logger(self): @@ -1319,7 +1348,8 @@ def _train(self, dataset: PreferenceDataset, epoch_multiplier: float = 1.0) -> N dataset, replacement=True, num_samples=len(dataset), - generator=self.rng, + # we convert the numpy generator to the pytorch generator. + generator=th.Generator().manual_seed(util.make_seeds(self.rng)), ) for member_idx in range(len(self.member_trainers)): # sampler gives new indexes on every call @@ -1345,35 +1375,36 @@ def _train(self, dataset: PreferenceDataset, epoch_multiplier: float = 1.0) -> N self.logger.record(k + "_std", np.std(v)) -def is_base_model_ensemble(reward_model): +def get_base_model(reward_model: reward_nets.RewardNet) -> reward_nets.RewardNet: base_model = reward_model while hasattr(base_model, "base"): - base_model = base_model.base + base_model = cast(reward_nets.RewardNet, base_model.base) - return isinstance(base_model, reward_nets.RewardEnsemble), base_model + return base_model def _make_reward_trainer( preference_model: PreferenceModel, loss: RewardLoss, + rng: np.random.Generator, reward_trainer_kwargs: Optional[Mapping[str, Any]] = None, - seed: Optional[int] = None, ) -> RewardTrainer: """Construct the correct type of reward trainer for this reward function.""" if reward_trainer_kwargs is None: reward_trainer_kwargs = {} - if preference_model.is_ensemble: + if preference_model.ensemble_model is not None: return EnsembleTrainer( preference_model, loss, - seed=seed, + rng=rng, **reward_trainer_kwargs, ) else: return BasicRewardTrainer( preference_model, loss=loss, + rng=rng, **reward_trainer_kwargs, ) @@ -1403,7 +1434,7 @@ def __init__( initial_epoch_multiplier: float = 200.0, custom_logger: Optional[imit_logger.HierarchicalLogger] = None, allow_variable_horizon: bool = False, - seed: Optional[int] = None, + rng: Optional[np.random.Generator] = None, query_schedule: Union[str, type_aliases.Schedule] = "hyperbolic", ): """Initialize the preference comparison trainer. @@ -1458,7 +1489,8 @@ def __init__( condition, and can seriously confound evaluation. Read https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html before overriding this. - seed: seed to use for initializing subcomponents such as fragmenter. + rng: random number generator to use for initializing subcomponents such as + fragmenter. Only used when default components are used; if you instantiate your own fragmenter, preference gatherer, etc., you are responsible for seeding them! @@ -1485,14 +1517,35 @@ def __init__( self._iteration = 0 self.model = reward_model + self.rng = rng + + # are any of the optional args that require a rng None? + has_any_rng_args_none = None in ( + preference_gatherer, + fragmenter, + reward_trainer, + ) + + if self.rng is None and has_any_rng_args_none: + raise ValueError( + "If you don't provide a random state, you must provide your own " + "seeded fragmenter, preference gatherer, and reward_trainer. " + "You can initialize a random state with `np.random.default_rng(seed)`.", + ) + elif self.rng is not None and not has_any_rng_args_none: + raise ValueError( + "If you provide your own fragmenter, preference gatherer, " + "and reward trainer, you don't need to provide a random state.", + ) if reward_trainer is None: + assert self.rng is not None preference_model = PreferenceModel(reward_model) loss = CrossEntropyRewardLoss() self.reward_trainer = _make_reward_trainer( preference_model, loss, - seed=seed, + rng=self.rng, ) else: self.reward_trainer = reward_trainer @@ -1503,15 +1556,24 @@ def __init__( self.reward_trainer.logger = self.logger self.trajectory_generator = trajectory_generator self.trajectory_generator.logger = self.logger - self.fragmenter = fragmenter or RandomFragmenter( - custom_logger=self.logger, - seed=seed, - ) + if fragmenter: + self.fragmenter = fragmenter + else: + assert self.rng is not None + self.fragmenter = RandomFragmenter( + custom_logger=self.logger, + rng=self.rng, + ) self.fragmenter.logger = self.logger - self.preference_gatherer = preference_gatherer or SyntheticGatherer( - custom_logger=self.logger, - seed=seed, - ) + if preference_gatherer: + self.preference_gatherer = preference_gatherer + else: + assert self.rng is not None + self.preference_gatherer = SyntheticGatherer( + custom_logger=self.logger, + rng=self.rng, + ) + self.preference_gatherer.logger = self.logger self.fragment_length = fragment_length diff --git a/src/imitation/data/buffer.py b/src/imitation/data/buffer.py index 3f7f283d8..5554f403a 100644 --- a/src/imitation/data/buffer.py +++ b/src/imitation/data/buffer.py @@ -1,7 +1,7 @@ """Buffers to store NumPy arrays and transitions in.""" import dataclasses -from typing import Mapping, Optional, Tuple +from typing import Any, Mapping, Optional, Tuple import numpy as np from stable_baselines3.common import vec_env @@ -9,6 +9,25 @@ from imitation.data import types +def num_samples(data: Mapping[Any, np.ndarray]) -> int: + """Computes the number of samples contained in `data`. + + Args: + data: A Mapping from keys to NumPy arrays. + + Returns: + The unique length of the first dimension of arrays contained in `data`. + + Raises: + ValueError: The length is not unique. + """ + n_samples_list = [arr.shape[0] for arr in data.values()] + n_samples_np = np.unique(n_samples_list) + if len(n_samples_np) > 1: + raise ValueError("Keys map to different length values.") + return int(n_samples_np[0]) + + class Buffer: """A FIFO ring buffer for NumPy arrays of a fixed shape and dtype. @@ -111,8 +130,8 @@ def from_data( ValueError: `data` has items mapping to arrays differing in the length of their first axis. """ - data_capacities = [arr.shape[0] for arr in data.values()] - data_capacities = np.unique(data_capacities) + data_capacities_list = [arr.shape[0] for arr in data.values()] + data_capacities = np.unique(data_capacities_list) if len(data) == 0: raise ValueError("No keys in data.") if len(data_capacities) > 1: @@ -150,12 +169,7 @@ def store(self, data: Mapping[str, np.ndarray], truncate_ok: bool = False) -> No if len(unexpected_keys) > 0: raise ValueError(f"Unexpected keys {unexpected_keys}") - n_samples = [arr.shape[0] for arr in data.values()] - n_samples = np.unique(n_samples) - if len(n_samples) > 1: - raise ValueError("Keys map to different length values.") - n_samples = n_samples[0] - + n_samples = num_samples(data) if n_samples == 0: raise ValueError("Trying to store empty data.") if n_samples > self.capacity: @@ -192,11 +206,7 @@ def _store_easy(self, data: Mapping[str, np.ndarray]) -> None: data: Same as in `self.store`'s docstring, except with the additional constraint `size(data) <= self.capacity - self._idx`. """ - n_samples = [arr.shape[0] for arr in data.values()] - n_samples = np.unique(n_samples) - assert len(n_samples) == 1 - n_samples = n_samples[0] - + n_samples = num_samples(data) assert n_samples <= self.capacity - self._idx idx_hi = self._idx + n_samples for k, arr in data.items(): @@ -222,7 +232,7 @@ def sample(self, n_samples: int) -> Mapping[str, np.ndarray]: ind = np.random.randint(self.size(), size=n_samples) return {k: buffer[ind] for k, buffer in self._arrays.items()} - def size(self) -> Optional[int]: + def size(self) -> int: """Returns the number of samples stored in the buffer.""" assert 0 <= self._n_data <= self.capacity return self._n_data @@ -250,7 +260,7 @@ def __init__( capacity: The number of samples that can be stored. venv: The environment whose action and observation spaces can be used to determine the data shapes of the underlying - buffers. Overrides all the following arguments. + buffers. Mutually exclusive with shape and dtype arguments. obs_shape: The shape of the observation space. act_shape: The shape of the action space. obs_dtype: The dtype of the observation space. @@ -259,19 +269,27 @@ def __init__( Raises: ValueError: Couldn't infer the observation and action shapes and dtypes from the arguments. + ValueError: Specified both venv and shapes/dtypes. """ - params = [obs_shape, act_shape, obs_dtype, act_dtype] + params = (obs_shape, act_shape, obs_dtype, act_dtype) if venv is not None: - if np.any([x is not None for x in params]): - raise ValueError("Specified shape or dtype and environment.") + if not all(x is None for x in params): + raise ValueError( + "Cannot specify both shape/dtype and also environment.", + ) obs_shape = tuple(venv.observation_space.shape) act_shape = tuple(venv.action_space.shape) obs_dtype = venv.observation_space.dtype act_dtype = venv.action_space.dtype else: - if np.any([x is None for x in params]): + if any(x is None for x in params): raise ValueError("Shape or dtype missing and no environment specified.") + assert obs_shape is not None + assert act_shape is not None + assert obs_dtype is not None + assert act_dtype is not None + self.capacity = capacity sample_shapes = { "obs": obs_shape, @@ -284,8 +302,8 @@ def __init__( "obs": obs_dtype, "acts": act_dtype, "next_obs": obs_dtype, - "dones": bool, - "infos": np.object, + "dones": np.dtype(bool), + "infos": np.dtype(object), } self._buffer = Buffer(capacity, sample_shapes=sample_shapes, dtypes=dtypes) diff --git a/src/imitation/data/rollout.py b/src/imitation/data/rollout.py index 310c11ca7..8148a61e5 100644 --- a/src/imitation/data/rollout.py +++ b/src/imitation/data/rollout.py @@ -3,7 +3,17 @@ import collections import dataclasses import logging -from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Union +from typing import ( + Any, + Callable, + Dict, + Hashable, + List, + Mapping, + Optional, + Sequence, + Union, +) import numpy as np from stable_baselines3.common.base_class import BaseAlgorithm @@ -28,7 +38,12 @@ def unwrap_traj(traj: types.TrajectoryWithRew) -> types.TrajectoryWithRew: Returns: A copy of `traj` with replaced `obs` and `rews` fields. + + Raises: + ValueError: If `traj.infos` is None """ + if traj.infos is None: + raise ValueError("Trajectory must have infos to unwrap") ep_info = traj.infos[-1]["rollout"] res = dataclasses.replace(traj, obs=ep_info["obs"], rews=ep_info["rews"]) assert len(res.obs) == len(res.acts) + 1 @@ -52,7 +67,7 @@ def __init__(self): def add_step( self, - step_dict: Mapping[str, np.ndarray], + step_dict: Mapping[str, Union[np.ndarray, Mapping[str, Any]]], key: Hashable = None, ) -> None: """Add a single step to the partial trajectory identified by `key`. @@ -88,11 +103,10 @@ def finish_trajectory( del self.partial_trajectories[key] out_dict_unstacked = collections.defaultdict(list) for part_dict in part_dicts: - for key, array in part_dict.items(): - out_dict_unstacked[key].append(array) + for k, array in part_dict.items(): + out_dict_unstacked[k].append(array) out_dict_stacked = { - key: np.stack(arr_list, axis=0) - for key, arr_list in out_dict_unstacked.items() + k: np.stack(arr_list, axis=0) for k, arr_list in out_dict_unstacked.items() } traj = types.TrajectoryWithRew(**out_dict_stacked, terminal=terminal) assert traj.rews.shape[0] == traj.acts.shape[0] == traj.obs.shape[0] - 1 @@ -125,7 +139,7 @@ def add_steps_and_auto_finish( A list of completed trajectories. There should be one trajectory for each `True` in the `dones` argument. """ - trajs = [] + trajs: List[types.TrajectoryWithRew] = [] for env_idx in range(len(obs)): assert env_idx in self.partial_trajectories assert list(self.partial_trajectories[env_idx][0].keys()) == ["obs"], ( @@ -281,7 +295,7 @@ def get_actions(states): ) return acts - elif isinstance(policy, Callable): + elif callable(policy): # When a policy callable is passed, by default we will use it directly. # We are not able to change the determinism of the policy when it is a # callable that only takes in the states. @@ -309,9 +323,9 @@ def generate_trajectories( policy: AnyPolicy, venv: VecEnv, sample_until: GenTrajTerminationFn, + rng: np.random.Generator, *, deterministic_policy: bool = False, - rng: np.random.RandomState = np.random, ) -> Sequence[types.TrajectoryWithRew]: """Generate trajectory dictionaries from a policy and an environment. @@ -358,9 +372,11 @@ def generate_trajectories( # # To start with, all environments are active. active = np.ones(venv.num_envs, dtype=bool) + assert isinstance(obs, np.ndarray), "Dict/tuple observations are not supported." while np.any(active): acts = get_actions(obs) obs, rews, dones, infos = venv.step(acts) + assert isinstance(obs, np.ndarray) # If an environment is inactive, i.e. the episode completed for that # environment after `sample_until(trajectories)` was true, then we do @@ -389,7 +405,7 @@ def generate_trajectories( # `trajectories` sooner. Shuffle to avoid bias in order. This is important # when callees end up truncating the number of trajectories or transitions. # It is also cheap, since we're just shuffling pointers. - rng.shuffle(trajectories) + rng.shuffle(trajectories) # type: ignore[arg-type] # Sanity checks. for trajectory in trajectories: @@ -474,7 +490,7 @@ def flatten_trajectories( The trajectories flattened into a single batch of Transitions. """ keys = ["obs", "next_obs", "acts", "dones", "infos"] - parts = {key: [] for key in keys} + parts: Mapping[str, List[np.ndarray]] = {key: [] for key in keys} for traj in trajectories: parts["acts"].append(traj.acts) @@ -512,6 +528,7 @@ def generate_transitions( policy: AnyPolicy, venv: VecEnv, n_timesteps: int, + rng: np.random.Generator, *, truncate: bool = True, **kwargs, @@ -526,6 +543,7 @@ def generate_transitions( - None, in which case actions will be sampled randomly venv: The vectorized environments to interact with. n_timesteps: The minimum number of timesteps to sample. + rng: The random state to use for sampling trajectories. truncate: If True, then drop any additional samples to ensure that exactly `n_timesteps` samples are returned. **kwargs: Passed-through to generate_trajectories. @@ -539,6 +557,7 @@ def generate_transitions( policy, venv, sample_until=make_min_timesteps(n_timesteps), + rng=rng, **kwargs, ) transitions = flatten_trajectories_with_rew(traj) @@ -553,6 +572,7 @@ def rollout( policy: AnyPolicy, venv: VecEnv, sample_until: GenTrajTerminationFn, + rng: np.random.Generator, *, unwrap: bool = True, exclude_infos: bool = True, @@ -571,6 +591,7 @@ def rollout( 3) None, in which case actions will be sampled randomly. venv: The vectorized environments. sample_until: End condition for rollout sampling. + rng: Random state to use for sampling. unwrap: If True, then save original observations and rewards (instead of potentially wrapped observations and rewards) by calling `unwrap_traj()`. @@ -585,7 +606,13 @@ def rollout( may be collected to avoid biasing process towards short episodes; the user should truncate if required. """ - trajs = generate_trajectories(policy, venv, sample_until, **kwargs) + trajs = generate_trajectories( + policy, + venv, + sample_until, + rng=rng, + **kwargs, + ) if unwrap: trajs = [unwrap_traj(traj) for traj in trajs] if exclude_infos: diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index 363d86b0e..c0801446a 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -5,7 +5,18 @@ import os import pathlib import warnings -from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, TypeVar, Union, cast +from typing import ( + Any, + Dict, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, + overload, +) import numpy as np import torch as th @@ -169,8 +180,11 @@ def transitions_collate_fn( return result +TransitionsMinimalSelf = TypeVar("TransitionsMinimalSelf", bound="TransitionsMinimal") + + @dataclasses.dataclass(frozen=True) -class TransitionsMinimal(th_data.Dataset): +class TransitionsMinimal(th_data.Dataset, Sequence[Mapping[str, np.ndarray]]): """A Torch-compatible `Dataset` of obs-act transitions. This class and its subclasses are usually instantiated via @@ -200,7 +214,7 @@ class TransitionsMinimal(th_data.Dataset): infos: np.ndarray """Array of info dicts. Shape: (batch_size,).""" - def __len__(self): + def __len__(self) -> int: """Returns number of transitions. Always positive.""" return len(self.obs) @@ -239,6 +253,14 @@ def __post_init__(self): # def __getitem__(self, key: int) -> Mapping[str, np.ndarray]: # pass # pragma: no cover + @overload + def __getitem__(self, key: int) -> Mapping[str, np.ndarray]: + pass + + @overload + def __getitem__(self: TransitionsMinimalSelf, key: slice) -> TransitionsMinimalSelf: + pass + def __getitem__(self, key): """See TransitionsMinimal docstring for indexing and slicing semantics.""" d = dataclass_quick_asdict(self) @@ -345,22 +367,25 @@ def load(path: AnyPath) -> Sequence[Trajectory]: # .npz format and the old pickle based format. To tell the difference we need to # look at the type of the resulting object. If it's the new compressed format, # it should be a Mapping that we need to decode, whereas if it's the old format - # it's just the sequence of trajectories and we can return it directly. + # it's just the sequence of trajectories, and we can return it directly. data = np.load(path, allow_pickle=True) if isinstance(data, Sequence): # old format warnings.warn("Loading old version of Trajectory's", DeprecationWarning) return data elif isinstance(data, Mapping): # new format num_trajs = len(data["indices"]) - fields = ( + fields = [ # Account for the extra obs in each trajectory np.split(data["obs"], data["indices"] + np.arange(num_trajs) + 1), np.split(data["acts"], data["indices"]), np.split(data["infos"], data["indices"]), data["terminal"], - ) + ] if "rews" in data: - fields += (np.split(data["rews"], data["indices"]),) + fields = [ + *fields, + np.split(data["rews"], data["indices"]), + ] return [TrajectoryWithRew(*args) for args in zip(*fields)] else: return [Trajectory(*args) for args in zip(*fields)] @@ -395,9 +420,9 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]): ValueError: If the trajectories are not all of the same type, i.e. some are `Trajectory` and others are `TrajectoryWithRew`. """ - p = pathlib.Path(path) + p = pathlib.Path(path_to_str(path)) p.parent.mkdir(parents=True, exist_ok=True) - tmp_path = f"{path}.tmp" + tmp_path = f"{p}.tmp" infos = [ # Replace 'None' values for `infos`` with array of empty dicts @@ -423,5 +448,5 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]): np.savez_compressed(f, **condensed) # Ensure atomic write - os.replace(tmp_path, path) - logging.info(f"Dumped demonstrations to {path}.") + os.replace(tmp_path, p) + logging.info(f"Dumped demonstrations to {p}.") diff --git a/src/imitation/data/wrappers.py b/src/imitation/data/wrappers.py index cd5b4ac26..09ad42247 100644 --- a/src/imitation/data/wrappers.py +++ b/src/imitation/data/wrappers.py @@ -1,9 +1,10 @@ """Environment wrappers for collecting rollouts.""" -from typing import Sequence, Tuple +from typing import List, Optional, Sequence, Tuple import gym import numpy as np +import numpy.typing as npt from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper from imitation.data import rollout, types @@ -15,6 +16,14 @@ class BufferingWrapper(VecEnvWrapper): Retrieve saved transitions using `pop_transitions()`. """ + error_on_premature_event: bool + _trajectories: List[types.TrajectoryWithRew] + _ep_lens: List[int] + _init_reset: bool + _traj_accum: Optional[rollout.TrajectoryAccumulator] + _timesteps: Optional[npt.NDArray[np.int_]] + n_transitions: Optional[int] + def __init__(self, venv: VecEnv, error_on_premature_reset: bool = True): """Builds BufferingWrapper. @@ -81,6 +90,7 @@ def step_wait(self): def _finish_partial_trajectories(self) -> Sequence[types.TrajectoryWithRew]: """Finishes and returns partial trajectories in `self._traj_accum`.""" + assert self._traj_accum is not None trajs = [] for i in range(self.num_envs): # Check that we have any transitions at all. diff --git a/src/imitation/envs/examples/model_envs.py b/src/imitation/envs/examples/model_envs.py index e5687b192..fa4747764 100644 --- a/src/imitation/envs/examples/model_envs.py +++ b/src/imitation/envs/examples/model_envs.py @@ -12,7 +12,7 @@ def make_random_trans_mat( n_states, n_actions, max_branch_factor, - rand_state=np.random, + rng: np.random.Generator, ) -> np.ndarray: """Make a 'random' transition matrix. @@ -28,7 +28,7 @@ def make_random_trans_mat( n_actions: Number of actions. max_branch_factor: Maximum number of states that can be reached from each state-action pair. - rand_state: NumPy random state. + rng: NumPy random state. Returns: The transition matrix `mat`, where `mat[s,a,next_s]` gives the probability @@ -39,26 +39,26 @@ def make_random_trans_mat( for action in range(n_actions): # uniformly sample a number of successors in [1,max_branch_factor] # for this action - succs = rand_state.randint(1, max_branch_factor + 1) - next_states = rand_state.choice(n_states, size=(succs,), replace=False) + succs = rng.integers(1, max_branch_factor + 1) + next_states = rng.choice(n_states, size=(succs,), replace=False) # generate random vec in probability simplex - next_vec = rand_state.dirichlet(np.ones((succs,))) + next_vec = rng.dirichlet(np.ones((succs,))) next_vec = next_vec / np.sum(next_vec) out_mat[start_state, action, next_states] = next_vec return out_mat -def make_random_state_dist( +def make_rng_dist( n_avail: int, n_states: int, - rand_state: np.random.RandomState = np.random, + rng: np.random.Generator, ) -> np.ndarray: """Make a random initial state distribution over n_states. Args: n_avail: Number of states available to transition into. n_states: Total number of states. - rand_state: NumPy random state. + rng: NumPy random state. Returns: An initial state distribution that is zero at all but a uniformly random @@ -71,8 +71,8 @@ def make_random_state_dist( """ # noqa: DAR402 assert 0 < n_avail <= n_states init_dist = np.zeros((n_states,)) - next_states = rand_state.choice(n_states, size=(n_avail,), replace=False) - avail_state_dist = rand_state.dirichlet(np.ones((n_avail,))) + next_states = rng.choice(n_states, size=(n_avail,), replace=False) + avail_state_dist = rng.dirichlet(np.ones((n_avail,))) init_dist[next_states] = avail_state_dist assert np.sum(init_dist > 0) == n_avail init_dist = init_dist / np.sum(init_dist) @@ -83,7 +83,7 @@ def make_obs_mat( n_states: int, is_random: bool, obs_dim: Optional[int], - rand_state: np.random.RandomState = np.random, + rng: np.random.Generator, ) -> np.ndarray: """Makes an observation matrix with a single observation for each state. @@ -94,16 +94,16 @@ def make_obs_mat( If `False`, are unique one-hot vectors for each state. obs_dim (int or NoneType): Must be `None` if `is_random == False`. Otherwise, this must be set to the size of the random vectors. - rand_state (np.random.RandomState): Random number generator. + rng (np.random.Generator): Random number generator. Returns: A matrix of shape `(n_states, obs_dim if is_random else n_states)`. """ - if not is_random: - assert obs_dim is None if is_random: - obs_mat = rand_state.normal(0, 2, (n_states, obs_dim)) + assert obs_dim is not None + obs_mat = rng.normal(0, 2, (n_states, obs_dim)) else: + assert obs_dim is None obs_mat = np.identity(n_states) assert ( obs_mat.ndim == 2 and obs_mat.shape[:1] == (n_states,) and obs_mat.shape[1] > 0 @@ -145,7 +145,7 @@ def __init__( super().__init__() # this generator is ONLY for constructing the MDP, not for controlling # random outcomes during rollouts - rand_gen = np.random.RandomState(generator_seed) + rng = np.random.default_rng(generator_seed) if random_obs: if obs_dim is None: obs_dim = n_states @@ -155,21 +155,21 @@ def __init__( n_states=n_states, is_random=random_obs, obs_dim=obs_dim, - rand_state=rand_gen, + rng=rng, ) self._transition_matrix = make_random_trans_mat( n_states=n_states, n_actions=n_actions, max_branch_factor=branch_factor, - rand_state=rand_gen, + rng=rng, ) - self._initial_state_dist = make_random_state_dist( + self._initial_state_dist = make_rng_dist( n_avail=branch_factor, n_states=n_states, - rand_state=rand_gen, + rng=rng, ) self._horizon = horizon - self._reward_weights = rand_gen.randn(self._observation_matrix.shape[-1]) + self._reward_weights = rng.standard_normal(self._observation_matrix.shape[-1]) self._reward_matrix = self._observation_matrix @ self._reward_weights assert self._reward_matrix.shape == (self.n_states,) diff --git a/src/imitation/envs/resettable_env.py b/src/imitation/envs/resettable_env.py index 0e8780d13..877003bca 100644 --- a/src/imitation/envs/resettable_env.py +++ b/src/imitation/envs/resettable_env.py @@ -27,7 +27,7 @@ def __init__(self): self._action_space = None self.cur_state = None self._n_actions_taken = None - self.rand_state: Optional[np.random.RandomState] = None + self.rng: Optional[np.random.Generator] = None self.seed() @abc.abstractmethod @@ -113,7 +113,7 @@ def seed(self, seed=None): # Gym API wants list of seeds to be returned for some reason, so # generate a seed explicitly in this case seed = np.random.randint(0, 1 << 31) - self.rand_state = np.random.RandomState(seed) + self.rng = np.random.default_rng(seed) return [seed] def reset(self): @@ -177,12 +177,12 @@ def action_space(self) -> gym.Space: return self._action_space def initial_state(self): - return self.rand_state.choice(self.n_states, p=self.initial_state_dist) + return self.rng.choice(self.n_states, p=self.initial_state_dist) def transition(self, state, action): out_dist = self.transition_matrix[state, action] choice_states = np.arange(self.n_states) - return int(self.rand_state.choice(choice_states, p=out_dist, size=())) + return int(self.rng.choice(choice_states, p=out_dist, size=())) def reward(self, state, action, new_state): reward = self.reward_matrix[state] diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 3f9b0d919..f0ee6d588 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -105,7 +105,12 @@ def __init__( e.g. `nn.BatchNorm*` or `nn.LayerNorm`. """ super().__init__(observation_space) - self.normalize = normalize_class(self.features_dim) + # Below we have to ignore the type error when initializing the class because + # there is no simple way of specifying a protocol that admits one positional + # argument for the number of features while being compatible with nn.Module. + # (it would require defining a base class and forcing all the subclasses + # to inherit from it). + self.normalize = normalize_class(self.features_dim) # type: ignore[call-arg] def forward(self, observations: th.Tensor) -> th.Tensor: flattened = super().forward(observations) diff --git a/src/imitation/policies/exploration_wrapper.py b/src/imitation/policies/exploration_wrapper.py index 651476278..2a7da8060 100644 --- a/src/imitation/policies/exploration_wrapper.py +++ b/src/imitation/policies/exploration_wrapper.py @@ -1,11 +1,10 @@ """Wrapper to turn a policy into a more exploratory version.""" -from typing import Optional - import numpy as np from stable_baselines3.common import vec_env from imitation.data import rollout +from imitation.util import util class ExplorationWrapper: @@ -25,7 +24,7 @@ def __init__( venv: vec_env.VecEnv, random_prob: float, switch_prob: float, - seed: Optional[int] = None, + rng: np.random.Generator, ): """Initializes the ExplorationWrapper. @@ -34,14 +33,16 @@ def __init__( venv: The environment to use (needed for sampling random actions). random_prob: The probability of picking the random policy when switching. switch_prob: The probability of switching away from the current policy. - seed: The random seed to use. + rng: The random state to use for seeding the environment and for + switching policies. """ self.wrapped_policy = policy_callable self.random_prob = random_prob self.switch_prob = switch_prob self.venv = venv - self.rng = np.random.RandomState(seed) + self.rng = rng + seed = util.make_seeds(self.rng) self.venv.action_space.seed(seed) self.current_policy = policy_callable @@ -54,13 +55,13 @@ def _random_policy(self, obs: np.ndarray) -> np.ndarray: def _switch(self) -> None: """Pick a new policy at random.""" - if self.rng.rand() < self.random_prob: + if self.rng.random() < self.random_prob: self.current_policy = self._random_policy else: self.current_policy = self.wrapped_policy def __call__(self, obs: np.ndarray) -> np.ndarray: acts = self.current_policy(obs) - if self.rng.rand() < self.switch_prob: + if self.rng.random() < self.switch_prob: self._switch() return acts diff --git a/src/imitation/policies/replay_buffer_wrapper.py b/src/imitation/policies/replay_buffer_wrapper.py index 9bb011063..ab95c18f5 100644 --- a/src/imitation/policies/replay_buffer_wrapper.py +++ b/src/imitation/policies/replay_buffer_wrapper.py @@ -1,11 +1,10 @@ """Wrapper for reward labeling for transitions sampled from a replay buffer.""" - from typing import Mapping, Type import numpy as np from gym import spaces -from stable_baselines3.common.buffers import BaseBuffer, ReplayBuffer +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.type_aliases import ReplayBufferSamples from imitation.rewards.reward_function import RewardFn @@ -24,7 +23,7 @@ def _samples_to_reward_fn_input( ) -class ReplayBufferRewardWrapper(BaseBuffer): +class ReplayBufferRewardWrapper(ReplayBuffer): """Relabel the rewards in transitions sampled from a ReplayBuffer.""" def __init__( @@ -63,16 +62,20 @@ def __init__( _base_kwargs = {k: v for k, v in kwargs.items() if k in ["device", "n_envs"]} super().__init__(buffer_size, observation_space, action_space, **_base_kwargs) - @property - def pos(self) -> int: + # TODO(juan) remove the type ignore once the merged PR + # https://github.com/python/mypy/pull/13475 + # is released into a mypy version on pypi. + + @property # type: ignore[override] + def pos(self) -> int: # type: ignore[override] return self.replay_buffer.pos @pos.setter def pos(self, pos: int): self.replay_buffer.pos = pos - @property - def full(self) -> bool: + @property # type: ignore[override] + def full(self) -> bool: # type: ignore[override] return self.replay_buffer.full @full.setter diff --git a/src/imitation/policies/serialize.py b/src/imitation/policies/serialize.py index de2b80859..15b408412 100644 --- a/src/imitation/policies/serialize.py +++ b/src/imitation/policies/serialize.py @@ -19,6 +19,8 @@ # Note: a VecEnv will always be passed first and then any kwargs. There is just no # proper way to specify this in python yet. For details see # https://stackoverflow.com/questions/61569324/type-annotation-for-callable-that-takes-kwargs +# TODO(juan) this can be fixed using ParamSpec +# (https://github.com/HumanCompatibleAI/imitation/issues/574) PolicyLoaderFn = Callable[..., policies.BasePolicy] """A policy loader function that takes a VecEnv before any other custom arguments and returns a stable_baselines3 base policy policy.""" @@ -50,24 +52,24 @@ def load_stable_baselines_model( The deserialized RL algorithm. """ logging.info(f"Loading Stable Baselines policy for '{cls}' from '{path}'") - path = pathlib.Path(path) + path_obj = pathlib.Path(path) - if path.is_dir(): - path = path / "model.zip" - if not path.exists(): + if path_obj.is_dir(): + path_obj = path_obj / "model.zip" + if not path_obj.exists(): raise FileNotFoundError( f"Expected '{path}' to be a directory containing a 'model.zip' file.", ) # SOMEDAY(adam): added 2022-01, can probably remove this check in 2023 - vec_normalize_path = path.parent / "vec_normalize.pkl" + vec_normalize_path = path_obj.parent / "vec_normalize.pkl" if vec_normalize_path.exists(): raise FileExistsError( "Outdated policy format: we do not support restoring normalization " "statistics from '{vec_normalize_path}'", ) - return cls.load(path, env=venv, **kwargs) + return cls.load(path_obj, env=venv, **kwargs) def _load_stable_baselines_from_file( @@ -224,6 +226,7 @@ def __init__( self.policy_dir = policy_dir def _on_step(self) -> bool: + assert self.model is not None output_dir = os.path.join(self.policy_dir, f"{self.num_timesteps:012d}") save_stable_model(output_dir, self.model) return True diff --git a/src/imitation/rewards/reward_nets.py b/src/imitation/rewards/reward_nets.py index 41a7e6f82..b712f5af1 100644 --- a/src/imitation/rewards/reward_nets.py +++ b/src/imitation/rewards/reward_nets.py @@ -1,7 +1,7 @@ """Constructs deep network reward models.""" import abc -from typing import Callable, Iterable, Optional, Sequence, Tuple, Type +from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Type, cast import gym import numpy as np @@ -83,20 +83,31 @@ def preprocess( del state, action, next_state, done # unused # preprocess - state_th = preprocessing.preprocess_obs( - state_th, - self.observation_space, - self.normalize_images, + # we only support array spaces, so we cast + # the observation to torch tensors. + state_th = cast( + th.Tensor, + preprocessing.preprocess_obs( + state_th, + self.observation_space, + self.normalize_images, + ), ) - action_th = preprocessing.preprocess_obs( - action_th, - self.action_space, - self.normalize_images, + action_th = cast( + th.Tensor, + preprocessing.preprocess_obs( + action_th, + self.action_space, + self.normalize_images, + ), ) - next_state_th = preprocessing.preprocess_obs( - next_state_th, - self.observation_space, - self.normalize_images, + next_state_th = cast( + th.Tensor, + preprocessing.preprocess_obs( + next_state_th, + self.observation_space, + self.normalize_images, + ), ) done_th = done_th.to(th.float32) @@ -416,7 +427,7 @@ def __init__( if self.use_done: combined_size += 1 - full_build_mlp_kwargs = { + full_build_mlp_kwargs: Dict[str, Any] = { "hid_sizes": (32, 32), **kwargs, # we do not want the values below to be overridden @@ -518,7 +529,7 @@ def __init__( if self.use_done: output_size *= 2 - full_build_cnn_kwargs = { + full_build_cnn_kwargs: Dict[str, Any] = { "hid_channels": (32, 32), **kwargs, # we do not want the values below to be overridden @@ -602,7 +613,7 @@ class NormalizedRewardNet(PredictProcessedWrapper): def __init__( self, base: RewardNet, - normalize_output_layer: Type[nn.Module], + normalize_output_layer: Type[networks.BaseNorm], ): """Initialize the NormalizedRewardNet. @@ -914,7 +925,7 @@ def predict_processed_all( next_state: np.ndarray, done: np.ndarray, **kwargs, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> np.ndarray: """Get the results of predict processed on all of the members. Args: @@ -929,11 +940,11 @@ def predict_processed_all( shape `(batch_size, num_members)`. """ batch_size = state.shape[0] - rewards = [ + rewards_list = [ member.predict_processed(state, action, next_state, done, **kwargs) for member in self.members ] - rewards = np.stack(rewards, axis=-1) + rewards: np.ndarray = np.stack(rewards_list, axis=-1) assert rewards.shape == (batch_size, self.num_members) return rewards diff --git a/src/imitation/rewards/reward_wrapper.py b/src/imitation/rewards/reward_wrapper.py index b333cfbb8..7afa551b3 100644 --- a/src/imitation/rewards/reward_wrapper.py +++ b/src/imitation/rewards/reward_wrapper.py @@ -4,7 +4,9 @@ from typing import Deque import numpy as np -from stable_baselines3.common import callbacks, vec_env +from stable_baselines3.common import callbacks +from stable_baselines3.common import logger as sb_logger +from stable_baselines3.common import vec_env from imitation.rewards import reward_function @@ -21,7 +23,7 @@ def __init__(self, episode_rewards: Deque[float], *args, **kwargs): **kwargs: Passed through to `callbacks.BaseCallback`. """ self.episode_rewards = episode_rewards - super().__init__(self, *args, **kwargs) + super().__init__(*args, **kwargs) def _on_step(self) -> bool: return True @@ -30,6 +32,7 @@ def _on_rollout_start(self) -> None: if len(self.episode_rewards) == 0: return mean = sum(self.episode_rewards) / len(self.episode_rewards) + assert isinstance(self.logger, sb_logger.Logger) self.logger.record("rollout/ep_rew_wrapped_mean", mean) @@ -62,7 +65,7 @@ def __init__( """ assert not isinstance(venv, RewardVecEnvWrapper) super().__init__(venv) - self.episode_rewards = collections.deque(maxlen=ep_history) + self.episode_rewards: Deque = collections.deque(maxlen=ep_history) self._cumulative_rew = np.zeros((venv.num_envs,)) self.reward_fn = reward_fn self._old_obs = None diff --git a/src/imitation/rewards/serialize.py b/src/imitation/rewards/serialize.py index 54b674e6f..07817291b 100644 --- a/src/imitation/rewards/serialize.py +++ b/src/imitation/rewards/serialize.py @@ -1,19 +1,12 @@ """Load serialized reward functions of different types.""" -from typing import Any, Callable, Iterable, Sequence, Type, Union +from typing import Any, Callable, Iterable, Sequence, Type, Union, cast import numpy as np import torch as th from stable_baselines3.common.vec_env import VecEnv -from imitation.rewards import reward_function -from imitation.rewards.reward_nets import ( - AddSTDRewardWrapper, - NormalizedRewardNet, - RewardNet, - RewardNetWrapper, - ShapedRewardNet, -) +from imitation.rewards import reward_function, reward_nets from imitation.util import registry, util # TODO(sam): I suspect this whole file can be replaced with th.load calls. Try @@ -55,9 +48,9 @@ def __call__( def _strip_wrappers( - reward_net: RewardNet, - wrapper_types: Iterable[Type[RewardNetWrapper]], -) -> RewardNet: + reward_net: reward_nets.RewardNet, + wrapper_types: Iterable[Type[reward_nets.RewardNetWrapper]], +) -> reward_nets.RewardNet: """Attempts to remove provided wrappers. Strips wrappers of type `wrapper_type` from `reward_net` in order until either the @@ -74,7 +67,7 @@ def _strip_wrappers( for wrapper_type in wrapper_types: assert issubclass( wrapper_type, - RewardNetWrapper, + reward_nets.RewardNetWrapper, ), f"trying to remove non-wrapper type {wrapper_type}" if isinstance(reward_net, wrapper_type): @@ -86,7 +79,7 @@ def _strip_wrappers( def _make_functional( - net: RewardNet, + net: reward_nets.RewardNet, attr: str = "predict", default_kwargs=None, **kwargs, @@ -97,7 +90,7 @@ def _make_functional( return lambda *args: getattr(net, attr)(*args, **default_kwargs) -WrapperPrefix = Sequence[Type[RewardNet]] +WrapperPrefix = Sequence[Type[reward_nets.RewardNet]] def _prefix_matches(wrappers: Sequence[Type[Any]], prefix: Sequence[Type[Any]]): @@ -120,9 +113,9 @@ def _prefix_matches(wrappers: Sequence[Type[Any]], prefix: Sequence[Type[Any]]): def _validate_wrapper_structure( - reward_net: Union[RewardNet, RewardNetWrapper], + reward_net: Union[reward_nets.RewardNet, reward_nets.RewardNetWrapper], prefixes: Iterable[WrapperPrefix], -) -> RewardNet: +) -> reward_nets.RewardNet: """Reward net if it has a valid structure. A wrapper prefix specifies, from outermost to innermost, which wrappers must @@ -155,7 +148,7 @@ def _validate_wrapper_structure( wrappers = [] while hasattr(wrapper, "base"): wrappers.append(wrapper.__class__) - wrapper = wrapper.base + wrapper = cast(reward_nets.RewardNet, wrapper.base) wrappers.append(wrapper.__class__) # append the final reward net if any(_prefix_matches(wrappers, prefix) for prefix in prefixes): @@ -198,16 +191,20 @@ def f( key="RewardNet_shaped", value=lambda path, _, **kwargs: ValidateRewardFn( _make_functional( - _validate_wrapper_structure(th.load(str(path)), {(ShapedRewardNet,)}), + _validate_wrapper_structure( + th.load(str(path)), + {(reward_nets.ShapedRewardNet,)}, + ), ), ), ) - reward_registry.register( key="RewardNet_unshaped", value=lambda path, _, **kwargs: ValidateRewardFn( - _make_functional(_strip_wrappers(th.load(str(path)), (ShapedRewardNet,))), + _make_functional( + _strip_wrappers(th.load(str(path)), (reward_nets.ShapedRewardNet,)), + ), ), ) @@ -215,7 +212,10 @@ def f( key="RewardNet_normalized", value=lambda path, _, **kwargs: ValidateRewardFn( _make_functional( - _validate_wrapper_structure(th.load(str(path)), {(NormalizedRewardNet,)}), + _validate_wrapper_structure( + th.load(str(path)), + {(reward_nets.NormalizedRewardNet,)}, + ), attr="predict_processed", default_kwargs={"update_stats": False}, **kwargs, @@ -226,7 +226,9 @@ def f( reward_registry.register( key="RewardNet_unnormalized", value=lambda path, _, **kwargs: ValidateRewardFn( - _make_functional(_strip_wrappers(th.load(str(path)), (NormalizedRewardNet,))), + _make_functional( + _strip_wrappers(th.load(str(path)), (reward_nets.NormalizedRewardNet,)), + ), ), ) @@ -238,11 +240,14 @@ def f( _validate_wrapper_structure( th.load(str(path)), { - (AddSTDRewardWrapper,), - (NormalizedRewardNet, AddSTDRewardWrapper), + (reward_nets.AddSTDRewardWrapper,), + ( + reward_nets.NormalizedRewardNet, + reward_nets.AddSTDRewardWrapper, + ), }, ), - (NormalizedRewardNet,), + (reward_nets.NormalizedRewardNet,), ), attr="predict_processed", default_kwargs={}, @@ -251,7 +256,6 @@ def f( ), ) - reward_registry.register(key="zero", value=load_zero) diff --git a/src/imitation/scripts/analyze.py b/src/imitation/scripts/analyze.py index 74617891c..3787ceb88 100644 --- a/src/imitation/scripts/analyze.py +++ b/src/imitation/scripts/analyze.py @@ -9,7 +9,7 @@ import tempfile import warnings from collections import OrderedDict -from typing import Any, Callable, List, Mapping, Optional, Sequence, Set +from typing import Any, Callable, Iterable, List, Mapping, Optional, Sequence, Set import pandas as pd from sacred.observers import FileStorageObserver @@ -49,14 +49,15 @@ def _gather_sacred_dicts( sacred_dirs = itertools.chain.from_iterable( sacred_util.filter_subdirs(source_dir) for source_dir in source_dirs ) - sacred_dicts = [] + sacred_dicts_list = [] for sacred_dir in sacred_dirs: try: - sacred_dicts.append(sacred_util.SacredDicts.load_from_dir(sacred_dir)) + sacred_dicts_list.append(sacred_util.SacredDicts.load_from_dir(sacred_dir)) except json.JSONDecodeError: warnings.warn(f"Invalid JSON file in {sacred_dir}", RuntimeWarning) + sacred_dicts: Iterable = sacred_dicts_list if run_name is not None: sacred_dicts = filter( lambda sd: get(sd.run, "experiment.name") == run_name, @@ -217,7 +218,6 @@ def _return_summaries(sd: sacred_util.SacredDicts) -> dict: ], ) - # If `verbosity` is at least the length of this list, then we use all table_entry_fns # as columns of table. # Otherwise, use only the subset at index `verbosity`. The subset of columns is diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index c215460e2..532e8780b 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -3,8 +3,9 @@ import contextlib import logging import os -from typing import Any, Mapping, Sequence, Tuple, Union +from typing import Any, Generator, Mapping, Sequence, Tuple, Union +import numpy as np import sacred from stable_baselines3.common import vec_env @@ -74,6 +75,12 @@ def fast(): locals() # quieten flake8 +@common_ingredient.capture +def make_rng(_seed) -> np.random.Generator: + """Creates a `np.random.Generator` with the given seed.""" + return np.random.default_rng(_seed) + + @common_ingredient.capture def make_log_dir( _run, @@ -130,7 +137,6 @@ def setup_logging( @contextlib.contextmanager @common_ingredient.capture def make_venv( - _seed, env_name: str, num_vec: int, parallel: bool, @@ -138,7 +144,7 @@ def make_venv( max_episode_steps: int, env_make_kwargs: Mapping[str, Any], **kwargs, -) -> vec_env.VecEnv: +) -> Generator[vec_env.VecEnv, None, None]: """Builds the vector environment. Args: @@ -156,12 +162,13 @@ def make_venv( Yields: The constructed vector environment. """ + rng = make_rng() # Note: we create the venv outside the try -- finally block for the case that env # creation fails. venv = util.make_vec_env( env_name, - num_vec, - seed=_seed, + rng=rng, + n_envs=num_vec, parallel=parallel, max_episode_steps=max_episode_steps, log_dir=log_dir, diff --git a/src/imitation/scripts/common/demonstrations.py b/src/imitation/scripts/common/demonstrations.py index 8b5beda04..487ee8853 100644 --- a/src/imitation/scripts/common/demonstrations.py +++ b/src/imitation/scripts/common/demonstrations.py @@ -56,6 +56,7 @@ def generate_expert_trajs( Raises: ValueError: If n_expert_demos is None. """ + rng = common.make_rng() if n_expert_demos is None: raise ValueError("n_expert_demos must be specified when rollout_path is None") @@ -67,6 +68,7 @@ def generate_expert_trajs( expert.get_expert_policy(rollout_env), rollout_env, rollout.make_sample_until(min_episodes=n_expert_demos), + rng=rng, ) diff --git a/src/imitation/scripts/common/reward.py b/src/imitation/scripts/common/reward.py index 548bca855..c40d3751f 100644 --- a/src/imitation/scripts/common/reward.py +++ b/src/imitation/scripts/common/reward.py @@ -6,7 +6,6 @@ import sacred from stable_baselines3.common import vec_env -from torch import nn from imitation.rewards import reward_nets from imitation.util import networks @@ -85,7 +84,7 @@ def _make_reward_net( venv: vec_env.VecEnv, net_cls: Type[reward_nets.RewardNet], net_kwargs: Mapping[str, Any], - normalize_output_layer: Optional[Type[nn.Module]], + normalize_output_layer: Optional[Type[networks.BaseNorm]], ): """Helper function for creating reward nets.""" reward_net = net_cls( @@ -108,7 +107,7 @@ def make_reward_net( venv: vec_env.VecEnv, net_cls: Type[reward_nets.RewardNet], net_kwargs: Mapping[str, Any], - normalize_output_layer: Optional[Type[nn.Module]], + normalize_output_layer: Optional[Type[networks.BaseNorm]], add_std_alpha: Optional[float], ensemble_size: Optional[int], ensemble_member_config: Optional[Mapping[str, Any]], @@ -150,9 +149,14 @@ def make_reward_net( for _ in range(ensemble_size) ] - reward_net = net_cls(venv.observation_space, venv.action_space, members) + reward_net: reward_nets.RewardNet = net_cls( + venv.observation_space, + venv.action_space, + members, + ) if add_std_alpha is not None: + assert isinstance(reward_net, reward_nets.RewardNetWithVariance) reward_net = reward_nets.AddSTDRewardWrapper( reward_net, default_alpha=add_std_alpha, diff --git a/src/imitation/scripts/common/rl.py b/src/imitation/scripts/common/rl.py index dd1a7b311..2bd3759a2 100644 --- a/src/imitation/scripts/common/rl.py +++ b/src/imitation/scripts/common/rl.py @@ -17,9 +17,13 @@ from imitation.policies import serialize from imitation.policies.replay_buffer_wrapper import ReplayBufferRewardWrapper from imitation.rewards.reward_function import RewardFn +from imitation.scripts.common import common from imitation.scripts.common.train import train_ingredient -rl_ingredient = sacred.Ingredient("rl", ingredients=[train_ingredient]) +rl_ingredient = sacred.Ingredient( + "rl", + ingredients=[train_ingredient, common.common_ingredient], +) logger = logging.getLogger(__name__) @@ -168,11 +172,11 @@ def make_rl_algo( @rl_ingredient.capture def load_rl_algo_from_path( + _seed: int, agent_path: str, venv: vec_env.VecEnv, rl_cls: Type[base_class.BaseAlgorithm], rl_kwargs: Mapping[str, Any], - _seed: int, relabel_reward_fn: Optional[RewardFn] = None, ) -> base_class.BaseAlgorithm: rl_kwargs = dict(rl_kwargs) diff --git a/src/imitation/scripts/common/train.py b/src/imitation/scripts/common/train.py index f5aa3c1bb..bcbf55f59 100644 --- a/src/imitation/scripts/common/train.py +++ b/src/imitation/scripts/common/train.py @@ -9,8 +9,9 @@ import imitation.util.networks from imitation.data import rollout from imitation.policies import base +from imitation.scripts.common import common -train_ingredient = sacred.Ingredient("train") +train_ingredient = sacred.Ingredient("train", ingredients=[common.common_ingredient]) logger = logging.getLogger(__name__) @@ -92,6 +93,7 @@ def eval_policy( "monitor_return" key). "expert_stats" gives the return value of `rollout_stats()` on the expert demonstrations loaded from `rollout_path`. """ + rng = common.make_rng() sample_until_eval = rollout.make_min_episodes(n_episodes_eval) if isinstance(rl_algo, base_class.BaseAlgorithm): # Set RL algorithm's env to venv, removing any cruft wrappers that the RL @@ -107,6 +109,7 @@ def eval_policy( rl_algo, train_env, sample_until=sample_until_eval, + rng=rng, ) return rollout.rollout_stats(trajs) diff --git a/src/imitation/scripts/common/wb.py b/src/imitation/scripts/common/wb.py index 2689e16a2..e64edc978 100644 --- a/src/imitation/scripts/common/wb.py +++ b/src/imitation/scripts/common/wb.py @@ -49,15 +49,12 @@ def wandb_init( env_name = _run.config["common"]["env_name"] root_seed = _run.config["seed"] - updated_wandb_kwargs = {} - updated_wandb_kwargs.update(wandb_kwargs) - updated_wandb_kwargs.update( - dict( - name="-".join([wandb_name_prefix, env_name, f"seed{root_seed}"]), - tags=[env_name, f"seed{root_seed}"] + ([wandb_tag] if wandb_tag else []), - dir=log_dir, - ), - ) + updated_wandb_kwargs: Mapping[str, Any] = { + **wandb_kwargs, + "name": f"{wandb_name_prefix}-{env_name}-seed{root_seed}", + "tags": [env_name, f"seed{root_seed}"] + ([wandb_tag] if wandb_tag else []), + "dir": log_dir, + } try: import wandb except ModuleNotFoundError as e: diff --git a/src/imitation/scripts/eval_policy.py b/src/imitation/scripts/eval_policy.py index 4121a27b0..610882fd5 100644 --- a/src/imitation/scripts/eval_policy.py +++ b/src/imitation/scripts/eval_policy.py @@ -53,7 +53,6 @@ def f(env: gym.Env, i: int) -> gym.Env: @eval_policy_ex.main def eval_policy( _run, - _seed: int, eval_n_timesteps: Optional[int], eval_n_episodes: Optional[int], render: bool, @@ -67,7 +66,6 @@ def eval_policy( """Rolls a policy out in an environment, collecting statistics. Args: - _seed: generated by Sacred. eval_n_timesteps: Minimum number of timesteps to evaluate for. Set exactly one of `eval_n_episodes` and `eval_n_timesteps`. eval_n_episodes: Minimum number of episodes to evaluate for. Set exactly @@ -86,6 +84,7 @@ def eval_policy( Returns: Return value of `imitation.util.rollout.rollout_stats()`. """ + rng = common.make_rng() log_dir = common.make_log_dir() sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes) post_wrappers = [video_wrapper_factory(log_dir, **video_kwargs)] if videos else None @@ -102,6 +101,7 @@ def eval_policy( expert.get_expert_policy(venv), venv, sample_until, + rng=rng, ) if rollout_save_path: diff --git a/src/imitation/scripts/parallel.py b/src/imitation/scripts/parallel.py index 2d96d23bc..c30ad1149 100644 --- a/src/imitation/scripts/parallel.py +++ b/src/imitation/scripts/parallel.py @@ -3,7 +3,7 @@ import collections.abc import copy import os -from typing import Any, Callable, Mapping, Optional, Sequence +from typing import Any, Callable, Dict, Mapping, Optional, Sequence import ray import ray.tune @@ -165,7 +165,7 @@ def inner(config: Mapping[str, Any], reporter) -> Mapping[str, Any]: sacred.SETTINGS.CAPTURE_MODE = "sys" run_kwargs = config - updated_run_kwargs = {} + updated_run_kwargs: Dict[str, Any] = {} # Import inside function rather than in module because Sacred experiments # are not picklable, and Ray requires this function to be picklable. from imitation.scripts.train_adversarial import train_adversarial_ex @@ -179,14 +179,10 @@ def inner(config: Mapping[str, Any], reporter) -> Mapping[str, Any]: ex.observers = [FileStorageObserver("sacred")] # Apply base configs to get modified `named_configs` and `config_updates`. - named_configs = [] - named_configs.extend(base_named_configs) - named_configs.extend(run_kwargs["named_configs"]) + named_configs = base_named_configs + run_kwargs["named_configs"] updated_run_kwargs["named_configs"] = named_configs - config_updates = {} - config_updates.update(base_config_updates) - config_updates.update(run_kwargs["config_updates"]) + config_updates = {**base_config_updates, **run_kwargs["config_updates"]} updated_run_kwargs["config_updates"] = config_updates # Add other run_kwargs items to updated_run_kwargs. diff --git a/src/imitation/scripts/train_adversarial.py b/src/imitation/scripts/train_adversarial.py index 2910babdf..11db52341 100644 --- a/src/imitation/scripts/train_adversarial.py +++ b/src/imitation/scripts/train_adversarial.py @@ -60,14 +60,13 @@ def dummy_config(): algorithm_specific = {} # noqa: F841 -for ingredient in [train_adversarial_ex] + train_adversarial_ex.ingredients: +for ingredient in [train_adversarial_ex, *train_adversarial_ex.ingredients]: _add_hook(ingredient) @train_adversarial_ex.capture def train_adversarial( _run, - _seed: int, show_config: bool, algo_cls: Type[common.AdversarialTrainer], algorithm_kwargs: Mapping[str, Any], @@ -84,7 +83,6 @@ def train_adversarial( - Generator policies are saved to `f"{log_dir}/checkpoints/{step}/gen_policy/"`. Args: - _seed: Random seed. show_config: Print the merged config before starting training. This is analogous to the print_config command, but will show config after rather than before merging `algorithm_specific` arguments. diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 8d7085577..d7d7cdfe9 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -3,14 +3,14 @@ import logging import os.path as osp import warnings -from typing import Any, Mapping, Optional, Type +from typing import Any, Mapping, Optional, Sequence, Type, cast from sacred.observers import FileStorageObserver from stable_baselines3.common import policies, utils, vec_env from imitation.algorithms import bc as bc_algorithm from imitation.algorithms.dagger import SimpleDAggerTrainer -from imitation.data import rollout +from imitation.data import rollout, types from imitation.scripts.common import common, demonstrations, expert, train from imitation.scripts.config.train_imitation import train_imitation_ex @@ -49,6 +49,7 @@ def make_policy( "lr_schedule": utils.get_schedule_fn(1), }, ) + policy: policies.BasePolicy if agent_path is not None: warnings.warn( "When agent_path is specified, policy_cls and policy_kwargs are ignored.", @@ -84,12 +85,13 @@ def train_imitation( Returns: Statistics for rollouts from the trained policy and demonstration data. """ + rng = common.make_rng() custom_logger, log_dir = common.setup_logging() with common.make_venv() as venv: imit_policy = make_policy(venv, agent_path=agent_path) - expert_trajs = None + expert_trajs: Optional[Sequence[types.Trajectory]] = None if not use_dagger or dagger["use_offline_rollouts"]: expert_trajs = demonstrations.get_expert_trajectories() @@ -99,6 +101,7 @@ def train_imitation( policy=imit_policy, demonstrations=expert_trajs, custom_logger=custom_logger, + rng=rng, **bc_kwargs, ) bc_train_kwargs = dict(log_rollouts_venv=venv, **bc_train_kwargs) @@ -117,6 +120,7 @@ def train_imitation( expert_policy=expert_policy, custom_logger=custom_logger, bc_trainer=bc_trainer, + rng=rng, ) model.train( total_timesteps=int(dagger["total_timesteps"]), @@ -132,12 +136,15 @@ def train_imitation( imit_stats = train.eval_policy(imit_policy, venv) - return { - "imit_stats": imit_stats, - "expert_stats": rollout.rollout_stats( - model._all_demos if use_dagger else expert_trajs, - ), - } + stats = {"imit_stats": imit_stats} + trajectories = model._all_demos if use_dagger else expert_trajs + assert trajectories is not None + if all(isinstance(t, types.TrajectoryWithRew) for t in trajectories): + expert_stats = rollout.rollout_stats( + cast(Sequence[types.TrajectoryWithRew], trajectories), + ) + stats["expert_stats"] = expert_stats + return stats @train_imitation_ex.command diff --git a/src/imitation/scripts/train_preference_comparisons.py b/src/imitation/scripts/train_preference_comparisons.py index 6b454c43a..006c5854f 100644 --- a/src/imitation/scripts/train_preference_comparisons.py +++ b/src/imitation/scripts/train_preference_comparisons.py @@ -56,7 +56,6 @@ def save_checkpoint( @train_preference_comparisons_ex.main def train_preference_comparisons( - _seed: int, total_timesteps: int, total_comparisons: int, num_iterations: int, @@ -84,7 +83,6 @@ def train_preference_comparisons( """Train a reward model using preference comparisons. Args: - _seed: Random seed. total_timesteps: number of environment interaction steps total_comparisons: number of preferences to gather in total num_iterations: number of times to train the agent against the reward model @@ -148,6 +146,7 @@ def train_preference_comparisons( ValueError: Inconsistency between config and deserialized policy normalization. """ custom_logger, log_dir = common.setup_logging() + rng = common.make_rng() with common.make_venv() as venv: reward_net = reward.make_reward_net(venv) @@ -172,7 +171,7 @@ def train_preference_comparisons( reward_fn=reward_net, venv=venv, exploration_frac=exploration_frac, - seed=_seed, + rng=rng, custom_logger=custom_logger, **trajectory_generator_kwargs, ) @@ -186,15 +185,17 @@ def train_preference_comparisons( ) trajectory_generator = preference_comparisons.TrajectoryDataset( trajectories=types.load_with_rewards(trajectory_path), - seed=_seed, + rng=rng, custom_logger=custom_logger, **trajectory_generator_kwargs, ) - fragmenter = preference_comparisons.RandomFragmenter( - **fragmenter_kwargs, - seed=_seed, - custom_logger=custom_logger, + fragmenter: preference_comparisons.Fragmenter = ( + preference_comparisons.RandomFragmenter( + **fragmenter_kwargs, + rng=rng, + custom_logger=custom_logger, + ) ) preference_model = preference_comparisons.PreferenceModel( **preference_model_kwargs, @@ -210,7 +211,7 @@ def train_preference_comparisons( ) gatherer = gatherer_cls( **gatherer_kwargs, - seed=_seed, + rng=rng, custom_logger=custom_logger, ) @@ -219,8 +220,8 @@ def train_preference_comparisons( reward_trainer = preference_comparisons._make_reward_trainer( preference_model, loss, + rng, reward_trainer_kwargs, - seed=_seed, ) main_trainer = preference_comparisons.PreferenceComparisons( @@ -236,7 +237,6 @@ def train_preference_comparisons( initial_comparison_frac=initial_comparison_frac, custom_logger=custom_logger, allow_variable_horizon=allow_variable_horizon, - seed=_seed, query_schedule=query_schedule, ) diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index 7122cd701..d68778d1a 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -87,6 +87,7 @@ def train_rl( Returns: The return value of `rollout_stats()` using the final policy. """ + rng = common.make_rng() custom_logger, log_dir = common.setup_logging() rollout_dir = osp.join(log_dir, "rollouts") policy_dir = osp.join(log_dir, "policies") @@ -120,7 +121,9 @@ def train_rl( ) if policy_save_interval > 0: - save_policy_callback = serialize.SavePolicyCallback(policy_dir) + save_policy_callback: callbacks.EventCallback = ( + serialize.SavePolicyCallback(policy_dir) + ) save_policy_callback = callbacks.EveryNTimesteps( policy_save_interval, save_policy_callback, @@ -144,7 +147,7 @@ def train_rl( ) types.save( save_path, - rollout.rollout(rl_algo, rl_algo.get_env(), sample_until), + rollout.rollout(rl_algo, rl_algo.get_env(), sample_until, rng=rng), ) if policy_save_final: output_dir = os.path.join(policy_dir, "final") diff --git a/src/imitation/testing/reward_improvement.py b/src/imitation/testing/reward_improvement.py index c8ad3780b..c06f97c70 100644 --- a/src/imitation/testing/reward_improvement.py +++ b/src/imitation/testing/reward_improvement.py @@ -45,15 +45,15 @@ def is_significant_reward_improvement( def mean_reward_improved_by( - old_rewards: Iterable[float], - new_rewards: Iterable[float], + old_rews: Iterable[float], + new_rews: Iterable[float], min_improvement: float, ): """Checks if mean rewards improved wrt. to old rewards by a certain amount. Args: - old_rewards: Iterable of "old" trajectory rewards (e.g. before training). - new_rewards: Iterable of "new" trajectory rewards (e.g. after training). + old_rews: Iterable of "old" trajectory rewards (e.g. before training). + new_rews: Iterable of "new" trajectory rewards (e.g. after training). min_improvement: The minimum amount of improvement that we expect. Returns: @@ -66,4 +66,5 @@ def mean_reward_improved_by( >>> mean_reward_improved_by([5, 8, 7], [8, 9, 10], 5) False """ - return np.mean(new_rewards) - np.mean(old_rewards) >= min_improvement + improvement = np.mean(new_rews) - np.mean(old_rews) # type: ignore[call-overload] + return improvement >= min_improvement diff --git a/src/imitation/util/logger.py b/src/imitation/util/logger.py index ea59fed6e..77e54df8c 100644 --- a/src/imitation/util/logger.py +++ b/src/imitation/util/logger.py @@ -57,7 +57,7 @@ def _build_output_formats( A list of output formats, one corresponding to each `format_strs`. """ os.makedirs(folder, exist_ok=True) - output_formats = [] + output_formats: List[sb_logger.KVWriter] = [] for f in format_strs: if f == "wandb": output_formats.append(WandbOutputFormat()) @@ -123,6 +123,7 @@ class HierarchicalLogger(sb_logger.Logger): _key_prefixes: List[str] _subdir: Optional[str] _name: Optional[str] + format_strs: Sequence[str] def __init__( self, @@ -265,6 +266,7 @@ def accumulate_means(self, name: str) -> Generator[None, None, None]: if subdir in self._cached_loggers: logger = self._cached_loggers[subdir] else: + assert self.default_logger.dir is not None folder = os.path.join(self.default_logger.dir, "raw", subdir) os.makedirs(folder, exist_ok=True) output_formats = _build_output_formats(folder, self.format_strs) @@ -352,7 +354,7 @@ def __init__(self): """ try: import wandb - except ModuleNotFoundError as e: + except ModuleNotFoundError as e: # pragma: no cover raise ModuleNotFoundError( "Trying to log data with `WandbOutputFormat` " "but `wandb` not installed: try `pip install wandb`.", diff --git a/src/imitation/util/networks.py b/src/imitation/util/networks.py index 664f5f081..187fecc59 100644 --- a/src/imitation/util/networks.py +++ b/src/imitation/util/networks.py @@ -1,9 +1,9 @@ """Helper methods to build and run neural networks.""" +import abc import collections import contextlib import functools -from abc import ABC, abstractclassmethod -from typing import Iterable, Optional, Type, Union +from typing import Iterable, Optional, OrderedDict, Type, Union import torch as th from torch import nn @@ -44,7 +44,7 @@ def forward(self, x): return new_value -class BaseNorm(nn.Module, ABC): +class BaseNorm(nn.Module, abc.ABC): """Base class for layers that try to normalize the input to mean 0 and variance 1. Similar to BatchNorm, LayerNorm, etc. but whereas they only use statistics from @@ -88,7 +88,7 @@ def forward(self, x: th.Tensor) -> th.Tensor: return (x - self.running_mean) / th.sqrt(self.running_var + self.eps) - @abstractclassmethod + @abc.abstractmethod def update_stats(self, batch: th.Tensor) -> None: """Update `self.running_mean`, `self.running_var` and `self.count`.""" @@ -193,7 +193,7 @@ def update_stats(self, batch: th.Tensor) -> None: self.running_var += learning_rate * delta_var self.count += b_size - self.num_batches += 1 + self.num_batches += 1 # type: ignore[misc] def build_mlp( @@ -234,7 +234,7 @@ def build_mlp( Raises: ValueError: if squeeze_output was supplied with out_size!=1. """ - layers = collections.OrderedDict() + layers: OrderedDict[str, nn.Module] = collections.OrderedDict() if name is None: prefix = "" @@ -246,7 +246,14 @@ def build_mlp( # Normalize input layer if normalize_input_layer: - layers[f"{prefix}normalize_input"] = normalize_input_layer(in_size) + try: + layer_instance = normalize_input_layer(in_size) # type: ignore[call-arg] + except TypeError as exc: + raise ValueError( + f"normalize_input_layer={normalize_input_layer} is not a valid " + "normalization layer type accepting only one argument (in_size).", + ) from exc + layers[f"{prefix}normalize_input"] = layer_instance # Hidden layers prev_size = in_size @@ -309,7 +316,7 @@ def build_cnn( Raises: ValueError: if squeeze_output was supplied with out_size!=1. """ - layers = collections.OrderedDict() + layers: OrderedDict[str, nn.Module] = collections.OrderedDict() if name is None: prefix = "" diff --git a/src/imitation/util/sacred.py b/src/imitation/util/sacred.py index d6d7fe8bd..b96f872df 100644 --- a/src/imitation/util/sacred.py +++ b/src/imitation/util/sacred.py @@ -7,6 +7,8 @@ from typing import Any, Callable, NamedTuple, Sequence, Union import sacred +import sacred.observers +import sacred.run from imitation.data import types @@ -78,6 +80,8 @@ def filter_subdirs( def build_sacred_symlink(log_dir: types.AnyPath, run: sacred.run.Run) -> None: """Constructs a symlink "{log_dir}/sacred" => "${SACRED_PATH}".""" + if isinstance(log_dir, bytes): + log_dir = log_dir.decode("utf-8") log_dir = pathlib.Path(log_dir) sacred_dir = get_sacred_dir_from_run(run) diff --git a/src/imitation/util/util.py b/src/imitation/util/util.py index f9d4e83bc..bbb7b2c37 100644 --- a/src/imitation/util/util.py +++ b/src/imitation/util/util.py @@ -5,16 +5,20 @@ import itertools import os import uuid +import warnings from typing import ( Any, Callable, Iterable, Iterator, + List, Mapping, Optional, Sequence, + Tuple, TypeVar, Union, + overload, ) import gym @@ -63,8 +67,9 @@ def make_unique_timestamp() -> str: def make_vec_env( env_name: str, + *, + rng: np.random.Generator, n_envs: int = 8, - seed: int = 0, parallel: bool = False, log_dir: Optional[str] = None, max_episode_steps: Optional[int] = None, @@ -75,8 +80,8 @@ def make_vec_env( Args: env_name: The Env's string id in Gym. + rng: The random state to use to seed the environment. n_envs: The number of duplicate environments. - seed: The environment seed. parallel: If True, uses SubprocVecEnv; otherwise, DummyVecEnv. log_dir: If specified, saves Monitor output to this directory. max_episode_steps: If specified, wraps each env in a TimeLimit wrapper @@ -99,7 +104,7 @@ def make_vec_env( spec = gym.spec(env_name) env_make_kwargs = env_make_kwargs or {} - def make_env(i, this_seed): + def make_env(i: int, this_seed: int) -> gym.Env: # Previously, we directly called `gym.make(env_name)`, but running # `imitation.scripts.train_adversarial` within `imitation.scripts.parallel` # created a weird interaction between Gym and Ray -- `gym.make` would fail @@ -136,9 +141,10 @@ def make_env(i, this_seed): return env - rng = np.random.RandomState(seed) - env_seeds = rng.randint(0, (1 << 31) - 1, (n_envs,)) - env_fns = [functools.partial(make_env, i, s) for i, s in enumerate(env_seeds)] + env_seeds = make_seeds(rng, n_envs) + env_fns: List[Callable[[], gym.Env]] = [ + functools.partial(make_env, i, s) for i, s in enumerate(env_seeds) + ] if parallel: # See GH hill-a/stable-baselines issue #217 return SubprocVecEnv(env_fns, start_method="forkserver") @@ -146,6 +152,39 @@ def make_env(i, this_seed): return DummyVecEnv(env_fns) +@overload +def make_seeds( + rng: np.random.Generator, +) -> int: + ... + + +@overload +def make_seeds(rng: np.random.Generator, n: int) -> List[int]: + ... + + +def make_seeds( + rng: np.random.Generator, + n: Optional[int] = None, +) -> Union[Sequence[int], int]: + """Generate n random seeds from a random state. + + Args: + rng: The random state to use to generate seeds. + n: The number of seeds to generate. + + Returns: + A list of n random seeds. + """ + seeds_arr = rng.integers(0, (1 << 31) - 1, (n if n is not None else 1,)) + seeds: List[int] = seeds_arr.tolist() + if n is None: + return seeds[0] + else: + return seeds + + def docstring_parameter(*args, **kwargs): """Treats the docstring as a format string, substituting in the arguments.""" @@ -172,23 +211,23 @@ def endless_iter(iterable: Iterable[T]) -> Iterator[T]: 0 Args: - iterable: The object to endlessly iterate over. + iterable: The non-iterator iterable object to endlessly iterate over. Returns: An iterator that repeats the elements in `iterable` forever. Raises: - ValueError: `iterable` is empty -- the first call it to returns no elements. + ValueError: if iterable is an iterator -- that will be exhausted, so + cannot be iterated over endlessly. """ - try: - next(iter(iterable)) - except StopIteration: - raise ValueError(f"iterable {iterable} had no elements to iterate over.") + if iter(iterable) == iterable: + raise ValueError("endless_iter needs a non-iterator Iterable.") + _, iterable = get_first_iter_element(iterable) return itertools.chain.from_iterable(itertools.repeat(iterable)) -def safe_to_tensor(numpy_array: np.ndarray, **kwargs) -> th.Tensor: +def safe_to_tensor(array: Union[np.ndarray, th.Tensor], **kwargs) -> th.Tensor: """Converts a NumPy array to a PyTorch tensor. The data is copied in the case where the array is non-writable. Unfortunately if @@ -196,16 +235,61 @@ def safe_to_tensor(numpy_array: np.ndarray, **kwargs) -> th.Tensor: undefined behavior if you try to write to the tensor. Args: - numpy_array: The numpy array to convert to a PyTorch tensor. + array: The array to convert to a PyTorch tensor. kwargs: Additional keyword arguments to pass to `th.as_tensor`. Returns: - A PyTorch tensor with the same content as `numpy_array`. + A PyTorch tensor with the same content as `array`. """ - if not numpy_array.flags.writeable: - numpy_array = numpy_array.copy() + if isinstance(array, th.Tensor): + return array + + if not array.flags.writeable: + array = array.copy() + + return th.as_tensor(array, **kwargs) + + +@overload +def safe_to_numpy(obj: Union[np.ndarray, th.Tensor], warn: bool = False) -> np.ndarray: + ... + + +@overload +def safe_to_numpy(obj: None, warn: bool = False) -> None: + ... + + +def safe_to_numpy( + obj: Optional[Union[np.ndarray, th.Tensor]], + warn: bool = False, +) -> Optional[np.ndarray]: + """Convert torch tensor to numpy. - return th.as_tensor(numpy_array, **kwargs) + If the object is already a numpy array, return it as is. + If the object is none, returns none. + + Args: + obj: torch tensor object to convert to numpy array + warn: if True, warn if the object is not already a numpy array. Useful for + warning the user of a potential performance hit if a torch tensor is + not the expected input type. + + Returns: + Object converted to numpy array + """ + if obj is None: + # We ignore the type due to https://github.com/google/pytype/issues/445 + return None # pytype: disable=bad-return-type + elif isinstance(obj, np.ndarray): + return obj + else: + if warn: + warnings.warn( + "Converted tensor to numpy array, might affect performance. " + "Make sure this is the intended behavior.", + ) + return obj.detach().cpu().numpy() def tensor_iter_norm( @@ -236,3 +320,42 @@ def tensor_iter_norm( # = sum(x**ord for x in tensor for tensor in tensor_iter)**(1/ord) # = th.norm(concatenated tensors) return th.norm(norm_tensor, p=ord) + + +def get_first_iter_element(iterable: Iterable[T]) -> Tuple[T, Iterable[T]]: + """Get first element of an iterable and a new fresh iterable. + + The fresh iterable has the first element added back using ``itertools.chain``. + If the iterable is not an iterator, this is equivalent to + ``(next(iter(iterable)), iterable)``. + + Args: + iterable: The iterable to get the first element of. + + Returns: + A tuple containing the first element of the iterable, and a fresh iterable + with all the elements. + + Raises: + ValueError: `iterable` is empty -- the first call to it returns no elements. + """ + iterator = iter(iterable) + try: + first_element = next(iterator) + except StopIteration: + raise ValueError(f"iterable {iterable} had no elements to iterate over.") + + return_iterable: Iterable[T] + if iterator == iterable: + # `iterable` was an iterator. Getting `first_element` will have removed it + # from `iterator`, so we need to add a fresh iterable with `first_element` + # added back in. + return_iterable = itertools.chain([first_element], iterator) + else: + # `iterable` was not an iterator; we can just return `iterable`. + # `iter(iterable)` will give a fresh iterator containing the first element. + # It's preferable to return `iterable` without modification so that users + # can generate new iterators from it as needed. + return_iterable = iterable + + return first_element, return_iterable diff --git a/src/imitation/util/video_wrapper.py b/src/imitation/util/video_wrapper.py index 5fc2ae4f5..6da02aa70 100644 --- a/src/imitation/util/video_wrapper.py +++ b/src/imitation/util/video_wrapper.py @@ -1,6 +1,7 @@ """Wrapper to record rendered video frames from an environment.""" import os +from typing import Optional import gym from gym.wrappers.monitoring import video_recorder @@ -11,6 +12,11 @@ class VideoWrapper(gym.Wrapper): """Creates videos from wrapped environment by calling render after each timestep.""" + episode_id: int + video_recorder: Optional[video_recorder.VideoRecorder] + single_video: bool + directory: str + def __init__( self, env: gym.Env, @@ -33,7 +39,7 @@ def __init__( self.video_recorder = None self.single_video = single_video - self.directory = os.path.abspath(directory) + self.directory = str(os.path.abspath(directory)) os.makedirs(self.directory) def _reset_video_recorder(self) -> None: diff --git a/tests/algorithms/test_adversarial.py b/tests/algorithms/test_adversarial.py index fc06235fb..a61ca8d34 100644 --- a/tests/algorithms/test_adversarial.py +++ b/tests/algorithms/test_adversarial.py @@ -2,7 +2,7 @@ import contextlib import os -from typing import Any, Mapping +from typing import Any, Mapping, Type, Union import numpy as np import pytest @@ -42,7 +42,7 @@ EXPERT_BATCH_SIZES = [1, 128] -@pytest.fixture(params=ALGORITHM_KWARGS.values(), ids=ALGORITHM_KWARGS.keys()) +@pytest.fixture(params=ALGORITHM_KWARGS.values(), ids=list(ALGORITHM_KWARGS.keys())) def _algorithm_kwargs(request): """Auto-parametrizes `_rl_algorithm_cls` for the `trainer` fixture.""" return dict(request.param) @@ -58,12 +58,14 @@ def make_trainer( algorithm_kwargs: Mapping[str, Any], tmpdir: str, expert_transitions: types.Transitions, + rng: np.random.Generator, expert_batch_size: int = 1, env_name: str = "seals/CartPole-v0", num_envs: int = 1, parallel: bool = False, convert_dataset: bool = False, ): + expert_data: Union[th_data.DataLoader, th_data.Dataset] if convert_dataset: expert_data = th_data.DataLoader( expert_transitions, @@ -75,10 +77,15 @@ def make_trainer( else: expert_data = expert_transitions - venv = util.make_vec_env(env_name, n_envs=num_envs, parallel=parallel) + venv = util.make_vec_env( + env_name, + n_envs=num_envs, + parallel=parallel, + rng=rng, + ) model_cls = algorithm_kwargs["model_class"] gen_algo = model_cls(algorithm_kwargs["policy_class"], venv) - reward_net_cls = reward_nets.BasicRewardNet + reward_net_cls: Type[reward_nets.RewardNet] = reward_nets.BasicRewardNet if algorithm_kwargs["algorithm_cls"] == airl.AIRL: reward_net_cls = reward_nets.BasicShapedRewardNet reward_net = reward_net_cls(venv.observation_space, venv.action_space) @@ -100,15 +107,21 @@ def make_trainer( venv.close() -def test_airl_fail_fast(custom_logger, tmpdir): +def test_airl_fail_fast(custom_logger, tmpdir, rng): venv = util.make_vec_env( "seals/CartPole-v0", n_envs=1, parallel=False, + rng=rng, ) gen_algo = stable_baselines3.DQN(stable_baselines3.dqn.MlpPolicy, venv) - small_data = rollout.generate_transitions(gen_algo, venv, n_timesteps=20) + small_data = rollout.generate_transitions( + gen_algo, + venv, + n_timesteps=20, + rng=rng, + ) reward_net = reward_nets.BasicShapedRewardNet( observation_space=venv.observation_space, action_space=venv.action_space, @@ -126,9 +139,14 @@ def test_airl_fail_fast(custom_logger, tmpdir): ) -@pytest.fixture(params=ALGORITHM_KWARGS.values(), ids=ALGORITHM_KWARGS.keys()) -def trainer(request, tmpdir, expert_transitions): - with make_trainer(request.param, tmpdir, expert_transitions) as trainer: +@pytest.fixture(params=ALGORITHM_KWARGS.values(), ids=list(ALGORITHM_KWARGS.keys())) +def trainer(request, tmpdir, expert_transitions, rng): + with make_trainer( + request.param, + tmpdir, + expert_transitions, + rng, + ) as trainer: yield trainer @@ -176,11 +194,13 @@ def trainer_parametrized( _expert_batch_size, tmpdir, expert_transitions, + rng, ): with make_trainer( _algorithm_kwargs, tmpdir, expert_transitions, + rng=rng, parallel=_parallel, convert_dataset=_convert_dataset, expert_batch_size=_expert_batch_size, @@ -188,12 +208,17 @@ def trainer_parametrized( yield trainer -def test_train_disc_step_no_crash(trainer_parametrized, _expert_batch_size): +def test_train_disc_step_no_crash( + trainer_parametrized, + _expert_batch_size, + rng, +): transitions = rollout.generate_transitions( trainer_parametrized.gen_algo, trainer_parametrized.venv, n_timesteps=_expert_batch_size, truncate=True, + rng=rng, ) trainer_parametrized.train_disc( gen_samples=types.dataclass_quick_asdict(transitions), @@ -214,12 +239,14 @@ def trainer_batch_sizes( _expert_batch_size, tmpdir, expert_transitions, + rng, ): with make_trainer( _algorithm_kwargs, tmpdir, expert_transitions, expert_batch_size=_expert_batch_size, + rng=rng, ) as trainer: yield trainer @@ -229,6 +256,7 @@ def test_train_disc_improve_D( tmpdir, expert_transitions, _expert_batch_size, + rng, n_steps=3, ): expert_samples = expert_transitions[:_expert_batch_size] @@ -238,6 +266,7 @@ def test_train_disc_improve_D( trainer_batch_sizes.venv_train, n_timesteps=_expert_batch_size, truncate=True, + rng=rng, ) gen_samples = types.dataclass_quick_asdict(gen_samples) init_stats = final_stats = None @@ -258,13 +287,20 @@ def _env_name(request): @pytest.fixture -def trainer_diverse_env(_algorithm_kwargs, _env_name, tmpdir, expert_transitions): +def trainer_diverse_env( + _algorithm_kwargs, + _env_name, + tmpdir, + expert_transitions, + rng, +): if _algorithm_kwargs["model_class"] == stable_baselines3.DQN: pytest.skip("DQN does not support all environments.") with make_trainer( _algorithm_kwargs, tmpdir, expert_transitions, + rng=rng, env_name=_env_name, ) as trainer: yield trainer @@ -274,6 +310,7 @@ def trainer_diverse_env(_algorithm_kwargs, _env_name, tmpdir, expert_transitions def test_logits_expert_is_high_log_policy_act_prob( trainer_diverse_env: common.AdversarialTrainer, n_timesteps: int, + rng, ): """Smoke test calling `logits_expert_is_high` on `AdversarialTrainer`. @@ -283,11 +320,13 @@ def test_logits_expert_is_high_log_policy_act_prob( Args: trainer_diverse_env: The trainer to test. n_timesteps: The number of timesteps of rollouts to collect. + rng: The random state to use. """ trans = rollout.generate_transitions( policy=None, venv=trainer_diverse_env.venv, n_timesteps=n_timesteps, + rng=rng, ) obs, acts, next_obs, dones = trainer_diverse_env.reward_train.preprocess( @@ -300,6 +339,7 @@ def test_logits_expert_is_high_log_policy_act_prob( log_act_prob_non_none = th.as_tensor(log_act_prob_non_none).to(obs.device) for log_act_prob in [None, log_act_prob_non_none]: + maybe_error_ctx: contextlib.AbstractContextManager if isinstance(trainer_diverse_env, airl.AIRL) and log_act_prob is None: maybe_error_ctx = pytest.raises(TypeError, match="Non-None.*required.*") else: diff --git a/tests/algorithms/test_bc.py b/tests/algorithms/test_bc.py index 3e01a12e2..fce89e63c 100644 --- a/tests/algorithms/test_bc.py +++ b/tests/algorithms/test_bc.py @@ -47,6 +47,7 @@ def trainer( expert_data_type, custom_logger, cartpole_expert_trajectories, + rng, ): trans = rollout.flatten_trajectories(cartpole_expert_trajectories) if expert_data_type == "data_loader": @@ -70,10 +71,11 @@ def trainer( batch_size=batch_size, demonstrations=expert_data, custom_logger=custom_logger, + rng=rng, ) -def test_weight_decay_init_error(cartpole_venv, custom_logger): +def test_weight_decay_init_error(cartpole_venv, custom_logger, rng): with pytest.raises(ValueError, match=".*weight_decay.*"): bc.BC( observation_space=cartpole_venv.observation_space, @@ -81,6 +83,7 @@ def test_weight_decay_init_error(cartpole_venv, custom_logger): demonstrations=None, optimizer_kwargs=dict(weight_decay=1e-4), custom_logger=custom_logger, + rng=rng, ) @@ -101,6 +104,7 @@ def test_bc(trainer: bc.BC, cartpole_venv): 15, return_episode_rewards=True, ) + assert isinstance(novice_rewards, list) trainer.train( n_epochs=1, @@ -114,6 +118,7 @@ def test_bc(trainer: bc.BC, cartpole_venv): 15, return_episode_rewards=True, ) + assert isinstance(rewards_after_training, list) assert reward_improvement.is_significant_reward_improvement( novice_rewards, rewards_after_training, @@ -163,6 +168,7 @@ def test_bc_data_loader_empty_iter_error( no_yield_after_iter: bool, custom_logger: logger.HierarchicalLogger, cartpole_expert_trajectories, + rng, ) -> None: """Check that we error out if the DataLoader suddenly stops yielding any batches. @@ -173,6 +179,7 @@ def test_bc_data_loader_empty_iter_error( no_yield_after_iter: Data loader stops yielding after this many calls. custom_logger: Where to log to. cartpole_expert_trajectories: The expert trajectories to use. + rng: Random state to use. """ batch_size = 32 trans = rollout.flatten_trajectories(cartpole_expert_trajectories) @@ -180,13 +187,16 @@ def test_bc_data_loader_empty_iter_error( bad_data_loader = _DataLoaderFailsOnNthIter( dummy_yield_value=dummy_yield_value, - no_yield_after_iter=no_yield_after_iter, + # add 1 as BC uses up an iteration from getting the first element + # for type checking + no_yield_after_iter=no_yield_after_iter + 1, ) trainer = bc.BC( observation_space=cartpole_venv.observation_space, action_space=cartpole_venv.action_space, batch_size=batch_size, custom_logger=custom_logger, + rng=rng, ) trainer.set_demonstrations(bad_data_loader) with pytest.raises(AssertionError, match=".*no data.*"): diff --git a/tests/algorithms/test_dagger.py b/tests/algorithms/test_dagger.py index 1bb204e2c..683c34f8f 100644 --- a/tests/algorithms/test_dagger.py +++ b/tests/algorithms/test_dagger.py @@ -41,7 +41,7 @@ def test_beta_schedule(): assert np.allclose(three_step_sched(i), (3 - i) / 3 if i <= 2 else 0) -def test_traj_collector_seed(tmpdir, pendulum_venv): +def test_traj_collector_seed(tmpdir, pendulum_venv, rng): collector = dagger.InteractiveTrajectoryCollector( venv=pendulum_venv, get_robot_acts=lambda o: [ @@ -49,6 +49,7 @@ def test_traj_collector_seed(tmpdir, pendulum_venv): ], beta=0.5, save_dir=tmpdir, + rng=rng, ) seeds1 = collector.seed(42) obs1 = collector.reset() @@ -59,7 +60,7 @@ def test_traj_collector_seed(tmpdir, pendulum_venv): np.testing.assert_array_equal(obs1, obs2) -def test_traj_collector(tmpdir, pendulum_venv): +def test_traj_collector(tmpdir, pendulum_venv, rng): robot_calls = 0 num_episodes = 0 @@ -73,6 +74,7 @@ def get_random_acts(obs): get_robot_acts=get_random_acts, beta=0.5, save_dir=tmpdir, + rng=rng, ) collector.reset() zero_acts = np.zeros( @@ -110,6 +112,7 @@ def _build_dagger_trainer( expert_policy, pendulum_expert_rollouts: List[TrajectoryWithRew], custom_logger, + rng: np.random.Generator, ): del expert_policy if pendulum_expert_rollouts is not None: @@ -122,6 +125,7 @@ def _build_dagger_trainer( action_space=venv.action_space, optimizer_kwargs=dict(lr=1e-3), custom_logger=custom_logger, + rng=rng, ) return dagger.DAggerTrainer( venv=venv, @@ -129,6 +133,7 @@ def _build_dagger_trainer( beta_schedule=beta_schedule, bc_trainer=bc_trainer, custom_logger=custom_logger, + rng=rng, ) @@ -137,14 +142,16 @@ def _build_simple_dagger_trainer( venv, beta_schedule, expert_policy, - pendulum_expert_rollouts: List[TrajectoryWithRew], + pendulum_expert_rollouts: Optional[List[TrajectoryWithRew]], custom_logger, + rng, ): bc_trainer = bc.BC( observation_space=venv.observation_space, action_space=venv.action_space, optimizer_kwargs=dict(lr=1e-3), custom_logger=custom_logger, + rng=rng, ) return dagger.SimpleDAggerTrainer( venv=venv, @@ -154,6 +161,7 @@ def _build_simple_dagger_trainer( expert_policy=expert_policy, expert_trajs=pendulum_expert_rollouts, custom_logger=custom_logger, + rng=rng, ) @@ -171,6 +179,7 @@ def init_trainer_fn( pendulum_expert_policy, maybe_pendulum_expert_trajectories: Optional[List[TrajectoryWithRew]], custom_logger, + rng, ): # Provide a trainer initialization fixture in addition `trainer` fixture below # for tests that want to initialize multiple DAggerTrainer. @@ -182,6 +191,7 @@ def init_trainer_fn( pendulum_expert_policy, maybe_pendulum_expert_trajectories, custom_logger, + rng, ) @@ -198,6 +208,7 @@ def simple_dagger_trainer( pendulum_expert_policy, maybe_pendulum_expert_trajectories: Optional[List[TrajectoryWithRew]], custom_logger, + rng, ): return _build_simple_dagger_trainer( tmpdir, @@ -206,6 +217,7 @@ def simple_dagger_trainer( pendulum_expert_policy, maybe_pendulum_expert_trajectories, custom_logger, + rng, ) @@ -215,6 +227,7 @@ def test_trainer_needs_demos_exception_error( ): assert trainer.round_num == 0 error_ctx = pytest.raises(dagger.NeedsDemosException) + ctx: contextlib.AbstractContextManager if maybe_pendulum_expert_trajectories is not None and isinstance( trainer, dagger.SimpleDAggerTrainer, @@ -237,13 +250,14 @@ def test_trainer_needs_demos_exception_error( trainer.extend_and_update(dict(n_epochs=1)) -def test_trainer_train_arguments(trainer, pendulum_expert_policy): +def test_trainer_train_arguments(trainer, pendulum_expert_policy, rng): def add_samples(): collector = trainer.create_trajectory_collector() rollout.generate_trajectories( pendulum_expert_policy, collector, sample_until=rollout.make_min_timesteps(40), + rng=rng, ) # Lower default number of epochs for the no-arguments call that follows. @@ -371,6 +385,7 @@ def test_simple_dagger_space_mismatch_error( pendulum_expert_policy, maybe_pendulum_expert_trajectories: Optional[List[TrajectoryWithRew]], custom_logger, + rng, ): class MismatchedSpace(gym.spaces.Space): """Dummy space that is not equal to any other space.""" @@ -388,26 +403,34 @@ class MismatchedSpace(gym.spaces.Space): pendulum_expert_policy, maybe_pendulum_expert_trajectories, custom_logger, + rng, ) -def test_dagger_not_enough_transitions_error(tmpdir, custom_logger): - venv = util.make_vec_env("CartPole-v0") +def test_dagger_not_enough_transitions_error(tmpdir, custom_logger, rng): + venv = util.make_vec_env("CartPole-v0", rng=rng) # Initialize with large batch size to ensure error down the line. bc_trainer = bc.BC( observation_space=venv.observation_space, action_space=venv.action_space, batch_size=100_000, custom_logger=custom_logger, + rng=rng, ) trainer = dagger.DAggerTrainer( venv=venv, scratch_dir=tmpdir, bc_trainer=bc_trainer, custom_logger=custom_logger, + rng=rng, ) collector = trainer.create_trajectory_collector() policy = base.RandomPolicy(venv.observation_space, venv.action_space) - rollout.generate_trajectories(policy, collector, rollout.make_min_episodes(1)) + rollout.generate_trajectories( + policy, + collector, + rollout.make_min_episodes(1), + rng=rng, + ) with pytest.raises(ValueError, match="Not enough transitions.*"): trainer.extend_and_update() diff --git a/tests/algorithms/test_density_baselines.py b/tests/algorithms/test_density_baselines.py index 481d8c493..f8a6ab069 100644 --- a/tests/algorithms/test_density_baselines.py +++ b/tests/algorithms/test_density_baselines.py @@ -1,5 +1,6 @@ """Tests for `imitation.algorithms.density_baselines`.""" +from dataclasses import asdict from typing import Sequence import numpy as np @@ -44,6 +45,7 @@ def test_density_reward( is_stationary, pendulum_venv, pendulum_expert_trajectories: Sequence[TrajectoryWithRew], + rng, ): # use only a subset of trajectories expert_trajectories_all = pendulum_expert_trajectories[:8] @@ -57,6 +59,7 @@ def test_density_reward( is_stationary=is_stationary, kernel_bandwidth=0.2, standardise_inputs=True, + rng=rng, ) reward_fn.train() @@ -71,6 +74,7 @@ def test_density_reward( random_policy, pendulum_venv, sample_until=sample_until, + rng=rng, ) expert_trajectories_test = expert_trajectories_all[n_experts // 2 :] random_returns = score_trajectories(random_trajectories, reward_fn) @@ -85,6 +89,7 @@ def test_density_reward( def test_density_trainer_smoke( pendulum_venv, pendulum_expert_trajectories: Sequence[TrajectoryWithRew], + rng, ): # tests whether density trainer runs, not whether it's good # (it's actually really poor) @@ -94,7 +99,69 @@ def test_density_trainer_smoke( demonstrations=rollouts, venv=pendulum_venv, rl_algo=rl_algo, + rng=rng, ) density_trainer.train() density_trainer.train_policy(n_timesteps=2) density_trainer.test_policy(n_trajectories=2) + + +def test_density_with_other_trajectory_types( + pendulum_expert_trajectories: Sequence[TrajectoryWithRew], + pendulum_venv, + rng, +): + rl_algo = stable_baselines3.PPO(policies.ActorCriticPolicy, pendulum_venv) + rollouts = pendulum_expert_trajectories[:2] + transitions = rollout.flatten_trajectories_with_rew(rollouts) + transitions_mappings = [ + asdict(transitions), + ] + + minimal_transitions = types.TransitionsMinimal( + obs=transitions.obs, + acts=transitions.acts, + infos=transitions.infos, + ) + d = DensityAlgorithm( + demonstrations=transitions_mappings, + venv=pendulum_venv, + rl_algo=rl_algo, + rng=rng, + ) + d.train() + d.train_policy(n_timesteps=2) + d.test_policy(n_trajectories=2) + + d = DensityAlgorithm( + demonstrations=minimal_transitions, + venv=pendulum_venv, + rl_algo=rl_algo, + rng=rng, + ) + d.train() + d.train_policy(n_timesteps=2) + d.test_policy(n_trajectories=2) + + +def test_density_trainer_raises( + pendulum_venv, + rng, +): + rl_algo = stable_baselines3.PPO(policies.ActorCriticPolicy, pendulum_venv) + density_trainer = DensityAlgorithm( + venv=pendulum_venv, + rl_algo=rl_algo, + rng=rng, + demonstrations=None, + density_type=DensityType.STATE_STATE_DENSITY, + ) + with pytest.raises(ValueError, match="STATE_STATE_DENSITY requires next_obs_b"): + density_trainer._get_demo_from_batch( + np.zeros((1, 3)), + np.zeros((1, 1)), + None, + ) + + with pytest.raises(TypeError, match="Unsupported demonstration type"): + density_trainer.set_demonstrations("foo") # type: ignore[arg-type] diff --git a/tests/algorithms/test_mce_irl.py b/tests/algorithms/test_mce_irl.py index 7bb9b7560..50c022eee 100644 --- a/tests/algorithms/test_mce_irl.py +++ b/tests/algorithms/test_mce_irl.py @@ -233,14 +233,13 @@ def test_policy_om_reasonable_mdp(discount: float): assert np.allclose(Dt[0], mdp.initial_state_dist) -def test_tabular_policy(): +def test_tabular_policy(rng): """Tests tabular policy prediction, especially timestep calculation and masking.""" state_space = gym.spaces.Discrete(2) action_space = gym.spaces.Discrete(2) pi = np.stack( [np.eye(2), 1 - np.eye(2)], ) - rng = np.random.RandomState(42) tabular = TabularPolicy( state_space=state_space, action_space=action_space, @@ -251,25 +250,25 @@ def test_tabular_policy(): states = np.array([0, 1, 1, 0, 1]) actions, timesteps = tabular.predict(states) np.testing.assert_array_equal(states, actions) - np.testing.assert_equal(timesteps, 1) + np.testing.assert_equal(timesteps[0], 1) mask = np.zeros((5,), dtype=bool) actions, timesteps = tabular.predict(states, timesteps, mask) np.testing.assert_array_equal(1 - states, actions) - np.testing.assert_equal(timesteps, 2) + np.testing.assert_equal(timesteps[0], 2) mask = np.ones((5,), dtype=bool) actions, timesteps = tabular.predict(states, timesteps, mask) np.testing.assert_array_equal(states, actions) - np.testing.assert_equal(timesteps, 1) + np.testing.assert_equal(timesteps[0], 1) mask = (1 - states).astype(bool) actions, timesteps = tabular.predict(states, timesteps, mask) np.testing.assert_array_equal(np.zeros((5,)), actions) - np.testing.assert_equal(timesteps, 2 - mask.astype(int)) + np.testing.assert_equal(timesteps[0], 2 - mask.astype(int)) -def test_tabular_policy_randomness(): +def test_tabular_policy_randomness(rng): state_space = gym.spaces.Discrete(2) action_space = gym.spaces.Discrete(2) pi = np.array( @@ -280,7 +279,6 @@ def test_tabular_policy_randomness(): ], ], ) - rng = np.random.RandomState(42) tabular = TabularPolicy( state_space=state_space, action_space=action_space, @@ -288,16 +286,16 @@ def test_tabular_policy_randomness(): rng=rng, ) - actions, _ = tabular.predict(np.zeros((100,), dtype=int)) + actions, _ = tabular.predict(np.zeros((1000,), dtype=int)) assert 0.45 <= np.mean(actions) <= 0.55 - ones_obs = np.ones((100,), dtype=int) + ones_obs = np.ones((1000,), dtype=int) actions, _ = tabular.predict(ones_obs) assert 0.05 <= np.mean(actions) <= 0.15 actions, _ = tabular.predict(ones_obs, deterministic=True) np.testing.assert_equal(actions, 0) -def test_mce_irl_demo_formats(): +def test_mce_irl_demo_formats(rng): mdp = model_envs.RandomMDP( n_states=5, n_actions=3, @@ -313,6 +311,7 @@ def test_mce_irl_demo_formats(): policy=None, venv=state_venv, sample_until=rollout.make_min_timesteps(100), + rng=rng, ) demonstrations = { "trajs": trajs, @@ -337,7 +336,13 @@ def test_mce_irl_demo_formats(): use_done=False, hid_sizes=[], ) - mce_irl = MCEIRL(demo, mdp, reward_net, linf_eps=1e-3) + mce_irl = MCEIRL( + demo, + mdp, + reward_net, + linf_eps=1e-3, + rng=rng, + ) assert np.allclose(mce_irl.demo_state_om.sum(), mdp.horizon + 1) final_counts[kind] = mce_irl.train(max_iter=5) @@ -357,6 +362,7 @@ def test_mce_irl_demo_formats(): def test_mce_irl_reasonable_mdp( model_kwargs: Mapping[str, Any], discount: float, + rng, ): with th.random.fork_rng(): th.random.manual_seed(715298) @@ -377,7 +383,14 @@ def test_mce_irl_reasonable_mdp( use_done=False, **model_kwargs, ) - mce_irl = MCEIRL(D, mdp, reward_net, linf_eps=1e-3, discount=discount) + mce_irl = MCEIRL( + D, + mdp, + reward_net, + linf_eps=1e-3, + discount=discount, + rng=rng, + ) final_counts = mce_irl.train() assert np.allclose(final_counts, D, atol=1e-3, rtol=1e-3) @@ -390,6 +403,7 @@ def test_mce_irl_reasonable_mdp( mce_irl.policy, state_venv, sample_until=rollout.make_min_episodes(5), + rng=rng, ) stats = rollout.rollout_stats(trajs) if discount > 0.0: # skip check when discount==0.0 (random policy) diff --git a/tests/algorithms/test_preference_comparisons.py b/tests/algorithms/test_preference_comparisons.py index c1f1f4575..331ae6d1e 100644 --- a/tests/algorithms/test_preference_comparisons.py +++ b/tests/algorithms/test_preference_comparisons.py @@ -23,10 +23,12 @@ @pytest.fixture -def venv(): +def venv(rng): + rng return util.make_vec_env( "seals/CartPole-v0", n_envs=1, + rng=rng, ) @@ -55,13 +57,16 @@ def agent(venv): @pytest.fixture -def random_fragmenter(): - return preference_comparisons.RandomFragmenter(seed=0, warning_threshold=0) +def random_fragmenter(rng): + return preference_comparisons.RandomFragmenter( + rng=rng, + warning_threshold=0, + ) @pytest.fixture -def agent_trainer(agent, reward_net, venv): - return preference_comparisons.AgentTrainer(agent, reward_net, venv) +def agent_trainer(agent, reward_net, venv, rng): + return preference_comparisons.AgentTrainer(agent, reward_net, venv, rng) def _check_trajs_equal( @@ -73,14 +78,17 @@ def _check_trajs_equal( assert np.array_equal(traj1.obs, traj2.obs) assert np.array_equal(traj1.acts, traj2.acts) assert np.array_equal(traj1.rews, traj2.rews) + assert traj1.infos is not None + assert traj2.infos is not None assert np.array_equal(traj1.infos, traj2.infos) assert traj1.terminal == traj2.terminal -def test_mismatched_spaces(venv, agent): +def test_mismatched_spaces(venv, agent, rng): other_venv = util.make_vec_env( "seals/MountainCar-v0", n_envs=1, + rng=rng, ) bad_reward_net = reward_nets.BasicRewardNet( other_venv.observation_space, @@ -90,7 +98,12 @@ def test_mismatched_spaces(venv, agent): ValueError, match="spaces do not match", ): - preference_comparisons.AgentTrainer(agent, bad_reward_net, venv) + preference_comparisons.AgentTrainer( + agent, + bad_reward_net, + venv, + rng=rng, + ) def test_trajectory_dataset_seeding( @@ -99,12 +112,12 @@ def test_trajectory_dataset_seeding( ): dataset1 = preference_comparisons.TrajectoryDataset( cartpole_expert_trajectories, - seed=0, + rng=np.random.default_rng(0), ) sample1 = dataset1.sample(num_samples) dataset2 = preference_comparisons.TrajectoryDataset( cartpole_expert_trajectories, - seed=0, + rng=np.random.default_rng(0), ) sample2 = dataset2.sample(num_samples) @@ -112,7 +125,7 @@ def test_trajectory_dataset_seeding( dataset3 = preference_comparisons.TrajectoryDataset( cartpole_expert_trajectories, - seed=42, + rng=np.random.default_rng(42), ) sample3 = dataset3.sample(num_samples) with pytest.raises(AssertionError): @@ -124,10 +137,11 @@ def test_trajectory_dataset_seeding( def test_trajectory_dataset_len( cartpole_expert_trajectories: Sequence[TrajectoryWithRew], num_steps: int, + rng, ): dataset = preference_comparisons.TrajectoryDataset( cartpole_expert_trajectories, - seed=0, + rng=rng, ) sample = dataset.sample(num_steps) lengths = [len(t) for t in sample] @@ -138,10 +152,11 @@ def test_trajectory_dataset_len( def test_trajectory_dataset_too_long( cartpole_expert_trajectories: Sequence[TrajectoryWithRew], + rng, ): dataset = preference_comparisons.TrajectoryDataset( cartpole_expert_trajectories, - seed=0, + rng=rng, ) with pytest.raises(RuntimeError, match="Asked for.*but only.* available"): dataset.sample(100000) @@ -149,11 +164,12 @@ def test_trajectory_dataset_too_long( def test_trajectory_dataset_shuffle( cartpole_expert_trajectories: Sequence[TrajectoryWithRew], + rng, num_steps: int = 400, ): dataset = preference_comparisons.TrajectoryDataset( cartpole_expert_trajectories, - seed=0, + rng, ) sample = dataset.sample(num_steps) sample2 = dataset.sample(num_steps) @@ -175,6 +191,79 @@ def test_transitions_left_in_buffer(agent_trainer): agent_trainer.train(steps=1) +@pytest.mark.parametrize( + "schedule", + ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)], +) +def test_preference_comparisons_raises( + agent_trainer, + reward_net, + random_fragmenter, + preference_model, + custom_logger, + schedule, + rng, +): + loss = preference_comparisons.CrossEntropyRewardLoss() + reward_trainer = preference_comparisons.BasicRewardTrainer( + preference_model, + loss, + rng=rng, + ) + gatherer = preference_comparisons.SyntheticGatherer(rng=rng) + # no rng, must provide fragmenter, preference gatherer, reward trainer + no_rng_msg = ( + ".*don't provide.*random state.*provide.*fragmenter" + ".*preference gatherer.*reward_trainer.*" + ) + + def build_preference_comparsions(gatherer, reward_trainer, fragmenter, rng): + preference_comparisons.PreferenceComparisons( + agent_trainer, + reward_net, + num_iterations=2, + transition_oversampling=2, + reward_trainer=reward_trainer, + preference_gatherer=gatherer, + fragmenter=fragmenter, + custom_logger=custom_logger, + query_schedule=schedule, + rng=rng, + ) + + with pytest.raises(ValueError, match=no_rng_msg): + build_preference_comparsions(gatherer, None, None, rng=None) + + with pytest.raises(ValueError, match=no_rng_msg): + build_preference_comparsions(None, reward_trainer, None, rng=None) + + with pytest.raises(ValueError, match=no_rng_msg): + build_preference_comparsions(None, None, random_fragmenter, rng=None) + + # This should not raise + build_preference_comparsions(gatherer, reward_trainer, random_fragmenter, rng=None) + + # if providing fragmenter, preference gatherer, reward trainer, does not need rng. + with_rng_msg = ( + "provide.*fragmenter.*preference gatherer.*reward trainer" + ".*don't need.*random state.*" + ) + + with pytest.raises(ValueError, match=with_rng_msg): + build_preference_comparsions( + gatherer, + reward_trainer, + random_fragmenter, + rng=rng, + ) + + # This should not raise + build_preference_comparsions(None, None, None, rng=rng) + build_preference_comparsions(gatherer, None, None, rng=rng) + build_preference_comparsions(None, reward_trainer, None, rng=rng) + build_preference_comparsions(None, None, random_fragmenter, rng=rng) + + @pytest.mark.parametrize( "schedule", ["constant", "hyperbolic", "inverse_quadratic", lambda t: 1 / (1 + t**3)], @@ -185,6 +274,7 @@ def test_trainer_no_crash( random_fragmenter, custom_logger, schedule, + rng, ): main_trainer = preference_comparisons.PreferenceComparisons( agent_trainer, @@ -196,6 +286,7 @@ def test_trainer_no_crash( custom_logger=custom_logger, query_schedule=schedule, initial_epoch_multiplier=2, + rng=rng, ) result = main_trainer.train(100, 10) # We don't expect good performance after training for 10 (!) timesteps, @@ -204,7 +295,7 @@ def test_trainer_no_crash( assert 0.0 < result["reward_accuracy"] <= 1.0 -def test_reward_ensemble_trainer_raises_type_error(venv): +def test_reward_ensemble_trainer_raises_type_error(venv, rng): reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) preference_model = preference_comparisons.PreferenceModel( model=reward_net, @@ -221,6 +312,7 @@ def test_reward_ensemble_trainer_raises_type_error(venv): preference_comparisons.EnsembleTrainer( preference_model, loss, + rng=rng, ) @@ -229,6 +321,7 @@ def test_correct_reward_trainer_used_by_default( reward_net, random_fragmenter, custom_logger, + rng, ): main_trainer = preference_comparisons.PreferenceComparisons( agent_trainer, @@ -237,6 +330,7 @@ def test_correct_reward_trainer_used_by_default( transition_oversampling=2, fragment_length=2, fragmenter=random_fragmenter, + rng=rng, custom_logger=custom_logger, ) @@ -258,6 +352,7 @@ def test_init_raises_error_when_trying_use_improperly_wrapped_ensemble( venv, random_fragmenter, custom_logger, + rng, ): reward_net = testing_reward_nets.make_ensemble( venv.observation_space, @@ -279,11 +374,18 @@ def test_init_raises_error_when_trying_use_improperly_wrapped_ensemble( transition_oversampling=2, fragment_length=2, fragmenter=random_fragmenter, + rng=rng, custom_logger=custom_logger, ) -def test_discount_rate_no_crash(agent_trainer, venv, random_fragmenter, custom_logger): +def test_discount_rate_no_crash( + agent_trainer, + venv, + random_fragmenter, + custom_logger, + rng, +): # also use a non-zero noise probability to check that doesn't cause errors reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) preference_model = preference_comparisons.PreferenceModel( @@ -296,6 +398,7 @@ def test_discount_rate_no_crash(agent_trainer, venv, random_fragmenter, custom_l reward_trainer = preference_comparisons.BasicRewardTrainer( preference_model, loss, + rng=rng, ) main_trainer = preference_comparisons.PreferenceComparisons( @@ -305,14 +408,22 @@ def test_discount_rate_no_crash(agent_trainer, venv, random_fragmenter, custom_l transition_oversampling=2, fragment_length=2, fragmenter=random_fragmenter, + rng=rng, reward_trainer=reward_trainer, custom_logger=custom_logger, ) main_trainer.train(100, 10) -def test_synthetic_gatherer_deterministic(agent_trainer, random_fragmenter): - gatherer = preference_comparisons.SyntheticGatherer(temperature=0) +def test_synthetic_gatherer_deterministic( + agent_trainer, + random_fragmenter, + rng, +): + gatherer = preference_comparisons.SyntheticGatherer( + temperature=0, + rng=rng, + ) trajectories = agent_trainer.sample(10) fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=2) preferences1 = gatherer(fragments) @@ -320,6 +431,20 @@ def test_synthetic_gatherer_deterministic(agent_trainer, random_fragmenter): assert np.all(preferences1 == preferences2) +def test_synthetic_gatherer_raises( + agent_trainer, + random_fragmenter, +): + with pytest.raises( + ValueError, + match="If `sample` is True, then `rng` must be provided", + ): + preference_comparisons.SyntheticGatherer( + temperature=0, + sample=True, + ) + + def test_fragments_terminal(random_fragmenter): trajectories = [ types.TrajectoryWithRew( @@ -346,7 +471,7 @@ def test_fragments_terminal(random_fragmenter): def test_fragments_too_short_error(agent_trainer): trajectories = agent_trainer.sample(2) random_fragmenter = preference_comparisons.RandomFragmenter( - seed=0, + rng=np.random.default_rng(0), warning_threshold=0, ) with pytest.raises( @@ -373,11 +498,11 @@ def test_preference_dataset_errors(agent_trainer, random_fragmenter): dataset.push(fragments, preferences) -def test_preference_dataset_queue(agent_trainer, random_fragmenter): +def test_preference_dataset_queue(agent_trainer, random_fragmenter, rng): dataset = preference_comparisons.PreferenceDataset(max_size=5) trajectories = agent_trainer.sample(10) - gatherer = preference_comparisons.SyntheticGatherer() + gatherer = preference_comparisons.SyntheticGatherer(rng=rng) for i in range(6): fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=1) preferences = gatherer(fragments) @@ -389,11 +514,16 @@ def test_preference_dataset_queue(agent_trainer, random_fragmenter): assert len(dataset) == 5 -def test_store_and_load_preference_dataset(agent_trainer, random_fragmenter, tmp_path): +def test_store_and_load_preference_dataset( + agent_trainer, + random_fragmenter, + tmp_path, + rng, +): dataset = preference_comparisons.PreferenceDataset() trajectories = agent_trainer.sample(10) fragments = random_fragmenter(trajectories, fragment_length=2, num_pairs=2) - gatherer = preference_comparisons.SyntheticGatherer() + gatherer = preference_comparisons.SyntheticGatherer(rng=rng) preferences = gatherer(fragments) dataset.push(fragments, preferences) @@ -415,11 +545,13 @@ def test_exploration_no_crash( venv, random_fragmenter, custom_logger, + rng, ): agent_trainer = preference_comparisons.AgentTrainer( agent, reward_net, venv, + rng=rng, exploration_frac=0.5, ) main_trainer = preference_comparisons.PreferenceComparisons( @@ -429,6 +561,7 @@ def test_exploration_no_crash( transition_oversampling=2, fragment_length=5, fragmenter=random_fragmenter, + rng=rng, custom_logger=custom_logger, ) main_trainer.train(100, 10) @@ -441,6 +574,7 @@ def test_active_fragmenter_discount_rate_no_crash( random_fragmenter, uncertainty_on, custom_logger, + rng, ): # also use a non-zero noise probability to check that doesn't cause errors reward_net = reward_nets.RewardEnsemble( @@ -477,6 +611,7 @@ def test_active_fragmenter_discount_rate_no_crash( reward_trainer = preference_comparisons.EnsembleTrainer( preference_model, loss, + rng=rng, ) main_trainer = preference_comparisons.PreferenceComparisons( @@ -486,6 +621,7 @@ def test_active_fragmenter_discount_rate_no_crash( transition_oversampling=2, fragment_length=2, fragmenter=fragmenter, + rng=rng, reward_trainer=reward_trainer, custom_logger=custom_logger, ) @@ -507,6 +643,7 @@ def test_reward_trainer_regularization_no_crash( custom_logger, preference_model, interval_param_scaler, + rng, ): reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) loss = preference_comparisons.CrossEntropyRewardLoss() @@ -522,6 +659,7 @@ def test_reward_trainer_regularization_no_crash( loss, regularizer_factory=regularizer_factory, custom_logger=custom_logger, + rng=rng, ) main_trainer = preference_comparisons.PreferenceComparisons( @@ -533,6 +671,7 @@ def test_reward_trainer_regularization_no_crash( fragmenter=random_fragmenter, reward_trainer=reward_trainer, custom_logger=custom_logger, + rng=rng, ) main_trainer.train(50, 50) @@ -544,6 +683,7 @@ def test_reward_trainer_regularization_raises( custom_logger, preference_model, interval_param_scaler, + rng, ): reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) loss = preference_comparisons.CrossEntropyRewardLoss() @@ -559,6 +699,7 @@ def test_reward_trainer_regularization_raises( loss, regularizer_factory=regularizer_factory, custom_logger=custom_logger, + rng=rng, ) main_trainer = preference_comparisons.PreferenceComparisons( @@ -570,6 +711,7 @@ def test_reward_trainer_regularization_raises( fragmenter=random_fragmenter, reward_trainer=reward_trainer, custom_logger=custom_logger, + rng=rng, ) with pytest.raises( ValueError, @@ -662,12 +804,15 @@ def test_agent_trainer_sample(venv, agent_trainer): ) -def test_agent_trainer_sample_image_observations(): +def test_agent_trainer_sample_image_observations(rng): """Test `AgentTrainer.sample()` in an image environment. SB3 algorithms may rearrange the channel dimension in environments with image observations, but `sample()` should return observations matching the original environment. + + Args: + rng: Random number generator (with a fixed seed). """ venv = DummyVecEnv([lambda: FakeImageEnv()]) reward_net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) @@ -683,6 +828,7 @@ def test_agent_trainer_sample_image_observations(): reward_net, venv, exploration_frac=0.5, + rng=rng, ) trajectories = agent_trainer.sample(2) assert len(trajectories) > 0 diff --git a/tests/conftest.py b/tests/conftest.py index 6cf8a9f82..63e88b6b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from typing import Sequence import gym +import numpy as np import pytest import torch from filelock import FileLock @@ -26,6 +27,7 @@ def load_or_rollout_trajectories( cache_path, policy, venv, + rng, ) -> Sequence[TrajectoryWithRew]: os.makedirs(os.path.dirname(cache_path), exist_ok=True) with FileLock(cache_path + ".lock"): @@ -40,6 +42,7 @@ def load_or_rollout_trajectories( policy, venv, rollout.make_sample_until(min_timesteps=2000, min_episodes=57), + rng=rng, ) types.save(cache_path, rollouts) return rollouts @@ -71,6 +74,7 @@ def cartpole_expert_trajectories( cartpole_expert_policy, cartpole_venv, pytestconfig, + rng, ) -> Sequence[TrajectoryWithRew]: rollouts_path = str( pytestconfig.cache.makedir("experts") / CARTPOLE_ENV_NAME / "rollout.npz", @@ -79,6 +83,7 @@ def cartpole_expert_trajectories( rollouts_path, cartpole_expert_policy, cartpole_venv, + rng, ) @@ -92,12 +97,14 @@ def pendulum_venv() -> VecEnv: @pytest.fixture def pendulum_expert_policy() -> BasePolicy: - return PPO.load( + policy = PPO.load( load_from_hub( "HumanCompatibleAI/ppo-Pendulum-v1", "ppo-Pendulum-v1.zip", ), ).policy + assert policy is not None + return policy @pytest.fixture @@ -105,6 +112,7 @@ def pendulum_expert_trajectories( pendulum_expert_policy, pendulum_venv, pytestconfig, + rng, ) -> Sequence[TrajectoryWithRew]: rollouts_path = str( pytestconfig.cache.makedir("experts") / PENDULUM_ENV_NAME / "rollout.npz", @@ -113,6 +121,7 @@ def pendulum_expert_trajectories( rollouts_path, pendulum_expert_policy, pendulum_venv, + rng=rng, ) @@ -134,3 +143,8 @@ def torch_single_threaded(): @pytest.fixture() def custom_logger(tmpdir: str) -> logger.HierarchicalLogger: return logger.configure(tmpdir) + + +@pytest.fixture() +def rng() -> np.random.Generator: + return np.random.default_rng() diff --git a/tests/data/test_buffer.py b/tests/data/test_buffer.py index b64cf5fc7..7f8caf4e0 100644 --- a/tests/data/test_buffer.py +++ b/tests/data/test_buffer.py @@ -53,7 +53,7 @@ def test_buffer(capacity, chunk_len, sample_shape) -> None: buf = Buffer( capacity, sample_shapes={"a": sample_shape, "b": sample_shape}, - dtypes={"a": float, "b": float}, + dtypes={"a": np.dtype(float), "b": np.dtype(float)}, ) to_insert = 3 * capacity @@ -213,7 +213,10 @@ def test_buffer_init_errors(): def test_replay_buffer_init_errors(): - with pytest.raises(ValueError, match=r"Specified.* and environment"): + with pytest.raises( + ValueError, + match=r"Cannot specify both shape/dtype and also environment", + ): ReplayBuffer(15, venv="MockEnv", obs_shape=(10, 10)) with pytest.raises(ValueError, match=r"Shape or dtype missing.*"): ReplayBuffer(15, obs_shape=(10, 10), act_shape=(15,), obs_dtype=bool) diff --git a/tests/data/test_rollout.py b/tests/data/test_rollout.py index 911e6b0f5..34c136bcc 100644 --- a/tests/data/test_rollout.py +++ b/tests/data/test_rollout.py @@ -38,12 +38,14 @@ def step(self, action): def _sample_fixed_length_trajectories( episode_lengths: Sequence[int], min_episodes: int, + rng: np.random.Generator, policy_type: str = "policy", **kwargs, ) -> Sequence[types.Trajectory]: venv = vec_env.DummyVecEnv( [functools.partial(TerminalSentinelEnv, length) for length in episode_lengths], ) + policy: rollout.AnyPolicy if policy_type == "policy": policy = RandomPolicy(venv.observation_space, venv.action_space) elif policy_type == "callable": @@ -63,6 +65,7 @@ def policy(x): policy, venv, sample_until=sample_until, + rng=rng, **kwargs, ) return trajectories @@ -72,7 +75,7 @@ def policy(x): "policy_type", ["policy", "callable", "random"], ) -def test_complete_trajectories(policy_type) -> None: +def test_complete_trajectories(policy_type, rng) -> None: """Checks trajectories include the terminal observation. This is hidden by default by VecEnv's auto-reset; we add it back in using @@ -80,6 +83,7 @@ def test_complete_trajectories(policy_type) -> None: Args: policy_type: Kind of policy to use when generating trajectories. + rng: Random state to use. """ min_episodes = 13 max_acts = 5 @@ -88,6 +92,7 @@ def test_complete_trajectories(policy_type) -> None: [max_acts] * num_envs, min_episodes, policy_type=policy_type, + rng=rng, ) assert len(trajectories) >= min_episodes expected_obs = np.array([[0]] * max_acts + [[1]]) @@ -117,6 +122,7 @@ def test_unbiased_trajectories( episode_lengths: Sequence[int], min_episodes: int, expected_counts: Mapping[int, int], + rng, ) -> None: """Checks trajectories are sampled without bias towards shorter episodes. @@ -136,8 +142,13 @@ def test_unbiased_trajectories( min_episodes: The minimum number of episodes to sample. expected_counts: Mapping from episode length to expected number of episodes of that length (omit if 0 episodes of that length expected). + rng: Random state to use. """ - trajectories = _sample_fixed_length_trajectories(episode_lengths, min_episodes) + trajectories = _sample_fixed_length_trajectories( + episode_lengths, + min_episodes, + rng, + ) assert len(trajectories) == sum(expected_counts.values()) traj_lens = np.array([len(traj) for traj in trajectories]) for length, count in expected_counts.items(): @@ -153,9 +164,9 @@ def test_seed_trajectories(): However, `TerminalSentinelEnv` is fixed-length deterministic, so there are no such confounders in this test. """ - rng_a1 = np.random.RandomState(0) - rng_a2 = np.random.RandomState(0) - rng_b = np.random.RandomState(1) + rng_a1 = np.random.default_rng(0) + rng_a2 = np.random.default_rng(0) + rng_b = np.random.default_rng(1) traj_a1 = _sample_fixed_length_trajectories([3, 5], 2, rng=rng_a1) traj_a2 = _sample_fixed_length_trajectories([3, 5], 2, rng=rng_a2) traj_b = _sample_fixed_length_trajectories([3, 5], 2, rng=rng_b) @@ -175,10 +186,13 @@ def step(self, action): return obs / 2, rew / 2, done, info -def test_rollout_stats(): +def test_rollout_stats(rng): """Applying `ObsRewIncrementWrapper` halves the reward mean. `rollout_stats` should reflect this. + + Args: + rng: Random state to use (with fixed seed). """ env = gym.make("CartPole-v1") env = monitor.Monitor(env, None) @@ -186,7 +200,12 @@ def test_rollout_stats(): venv = vec_env.DummyVecEnv([lambda: env]) policy = serialize.load_policy("zero", venv) - trajs = rollout.generate_trajectories(policy, venv, rollout.make_min_episodes(10)) + trajs = rollout.generate_trajectories( + policy, + venv, + rollout.make_min_episodes(10), + rng=rng, + ) s = rollout.rollout_stats(trajs) np.testing.assert_allclose(s["return_mean"], s["monitor_return_mean"] / 2) @@ -195,10 +214,13 @@ def test_rollout_stats(): np.testing.assert_allclose(s["return_max"], s["monitor_return_max"] / 2) -def test_unwrap_traj(): +def test_unwrap_traj(rng): """Check that unwrap_traj reverses `ObsRewIncrementWrapper`. Also check that unwrapping twice is a no-op. + + Args: + rng: Random state to use (with fixed seed). """ env = gym.make("CartPole-v1") env = wrappers.RolloutInfoWrapper(env) @@ -206,7 +228,12 @@ def test_unwrap_traj(): venv = vec_env.DummyVecEnv([lambda: env]) policy = serialize.load_policy("zero", venv) - trajs = rollout.generate_trajectories(policy, venv, rollout.make_min_episodes(10)) + trajs = rollout.generate_trajectories( + policy, + venv, + rollout.make_min_episodes(10), + rng=rng, + ) trajs_unwrapped = [rollout.unwrap_traj(t) for t in trajs] trajs_unwrapped_twice = [rollout.unwrap_traj(t) for t in trajs_unwrapped] @@ -221,6 +248,23 @@ def test_unwrap_traj(): np.testing.assert_equal(t1.rews, t2.rews) +def test_unwrap_traj_raises_no_infos(): + """Check that unwrap_traj raises ValueError if no infos in trajectory.""" + with pytest.raises(ValueError, match="Trajectory must have infos to unwrap"): + acts = np.array([0]) + obs = np.array([0, 0]) + rews = np.array([0.0]) + rollout.unwrap_traj( + types.TrajectoryWithRew( + acts=acts, + obs=obs, + terminal=False, + rews=rews, + infos=None, + ), + ) + + def test_make_sample_until_errors(): with pytest.raises(ValueError, match="At least one.*"): rollout.make_sample_until(min_timesteps=None, min_episodes=None) @@ -251,18 +295,19 @@ def test_compute_returns(gamma): assert abs(rollout.discounted_sum(rewards, gamma) - returns) < 1e-8 -def test_generate_trajectories_type_error(): +def test_generate_trajectories_type_error(rng): venv = vec_env.DummyVecEnv([functools.partial(TerminalSentinelEnv, 1)]) sample_until = rollout.make_min_episodes(1) with pytest.raises(TypeError, match="Policy must be.*got instead"): rollout.generate_trajectories( "strings_are_not_valid_policies", venv, + rng=rng, sample_until=sample_until, ) -def test_generate_trajectories_value_error(): +def test_generate_trajectories_value_error(rng): venv = vec_env.DummyVecEnv([functools.partial(TerminalSentinelEnv, 1)]) sample_until = rollout.make_min_episodes(1) @@ -271,5 +316,6 @@ def test_generate_trajectories_value_error(): lambda obs: np.zeros(len(obs), dtype=int), venv, sample_until=sample_until, + rng=rng, deterministic_policy=True, ) diff --git a/tests/data/test_types.py b/tests/data/test_types.py index af8398ba3..dbfacf3b2 100644 --- a/tests/data/test_types.py +++ b/tests/data/test_types.py @@ -1,13 +1,12 @@ """Tests of `imitation.data.types`.""" - import contextlib import copy import dataclasses import os import pathlib import pickle -from typing import Any, Callable +from typing import Any, Callable, Sequence import gym import numpy as np @@ -27,7 +26,7 @@ LENGTHS = [0, 1, 2, 10] -def _check_1d_shape(fn: Callable[[np.ndarray], Any], length: float, expected_msg: str): +def _check_1d_shape(fn: Callable[[np.ndarray], Any], length: int, expected_msg: str): for shape in [(), (length, 1), (length, 2), (length - 1,), (length + 1,)]: with pytest.raises(ValueError, match=expected_msg): fn(np.zeros(shape)) @@ -208,6 +207,7 @@ def test_save_trajectories( use_rewards, type_safe, ): + chdir_context: contextlib.AbstractContextManager """Check that trajectories are properly saved.""" if use_chdir: # Test no relative path without directory edge-case. @@ -234,12 +234,13 @@ def test_save_trajectories( with pytest.raises(ValueError): types.save(save_path, [trajectory, trajectory_rew]) + loaded_trajs: Sequence[types.Trajectory] if type_safe: if use_rewards: loaded_trajs = types.load_with_rewards(save_path) else: with pytest.raises(ValueError): - loaded_trajs = types.load_with_rewards(save_path) + types.load_with_rewards(save_path) loaded_trajs = types.load(save_path) else: loaded_trajs = types.load(save_path) @@ -271,6 +272,7 @@ def test_invalid_trajectories( ValueError, match=r"infos when present must be present for each action.*", ): + assert traj.infos is not None dataclasses.replace(traj, infos=traj.infos[:-1]) with pytest.raises( ValueError, diff --git a/tests/data/test_wrappers.py b/tests/data/test_wrappers.py index 364bebb33..14a7626c8 100644 --- a/tests/data/test_wrappers.py +++ b/tests/data/test_wrappers.py @@ -52,13 +52,13 @@ def _make_buffering_venv( error_on_premature_reset: bool, ) -> BufferingWrapper: venv = DummyVecEnv([_CountingEnv] * 2) - venv = BufferingWrapper(venv, error_on_premature_reset) - venv.reset() - return venv + wrapped_venv = BufferingWrapper(venv, error_on_premature_reset) + wrapped_venv.reset() + return wrapped_venv -def _assert_equal_scrambled_vectors(a: np.ndarray, b: np.ndarray) -> bool: - """Returns True if `a` and `b` are identical up to sorting.""" +def _assert_equal_scrambled_vectors(a: np.ndarray, b: np.ndarray) -> None: + """Raises AssertionError if `a` and `b` are not identical up to sorting.""" assert a.shape == b.shape assert a.ndim == 1 np.testing.assert_allclose(np.sort(a), np.sort(b)) @@ -166,13 +166,13 @@ def make_env(ep_len): transitions_list.append(venv_buffer.pop_transitions()) # Build expected transitions - expect_obs = [] + expect_obs_list = [] for ep_len in episode_lengths: n_complete, remainder = divmod(n_steps, ep_len) - expect_obs.extend([np.arange(ep_len)] * n_complete) - expect_obs.append(np.arange(remainder)) + expect_obs_list.extend([np.arange(ep_len)] * n_complete) + expect_obs_list.append(np.arange(remainder)) - expect_obs = np.concatenate(expect_obs) + expect_obs = np.concatenate(expect_obs_list) expect_next_obs = expect_obs + 1 expect_acts = expect_obs * 2.1 expect_rews = expect_next_obs * 10 diff --git a/tests/policies/test_exploration_wrapper.py b/tests/policies/test_exploration_wrapper.py index b2972e5e6..48b7b2469 100644 --- a/tests/policies/test_exploration_wrapper.py +++ b/tests/policies/test_exploration_wrapper.py @@ -11,10 +11,11 @@ def constant_policy(obs): return np.zeros(len(obs), dtype=int) -def make_wrapper(random_prob, switch_prob): +def make_wrapper(random_prob, switch_prob, rng): venv = util.make_vec_env( "seals/CartPole-v0", n_envs=1, + rng=rng, ) return ( exploration_wrapper.ExplorationWrapper( @@ -22,13 +23,13 @@ def make_wrapper(random_prob, switch_prob): venv=venv, random_prob=random_prob, switch_prob=switch_prob, - seed=0, + rng=rng, ), venv, ) -def test_random_prob(): +def test_random_prob(rng): """Test that `random_prob` produces right behaviors of policy switching. The policy always makes an initial switch when ExplorationWrapper is applied. @@ -39,22 +40,25 @@ def test_random_prob(): (2) `random_prob=1.0`: Initial and following policies are always random policies. (3) `random_prob=0.5`: Around half-half for constant and random policies. + Args: + rng (np.random.Generator): random number generator. + Raises: ValueError: Unknown policy type to switch. """ - wrapper, _ = make_wrapper(random_prob=0.0, switch_prob=0.5) + wrapper, _ = make_wrapper(random_prob=0.0, switch_prob=0.5, rng=rng) assert wrapper.current_policy == constant_policy for _ in range(100): wrapper._switch() assert wrapper.current_policy == constant_policy - wrapper, _ = make_wrapper(random_prob=1.0, switch_prob=0.5) + wrapper, _ = make_wrapper(random_prob=1.0, switch_prob=0.5, rng=rng) assert wrapper.current_policy == wrapper._random_policy for _ in range(100): wrapper._switch() assert wrapper.current_policy == wrapper._random_policy - wrapper, _ = make_wrapper(random_prob=0.5, switch_prob=0.5) + wrapper, _ = make_wrapper(random_prob=0.5, switch_prob=0.5, rng=rng) num_random = 0 num_constant = 0 for _ in range(1000): @@ -70,7 +74,7 @@ def test_random_prob(): assert num_constant > 450 -def test_switch_prob(): +def test_switch_prob(rng): """Test that `switch_prob` produces right behaviors of policy switching. The policy always makes an initial switch when ExplorationWrapper is applied. @@ -80,8 +84,11 @@ def test_switch_prob(): (1) `switch_prob=0.0`: The policy never switches after initial switch. (2) `switch_prob=1.0`: The policy always switches and the distribution of policies is determined by `random_prob`. + + Args: + rng (np.random.Generator): random number generator. """ - wrapper, venv = make_wrapper(random_prob=0.5, switch_prob=0.0) + wrapper, venv = make_wrapper(random_prob=0.5, switch_prob=0.0, rng=rng) policy = wrapper.current_policy np.random.seed(0) obs = np.random.rand(100, 2) @@ -90,7 +97,7 @@ def test_switch_prob(): assert wrapper.current_policy == policy def _always_switch(random_prob, num_steps, seed): - wrapper, _ = make_wrapper(random_prob=random_prob, switch_prob=1.0) + wrapper, _ = make_wrapper(random_prob=random_prob, switch_prob=1.0, rng=rng) np.random.seed(seed) num_random = 0 num_constant = 0 @@ -116,10 +123,10 @@ def _always_switch(random_prob, num_steps, seed): assert num_constant == 1000 -def test_valid_output(): +def test_valid_output(rng): """Ensure that we test both the random and the wrapped policy at least once.""" for random_prob in [0.0, 0.5, 1.0]: - wrapper, venv = make_wrapper(random_prob=random_prob, switch_prob=0.5) + wrapper, venv = make_wrapper(random_prob=random_prob, switch_prob=0.5, rng=rng) np.random.seed(0) obs = np.random.rand(100, 2) for action in wrapper(obs): diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index ee8dcd804..c7953a006 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -2,6 +2,7 @@ import functools import pathlib +from typing import cast import gym import numpy as np @@ -24,11 +25,21 @@ @pytest.mark.parametrize("env_name", SIMPLE_ENVS) @pytest.mark.parametrize("policy_type", HARDCODED_TYPES) -def test_actions_valid(env_name, policy_type): +def test_actions_valid(env_name, policy_type, rng): """Test output actions of our custom policies always lie in action space.""" - venv = util.make_vec_env(env_name, n_envs=1, parallel=False) + venv = util.make_vec_env( + env_name, + n_envs=1, + parallel=False, + rng=rng, + ) policy = serialize.load_policy(policy_type, venv) - transitions = rollout.generate_transitions(policy, venv, n_timesteps=100) + transitions = rollout.generate_transitions( + policy, + venv, + n_timesteps=100, + rng=rng, + ) for a in transitions.acts: assert venv.action_space.contains(a) @@ -41,11 +52,15 @@ def test_actions_valid(env_name, policy_type): ("sac", SIMPLE_CONTINUOUS_ENV), ], ) -def test_save_stable_model_errors_and_warnings(tmpdir, policy_env_name_pair): +def test_save_stable_model_errors_and_warnings( + tmpdir, + policy_env_name_pair, + rng, +): """Check errors and warnings in `save_stable_model()`.""" policy, env_name = policy_env_name_pair tmpdir = pathlib.Path(tmpdir) - venv = util.make_vec_env(env_name) + venv = util.make_vec_env(env_name, rng=rng) # Trigger FileNotFoundError for no model.{zip,pkl} dir_a = tmpdir / "a" @@ -64,9 +79,14 @@ def test_save_stable_model_errors_and_warnings(tmpdir, policy_env_name_pair): serialize.load_policy(policy, venv, path=str(dir_nonexistent)) -def _test_serialize_identity(env_name, model_cfg, tmpdir): +def _test_serialize_identity(env_name, model_cfg, tmpdir, rng): """Test output actions of deserialized policy are same as original.""" - venv = util.make_vec_env(env_name, n_envs=1, parallel=False) + venv = util.make_vec_env( + env_name, + n_envs=1, + parallel=False, + rng=rng, + ) model_name, model_cls_name = model_cfg model_cls = registry.load_attr(model_cls_name) @@ -81,7 +101,7 @@ def _test_serialize_identity(env_name, model_cfg, tmpdir): venv, n_timesteps=1000, deterministic_policy=True, - rng=np.random.RandomState(0), + rng=np.random.default_rng(0), ) serialize.save_stable_model(tmpdir, model) @@ -93,7 +113,7 @@ def _test_serialize_identity(env_name, model_cfg, tmpdir): venv, n_timesteps=1000, deterministic_policy=True, - rng=np.random.RandomState(0), + rng=np.random.default_rng(0), ) assert np.allclose(orig_rollout.acts, new_rollout.acts) @@ -107,16 +127,21 @@ def _test_serialize_identity(env_name, model_cfg, tmpdir): @pytest.mark.parametrize("env_name", SIMPLE_ENVS) @pytest.mark.parametrize("model_cfg", NORMAL_CONFIGS) -def test_serialize_identity(env_name, model_cfg, tmpdir): +def test_serialize_identity(env_name, model_cfg, tmpdir, rng): """Test output actions of deserialized policy are same as original.""" - _test_serialize_identity(env_name, model_cfg, tmpdir) + _test_serialize_identity(env_name, model_cfg, tmpdir, rng) @pytest.mark.parametrize("env_name", [SIMPLE_CONTINUOUS_ENV]) @pytest.mark.parametrize("model_cfg", CONTINUOUS_ONLY_CONFIGS) -def test_serialize_identity_continuous_only(env_name, model_cfg, tmpdir): +def test_serialize_identity_continuous_only( + env_name, + model_cfg, + tmpdir, + rng, +): """Test serialize identity for continuous_only algorithms.""" - _test_serialize_identity(env_name, model_cfg, tmpdir) + _test_serialize_identity(env_name, model_cfg, tmpdir, rng) class ZeroModule(nn.Module): @@ -161,7 +186,11 @@ def test_normalize_features_extractor(obs_space: gym.Space) -> None: for i in range(10): obs = th.as_tensor([obs_space.sample()]) - obs = preprocessing.preprocess_obs(obs, obs_space) + # TODO(juan) the cast below is because preprocess_obs has too general a type. + # this should be replaced with an overload or a generic. + # https://github.com/DLR-RM/stable-baselines3/issues/1065 + obs = cast(th.Tensor, preprocessing.preprocess_obs(obs, obs_space)) + assert isinstance(obs, th.Tensor) flattened_obs = obs.flatten(1, -1) extracted = {k: extractor(obs) for k, extractor in extractors.items()} for k, v in extracted.items(): diff --git a/tests/policies/test_replay_buffer_wrapper.py b/tests/policies/test_replay_buffer_wrapper.py index 4ae73d39e..0eeee334b 100644 --- a/tests/policies/test_replay_buffer_wrapper.py +++ b/tests/policies/test_replay_buffer_wrapper.py @@ -29,36 +29,37 @@ def make_algo_with_wrapped_buffer( rl_cls: Type[off_policy_algorithm.OffPolicyAlgorithm], policy_cls: Type[BasePolicy], replay_buffer_class: Type[buffers.ReplayBuffer], + rng: np.random.Generator, buffer_size: int = 100, ) -> off_policy_algorithm.OffPolicyAlgorithm: - venv = util.make_vec_env("Pendulum-v1", n_envs=1) - rl_kwargs = dict( + venv = util.make_vec_env("Pendulum-v1", n_envs=1, rng=rng) + rl_algo = rl_cls( + policy=policy_cls, + policy_kwargs=dict(), + env=venv, + seed=42, replay_buffer_class=ReplayBufferRewardWrapper, replay_buffer_kwargs=dict( replay_buffer_class=replay_buffer_class, reward_fn=zero_reward_fn, ), buffer_size=buffer_size, - ) - rl_algo = rl_cls( - policy=policy_cls, - policy_kwargs=dict(), - env=venv, - seed=42, - **rl_kwargs, - ) + ) # type: ignore[call-arg] return rl_algo -def test_invalid_args(): +def test_invalid_args(rng): with pytest.raises( TypeError, match=r".*unexpected keyword argument 'replay_buffer_class'.*", ): + # we ignore the type because we are intentionally + # passing the wrong type for the test make_algo_with_wrapped_buffer( rl_cls=sb3.PPO, policy_cls=policies.ActorCriticPolicy, replay_buffer_class=buffers.ReplayBuffer, + rng=rng, ) with pytest.raises(AssertionError, match=r".*only ReplayBuffer is supported.*"): @@ -66,10 +67,11 @@ def test_invalid_args(): rl_cls=sb3.SAC, policy_cls=sb3.sac.policies.SACPolicy, replay_buffer_class=buffers.DictReplayBuffer, + rng=rng, ) -def test_wrapper_class(tmpdir): +def test_wrapper_class(tmpdir, rng): buffer_size = 15 total_timesteps = 20 @@ -78,6 +80,7 @@ def test_wrapper_class(tmpdir): policy_cls=sb3.sac.policies.SACPolicy, replay_buffer_class=buffers.ReplayBuffer, buffer_size=buffer_size, + rng=rng, ) rl_algo.learn(total_timesteps=total_timesteps) diff --git a/tests/rewards/test_reward_fn.py b/tests/rewards/test_reward_fn.py index 5e14aa3d6..7a0e42410 100644 --- a/tests/rewards/test_reward_fn.py +++ b/tests/rewards/test_reward_fn.py @@ -7,7 +7,7 @@ OBS = np.random.randint(0, 10, (64, 100)) ACTS = NEXT_OBS = OBS -DONES = np.zeros(64, dtype=np.bool) +DONES = np.zeros(64, dtype=np.bool_) def _funky_reward_fn(obs, act, next_obs, done): diff --git a/tests/rewards/test_reward_nets.py b/tests/rewards/test_reward_nets.py index c0e3212e0..c65332d0c 100644 --- a/tests/rewards/test_reward_nets.py +++ b/tests/rewards/test_reward_nets.py @@ -100,9 +100,9 @@ def torch_transitions() -> TorchTransitions: """A batch of states, actions, next_states, and dones as th.Tensors for Env2D.""" return ( th.zeros((10, 5, 5)), - th.zeros((10, 1), dtype=int), + th.zeros((10, 1), dtype=th.int), th.zeros((10, 5, 5)), - th.zeros((10,), dtype=bool), + th.zeros((10,), dtype=th.bool), ) @@ -224,8 +224,13 @@ def _sample(space, n): return np.array([space.sample() for _ in range(n)]) -def _make_env_and_save_reward_net(env_name, reward_type, tmpdir, is_image=False): - venv = util.make_vec_env(env_name, n_envs=1, parallel=False) +def _make_env_and_save_reward_net(env_name, reward_type, tmpdir, rng, is_image=False): + venv = util.make_vec_env( + env_name, + n_envs=1, + parallel=False, + rng=rng, + ) save_path = os.path.join(tmpdir, "norm_reward.pt") assert reward_type in [ @@ -259,11 +264,12 @@ def _make_env_and_save_reward_net(env_name, reward_type, tmpdir, is_image=False) return venv, save_path -def _is_reward_valid(env_name, reward_type, tmpdir, is_image): +def _is_reward_valid(env_name, reward_type, tmpdir, rng, is_image): venv, tmppath = _make_env_and_save_reward_net( env_name, reward_type, tmpdir, + rng, is_image=is_image, ) @@ -283,20 +289,25 @@ def _is_reward_valid(env_name, reward_type, tmpdir, is_image): @pytest.mark.parametrize("env_name", ENVS) @pytest.mark.parametrize("reward_type", DESERIALIZATION_TYPES) -def test_reward_valid(env_name, reward_type, tmpdir): +def test_reward_valid(env_name, reward_type, tmpdir, rng): """Test output of reward function is appropriate shape and type.""" - _is_reward_valid(env_name, reward_type, tmpdir, is_image=False) + _is_reward_valid(env_name, reward_type, tmpdir, rng, is_image=False) @pytest.mark.parametrize("env_name", IMAGE_ENVS) @pytest.mark.parametrize("reward_type", DESERIALIZATION_TYPES) -def test_reward_valid_image(env_name, reward_type, tmpdir): +def test_reward_valid_image(env_name, reward_type, tmpdir, rng): """Test output of reward function is appropriate shape and type.""" - _is_reward_valid(env_name, reward_type, tmpdir, is_image=True) + _is_reward_valid(env_name, reward_type, tmpdir, rng, is_image=True) -def test_strip_wrappers_basic(): - venv = util.make_vec_env("FrozenLake-v1", n_envs=1, parallel=False) +def test_strip_wrappers_basic(rng): + venv = util.make_vec_env( + "FrozenLake-v1", + n_envs=1, + parallel=False, + rng=rng, + ) net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) net = reward_nets.NormalizedRewardNet(net, networks.RunningNorm) net = serialize._strip_wrappers( @@ -309,8 +320,8 @@ def test_strip_wrappers_basic(): assert isinstance(net, reward_nets.BasicRewardNet) -def test_strip_wrappers_image_basic(): - venv = util.make_vec_env("Asteroids-v4", n_envs=1, parallel=False) +def test_strip_wrappers_image_basic(rng): + venv = util.make_vec_env("Asteroids-v4", n_envs=1, parallel=False, rng=rng) net = reward_nets.CnnRewardNet(venv.observation_space, venv.action_space) net = reward_nets.NormalizedRewardNet(net, networks.RunningNorm) net = serialize._strip_wrappers( @@ -323,8 +334,13 @@ def test_strip_wrappers_image_basic(): assert isinstance(net, reward_nets.CnnRewardNet) -def test_strip_wrappers_complex(): - venv = util.make_vec_env("FrozenLake-v1", n_envs=1, parallel=False) +def test_strip_wrappers_complex(rng): + venv = util.make_vec_env( + "FrozenLake-v1", + n_envs=1, + parallel=False, + rng=rng, + ) net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) net = reward_nets.ShapedRewardNet(net, _potential, discount_factor=0.99) net = reward_nets.NormalizedRewardNet(net, networks.RunningNorm) @@ -344,8 +360,8 @@ def test_strip_wrappers_complex(): assert isinstance(net, reward_nets.BasicRewardNet) -def test_strip_wrappers_image_complex(): - venv = util.make_vec_env("Asteroids-v4", n_envs=1, parallel=False) +def test_strip_wrappers_image_complex(rng): + venv = util.make_vec_env("Asteroids-v4", n_envs=1, parallel=False, rng=rng) net = reward_nets.CnnRewardNet(venv.observation_space, venv.action_space) net = reward_nets.ShapedRewardNet(net, _potential, discount_factor=0.99) net = reward_nets.NormalizedRewardNet(net, networks.RunningNorm) @@ -419,11 +435,12 @@ def forward(*args): @pytest.mark.parametrize("env_name", ENVS) -def test_cant_load_unnorm_as_norm(env_name, tmpdir): +def test_cant_load_unnorm_as_norm(env_name, tmpdir, rng): venv, tmppath = _make_env_and_save_reward_net( env_name, "RewardNet_unnormalized", tmpdir, + rng=rng, ) with pytest.raises(TypeError): serialize.load_reward("RewardNet_normalized", tmppath, venv) @@ -435,11 +452,17 @@ def _serialize_deserialize_identity( net_kwargs, normalize_rewards, tmpdir, + rng, ): """Does output of deserialized reward network match that of original?""" logging.info(f"Testing {net_cls}") - venv = util.make_vec_env(env_name, n_envs=1, parallel=False) + venv = util.make_vec_env( + env_name, + n_envs=1, + parallel=False, + rng=rng, + ) original = net_cls(venv.observation_space, venv.action_space, **net_kwargs) if normalize_rewards: original = reward_nets.NormalizedRewardNet(original, networks.RunningNorm) @@ -452,7 +475,12 @@ def _serialize_deserialize_identity( assert original.observation_space == loaded.observation_space assert original.action_space == loaded.action_space - transitions = rollout.generate_transitions(random, venv, n_timesteps=100) + transitions = rollout.generate_transitions( + random, + venv, + n_timesteps=100, + rng=rng, + ) if isinstance(original, reward_nets.NormalizedRewardNet): wrapped_rew_fn = serialize.load_reward("RewardNet_normalized", tmppath, venv) @@ -510,6 +538,7 @@ def test_serialize_identity( net_kwargs, normalize_rewards, tmpdir, + rng, ): """Does output of deserialized reward MLP match that of original?""" _serialize_deserialize_identity( @@ -518,6 +547,7 @@ def test_serialize_identity( net_kwargs, normalize_rewards, tmpdir, + rng, ) @@ -531,6 +561,7 @@ def test_serialize_identity_images( net_kwargs, normalize_rewards, tmpdir, + rng, ): """Does output of deserialized reward CNN match that of original?""" _serialize_deserialize_identity( @@ -539,6 +570,7 @@ def test_serialize_identity_images( net_kwargs, normalize_rewards, tmpdir, + rng, ) @@ -735,7 +767,9 @@ def test_predict_processed_wrappers_pass_on_kwargs( zero_reward_net: testing_reward_nets.MockRewardNet, numpy_transitions: NumpyTransitions, ): - zero_reward_net.predict_processed = mock.Mock(return_value=np.zeros((10,))) + zero_reward_net.predict_processed = mock.Mock( # type: ignore[assignment] + return_value=np.zeros((10,)), + ) wrapped_reward_net = make_predict_processed_wrapper( zero_reward_net, ) @@ -814,7 +848,7 @@ def forward(self): @pytest.mark.parametrize("normalize_input_layer", [None, networks.RunningNorm]) -def test_training_regression(normalize_input_layer): +def test_training_regression(normalize_input_layer, rng): """Test reward_net normalization by training a regression model.""" venv = DummyVecEnv([lambda: gym.make("CartPole-v0")] * 2) reward_net = reward_nets.BasicRewardNet( @@ -834,7 +868,12 @@ def test_training_regression(normalize_input_layer): # Getting transitions from a random policy random = base.RandomPolicy(venv.observation_space, venv.action_space) for _ in range(2): - transitions = rollout.generate_transitions(random, venv, n_timesteps=100) + transitions = rollout.generate_transitions( + random, + venv, + n_timesteps=100, + rng=rng, + ) trans_args = ( transitions.obs, transitions.acts, diff --git a/tests/rewards/test_reward_wrapper.py b/tests/rewards/test_reward_wrapper.py index 8d8d8fc92..893bff59a 100644 --- a/tests/rewards/test_reward_wrapper.py +++ b/tests/rewards/test_reward_wrapper.py @@ -17,20 +17,20 @@ def __call__(self, obs, act, next_obs, steps=None): return (np.arange(len(obs)) + 1).astype("float32") -def test_reward_overwrite(): +def test_reward_overwrite(rng): """Test that reward wrapper actually overwrites base rewards.""" env_name = "Pendulum-v1" num_envs = 3 - env = util.make_vec_env(env_name, num_envs) + env = util.make_vec_env(env_name, rng=rng, n_envs=num_envs) reward_fn = FunkyReward() wrapped_env = reward_wrapper.RewardVecEnvWrapper(env, reward_fn) policy = RandomPolicy(env.observation_space, env.action_space) sample_until = rollout.make_min_episodes(10) default_stats = rollout.rollout_stats( - rollout.generate_trajectories(policy, env, sample_until), + rollout.generate_trajectories(policy, env, sample_until, rng), ) wrapped_stats = rollout.rollout_stats( - rollout.generate_trajectories(policy, wrapped_env, sample_until), + rollout.generate_trajectories(policy, wrapped_env, sample_until, rng), ) # Pendulum-v1 always has negative rewards assert default_stats["return_max"] < 0 diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index f461dc889..5b8959312 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -15,7 +15,7 @@ import sys import tempfile from collections import Counter -from typing import Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional from unittest import mock import numpy as np @@ -470,6 +470,7 @@ def _check_train_ex_result(result: dict): assert "monitor_return_mean" not in expert_stats imit_stats = result.get("imit_stats") + assert isinstance(imit_stats, dict) _check_rollout_stats(imit_stats) @@ -600,8 +601,8 @@ def test_transfer_learning(tmpdir: str) -> None: Args: tmpdir: Temporary directory to save results to. """ - tmpdir = pathlib.Path(tmpdir) - log_dir_train = tmpdir / "train" + tmpdir_path = pathlib.Path(tmpdir) + log_dir_train = tmpdir_path / "train" run = train_adversarial.train_adversarial_ex.run( command_name="airl", named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["adversarial"], @@ -615,7 +616,7 @@ def test_transfer_learning(tmpdir: str) -> None: _check_rollout_stats(run.result["imit_stats"]) - log_dir_data = tmpdir / "train_rl" + log_dir_data = tmpdir_path / "train_rl" reward_path = log_dir_train / "checkpoints" / "final" / "reward_test.pt" run = train_rl.train_rl_ex.run( named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["rl"], @@ -649,9 +650,9 @@ def test_preference_comparisons_transfer_learning( tmpdir: Temporary directory to save results to. named_configs_dict: Named configs for preference_comparisons and rl. """ - tmpdir = pathlib.Path(tmpdir) + tmpdir_path = pathlib.Path(tmpdir) - log_dir_train = tmpdir / "train" + log_dir_train = tmpdir_path / "train" run = train_preference_comparisons.train_preference_comparisons_ex.run( named_configs=["pendulum"] + ALGO_FAST_CONFIGS["preference_comparison"] @@ -669,7 +670,7 @@ def test_preference_comparisons_transfer_learning( reward_type = "RewardNet_unnormalized" load_reward_kwargs = {} - log_dir_data = tmpdir / "train_rl" + log_dir_data = tmpdir_path / "train_rl" reward_path = log_dir_train / "checkpoints" / "final" / "reward_net.pt" agent_path = log_dir_train / "checkpoints" / "final" / "policy" run = train_rl.train_rl_ex.run( @@ -686,10 +687,18 @@ def test_preference_comparisons_transfer_learning( _check_rollout_stats(run.result) -def test_train_rl_double_normalization(tmpdir: str): - venv = util.make_vec_env("CartPole-v1", n_envs=1, parallel=False) - net = reward_nets.BasicRewardNet(venv.observation_space, venv.action_space) - net = reward_nets.NormalizedRewardNet(net, networks.RunningNorm) +def test_train_rl_double_normalization(tmpdir: str, rng): + venv = util.make_vec_env( + "CartPole-v1", + n_envs=1, + parallel=False, + rng=rng, + ) + basic_reward_net = reward_nets.BasicRewardNet( + venv.observation_space, + venv.action_space, + ) + net = reward_nets.NormalizedRewardNet(basic_reward_net, networks.RunningNorm) tmppath = os.path.join(tmpdir, "reward.pt") th.save(net, tmppath) @@ -782,14 +791,14 @@ def test_parallel_arg_errors(tmpdir): def _generate_test_rollouts(tmpdir: str, env_named_config: str) -> pathlib.Path: - tmpdir = pathlib.Path(tmpdir) + tmpdir_path = pathlib.Path(tmpdir) train_rl.train_rl_ex.run( named_configs=[env_named_config] + ALGO_FAST_CONFIGS["rl"], config_updates=dict( common=dict(log_dir=tmpdir), ), ) - rollout_path = tmpdir / "rollouts/final.pkl" + rollout_path = tmpdir_path / "rollouts/final.pkl" return rollout_path.absolute() @@ -854,7 +863,7 @@ def _run_train_bc_for_test_analyze_imit(run_name, sacred_logs_dir, log_dir): ), ) def test_analyze_imitation(tmpdir: str, run_names: List[str], run_sacred_fn): - sacred_logs_dir = tmpdir = pathlib.Path(tmpdir) + sacred_logs_dir = tmpdir_path = pathlib.Path(tmpdir) # Generate sacred logs (other logs are put in separate tmpdir for deletion). for run_name in run_names: @@ -870,8 +879,8 @@ def check(run_name: Optional[str], count: int) -> None: source_dirs=[sacred_logs_dir], env_name="seals/CartPole-v0", run_name=run_name, - csv_output_path=tmpdir / "analysis.csv", - tex_output_path=tmpdir / "analysis.tex", + csv_output_path=tmpdir_path / "analysis.csv", + tex_output_path=tmpdir_path / "analysis.tex", print_table=True, ), ) @@ -889,7 +898,7 @@ def test_analyze_gather_tb(tmpdir: str): if os.name == "nt": # pragma: no cover pytest.skip("gather_tb uses symlinks: not supported by Windows") - config_updates = dict(local_dir=tmpdir, run_name="test") + config_updates: Dict[str, Any] = dict(local_dir=tmpdir, run_name="test") config_updates.update(PARALLEL_CONFIG_LOW_RESOURCE) parallel_run = parallel.parallel_ex.run( named_configs=["generate_test_data"], diff --git a/tests/test_envs.py b/tests/test_envs.py index fa53931d1..eb55db036 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,4 +1,5 @@ """Tests for `imitation.envs.*`.""" +from typing import List import gym import numpy as np @@ -16,8 +17,7 @@ if env_spec.id.startswith("imitation/") ] -DETERMINISTIC_ENVS = [] - +DETERMINISTIC_ENVS: List[str] = [] env = pytest.fixture(seals_test.make_env_fixture(skip_fn=pytest.skip)) diff --git a/tests/test_regularization.py b/tests/test_regularization.py index 7c09a1ccf..fea0a6bff 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -87,11 +87,11 @@ def test_interval_param_scaler_raises(interval_param_scaler): with pytest.raises(ValueError, match="train_loss must be a scalar"): scaler(1.0, th.Tensor([1.0, 2.0]), 1.0) with pytest.raises(ValueError, match="train_loss must be a scalar"): - scaler(1.0, "random value", th.tensor(1.0)) # type: ignore + scaler(1.0, "random value", th.tensor(1.0)) with pytest.raises(ValueError, match="val_loss must be a scalar"): - scaler(1.0, 1.0, "random value") # type: ignore + scaler(1.0, 1.0, "random value") with pytest.raises(ValueError, match="lambda_ must be a float"): - scaler(th.tensor(1.0), 1.0, 1.0) # type: ignore + scaler(th.tensor(1.0), 1.0, 1.0) with pytest.raises(ValueError, match="lambda_ must not be zero.*"): scaler(0.0, 1.0, 1.0) with pytest.raises(ValueError, match="lambda_ must be non-negative.*"): @@ -131,12 +131,12 @@ def test_interval_param_scaler_init_raises(): ValueError, match="tolerable_interval must be a tuple of length 2", ): - updaters.IntervalParamScaler(0.5, (0.1, 0.9, 0.5)) # type: ignore + updaters.IntervalParamScaler(0.5, (0.1, 0.9, 0.5)) # type: ignore[arg-type] with pytest.raises( ValueError, match="tolerable_interval must be a tuple of length 2", ): - updaters.IntervalParamScaler(0.5, (0.1,)) # type: ignore + updaters.IntervalParamScaler(0.5, (0.1,)) # type: ignore[arg-type] # the first element of the interval must be at least 0. with pytest.raises( @@ -326,7 +326,7 @@ class SimpleLossRegularizer(regularizers.LossRegularizer): It multiplies the total loss by lambda_+1. """ - def _loss_penalty(self, loss: th.Tensor) -> th.Tensor: + def _loss_penalty(self, loss: regularizers.Scalar) -> regularizers.Scalar: return loss * self.lambda_ # this multiplies the total loss by lambda_+1. diff --git a/tests/util/test_networks.py b/tests/util/test_networks.py index 6e8bfc533..147e1185b 100644 --- a/tests/util/test_networks.py +++ b/tests/util/test_networks.py @@ -71,7 +71,7 @@ def update_stats(self, batch: th.Tensor) -> None: self.running_var += learning_rate * S - delta**2 self.count += b_size - self.num_batches += 1 + self.num_batches += 1 # type: ignore[misc] @pytest.mark.parametrize("normalization_layer", NORMALIZATION_LAYERS) @@ -243,6 +243,20 @@ def test_build_mlp_norm_training(init_kwargs) -> None: optimizer.step() +def test_build_mlp_raises_on_invalid_normalize_input_layer() -> None: + """Test that `networks.build_mlp()` raises on invalid input layer.""" + with pytest.raises( + ValueError, + match="normalize_input_layer.*not a valid normalization layer.*", + ): + networks.build_mlp( + in_size=1, + hid_sizes=[16, 16], + out_size=1, + normalize_input_layer=th.nn.Module, + ) + + def test_input_validation_on_ema_norm(): with pytest.raises(ValueError): networks.EMANorm(128, decay=1.1) diff --git a/tests/util/test_util.py b/tests/util/test_util.py index d7cd18cbe..ce663d8e0 100644 --- a/tests/util/test_util.py +++ b/tests/util/test_util.py @@ -19,12 +19,43 @@ def test_endless_iter(): assert next(it) == 0 assert next(it) == 1 assert next(it) == 0 + assert next(it) == 1 + assert next(it) == 0 def test_endless_iter_error(): x = [] with pytest.raises(ValueError, match="no elements"): util.endless_iter(x) + with pytest.raises(ValueError, match="needs a non-iterator Iterable"): + generator = (x for x in range(5)) + util.endless_iter(generator) + + +@given( + st.lists( + st.integers(), + min_size=1, + ), +) +def test_get_first_iter_element(input_seq): + with pytest.raises(ValueError, match="iterable.* had no elements"): + util.get_first_iter_element([]) + + first_element, new_iterable = util.get_first_iter_element(input_seq) + assert first_element == input_seq[0] + assert input_seq is new_iterable + + def generator_fn(): + for x in input_seq: + yield x + + generator = generator_fn() + assert generator == iter(generator) + first_element, new_iterable = util.get_first_iter_element(generator) + assert first_element == input_seq[0] + assert list(new_iterable) == input_seq + assert list(new_iterable) == [] @given( @@ -67,6 +98,13 @@ def test_safe_to_tensor(): assert not np.may_share_memory(numpy, torch) +def test_safe_to_numpy(): + tensor = th.tensor([1, 2, 3]) + numpy = util.safe_to_numpy(tensor) + assert (numpy == tensor.numpy()).all() + assert util.safe_to_numpy(None) is None + + def test_tensor_iter_norm(): # vector is [1,0,1,1,-5,-6]; its 2-norm is 8, and 1-norm is 14 tensor_list = [ diff --git a/tests/util/test_wb_logger.py b/tests/util/test_wb_logger.py index fe755ad31..f3b1a85a9 100644 --- a/tests/util/test_wb_logger.py +++ b/tests/util/test_wb_logger.py @@ -87,7 +87,9 @@ def finish(self): mock_wandb = MockWandb() -@mock.patch.object(wandb, "__init__", mock_wandb.__init__) +# we ignore the type below as one should technically not access the +# __init__ method directly but only by creating an instance. +@mock.patch.object(wandb, "__init__", mock_wandb.__init__) # type: ignore[misc] @mock.patch.object(wandb, "init", mock_wandb.init) @mock.patch.object(wandb, "log", mock_wandb.log) @mock.patch.object(wandb, "finish", mock_wandb.finish)