diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index ae79565f..0b0537af 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -16,7 +16,7 @@ - [ ] I have updated the documentation and previewed the changes via `mkdocs serve`. - [ ] I have updated the tests accordingly (if applicable). -If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR. +If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR. Feel free to remove this section if you don't need it. - [ ] I have contacted [vwxyzjn](https://github.com/vwxyzjn) to obtain access to the [openrlbenchmark W&B team](https://wandb.ai/openrlbenchmark) (**required**). - [ ] I have tracked applicable experiments in [openrlbenchmark/cleanrl](https://wandb.ai/openrlbenchmark/cleanrl) with `--capture-video` flag toggled on (**required**). - [ ] I have added additional documentation and previewed the changes via `mkdocs serve`. diff --git a/.gitignore b/.gitignore index 6e70cef4..69f6da45 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +compare.pdf +compare.png balance_bot.xml cleanrl/ppo_continuous_action_isaacgym/isaacgym/examples cleanrl/ppo_continuous_action_isaacgym/isaacgym/isaacgym diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a7addd7..7319e7f3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,7 +35,7 @@ repos: hooks: - id: codespell args: - - --ignore-words-list=nd,reacher,thist,ths,magent + - --ignore-words-list=nd,reacher,thist,ths,magent,nin - --skip=docs/css/termynal.css,docs/js/termynal.js - repo: https://github.com/python-poetry/poetry rev: 1.2.1 diff --git a/cleanrl_utils/rlops.py b/cleanrl_utils/rlops.py new file mode 100644 index 00000000..d2ae40af --- /dev/null +++ b/cleanrl_utils/rlops.py @@ -0,0 +1,250 @@ +import argparse +import os +from distutils.util import strtobool +from typing import List +from urllib.parse import parse_qs, urlparse + +import expt +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import wandb +import wandb.apis.reports as wb # noqa +from expt import Hypothesis, Run +from rich.console import Console + +wandb.require("report-editing") +api = wandb.Api() + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--wandb-project-name", type=str, default="cleanrl", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default="openrlbenchmark", + help="the entity (team) of wandb's project") + # TODO: update the docs for filter + parser.add_argument("--filters", nargs="+", default=["v1.0.0b2-9-g4605546", "rlops-pilot"], + help='the tags of the runsets (e.g., `--tags v1.0.0b2-9-g4605546 rlops-pilot` and you can also use `--tags "v1.0.0b2-9-g4605546;latest"` to filter runs with multiple tags)') + parser.add_argument("--env-ids", nargs="+", default=["Hopper-v2", "Walker2d-v2", "HalfCheetah-v2"], + help="the ids of the environment to compare") + parser.add_argument("--output-filename", type=str, default="rlops_static/compare.png", + help="the output filename of the plot") + parser.add_argument("--rolling", type=int, default=100, + help="the rolling window for smoothing the curves") + parser.add_argument("--metric-last-n-average-window", type=int, default=100, + help="the last n number of episodes to average metric over in the result table") + parser.add_argument("--scan-history", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, we will pull the complete metrics from wandb instead of sampling 500 data points (recommended for generating tables)") + parser.add_argument("--report", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, a wandb report will be created") + # fmt: on + return parser.parse_args() + + +def create_hypothesis(name: str, wandb_runs: List[wandb.apis.public.Run], scan_history: bool = False) -> Hypothesis: + runs = [] + for idx, run in enumerate(wandb_runs): + if scan_history: + wandb_run = pd.DataFrame([row for row in run.scan_history()]) + else: + wandb_run = run.history() + if "videos" in wandb_run: + wandb_run = wandb_run.drop(columns=["videos"], axis=1) + runs += [Run(f"seed{idx}", wandb_run)] + return Hypothesis(name, runs) + + +class Runset: + def __init__(self, name: str, filters: dict, entity: str, project: str, groupby: str = "", color: str = "#000000"): + self.name = name + self.filters = filters + self.entity = entity + self.project = project + self.groupby = groupby + self.color = color + + @property + def runs(self): + return wandb.Api().runs(path=f"{self.entity}/{self.project}", filters=self.filters) + + @property + def report_runset(self): + return wb.Runset( + name=self.name, + entity=self.entity, + project=self.project, + filters={"$or": [self.filters]}, + groupby=[self.groupby] if len(self.groupby) > 0 else None, + ) + + +def compare( + runsetss: List[List[Runset]], + env_ids: List[str], + ncols: int, + rolling: int, + metric_last_n_average_window: int, + scan_history: bool = False, + output_filename: str = "compare.png", +): + blocks = [] + for idx, env_id in enumerate(env_ids): + pg = wb.PanelGrid( + runsets=[runsets[idx].report_runset for runsets in runsetss], + panels=[ + wb.LinePlot( + x="global_step", + y=["charts/episodic_return"], + title=env_id, + title_x="Steps", + title_y="Episodic Return", + max_runs_to_show=100, + smoothing_factor=0.8, + groupby_rangefunc="stderr", + legend_template="${runsetName}", + ), + wb.LinePlot( + x="_runtime", + y=["charts/episodic_return"], + title=env_id, + title_y="Episodic Return", + max_runs_to_show=100, + smoothing_factor=0.8, + groupby_rangefunc="stderr", + legend_template="${runsetName}", + ), + # wb.MediaBrowser( + # num_columns=2, + # media_keys="videos", + # ), + ], + ) + custom_run_colors = {} + for runsets in runsetss: + custom_run_colors.update( + {(runsets[idx].report_runset.name, runsets[idx].runs[0].config["exp_name"]): runsets[idx].color} + ) + pg.custom_run_colors = custom_run_colors # IMPORTANT: custom_run_colors is implemented as a custom `setter` that needs to be overwritten unlike regular dictionaries + blocks += [pg] + + nrows = np.ceil(len(env_ids) / ncols).astype(int) + figsize = (ncols * 4, nrows * 3) + fig, axes = plt.subplots( + nrows=nrows, + ncols=ncols, + figsize=figsize, + # sharex=True, + # sharey=True, + ) + + result_table = pd.DataFrame(index=env_ids, columns=[runsets[0].name for runsets in runsetss]) + for idx, env_id in enumerate(env_ids): + ex = expt.Experiment("Comparison") + for runsets in runsetss: + h = create_hypothesis(runsets[idx].name, runsets[idx].runs, scan_history) + ex.add_hypothesis(h) + + # for each run `i` get the average of the last `rolling` episodes as r_i + # then take the average and std of r_i as the results. + result = [] + for hypothesis in ex.hypotheses: + raw_result = [] + for run in hypothesis.runs: + raw_result += [run.df["charts/episodic_return"].dropna()[-metric_last_n_average_window:].mean()] + raw_result = np.array(raw_result) + result += [f"{raw_result.mean():.2f} ± {raw_result.std():.2f}"] + result_table.loc[env_id] = result + + ax = axes.flatten()[idx] + ex.plot( + ax=ax, + title=env_id, + x="global_step", + y="charts/episodic_return", + err_style="band", + std_alpha=0.1, + rolling=rolling, + colors=[runsets[idx].color for runsets in runsetss], + # n_samples=500, + legend=False, + ) + + print(result_table) + + h, l = ax.get_legend_handles_labels() + fig.legend(h, l, loc="upper center", ncol=ncols) + num_legend_rows = len(h) // 2 + # dynamically adjust the top of subplot to make room for legend + fig.subplots_adjust(top=1 - 0.07 * num_legend_rows) + # remove the empty axes + for ax in axes.flatten()[len(env_ids) :]: + ax.remove() + + print(f"saving figure to {output_filename}") + if os.path.dirname(output_filename) != "": + os.makedirs(os.path.dirname(output_filename), exist_ok=True) + plt.savefig(f"{output_filename}", bbox_inches="tight") + plt.savefig(f"{output_filename.replace('.png', '.pdf')}", bbox_inches="tight") + return blocks + + +if __name__ == "__main__": + args = parse_args() + console = Console() + blocks = [] + runsetss = [] + colors = sns.color_palette(n_colors=len(args.filters)).as_hex() + for filter_str, color in zip(args.filters, colors): + print("=========", filter_str) + parse_result = urlparse(filter_str) + exp_name = parse_result.path + query = parse_qs(parse_result.query) + user = [{"username": query["user"][0]}] if "user" in query else [] + include_tag_groups = [{"tags": {"$in": [tag]}} for tag in query["tag"]] if "tag" in query else [] + runsets = [] + for env_id in args.env_ids: + runsets += [ + Runset( + name=f"CleanRL's {exp_name} ({query})", + filters={ + "$and": [ + {"config.env_id.value": env_id}, + *include_tag_groups, + *user, + {"config.exp_name.value": exp_name}, + ] + }, + entity=args.wandb_entity, + project=args.wandb_project_name, + groupby="exp_name", + color=color, + ) + ] + console.print(f"CleanRL's {exp_name} [green]({query})[/] in [purple]{env_id}[/] has {len(runsets[-1].runs)} runs") + for run in runsets[-1].runs: + console.print(f"┣━━ [link={run.url}]{run.name}[/link] with tags = {run.tags}") + assert len(runsets[0].runs) > 0, f"CleanRL's {exp_name} ({query}) in {env_id} has no runs" + runsetss += [runsets] + + blocks = compare( + runsetss, + args.env_ids, + output_filename=args.output_filename, + ncols=2, + rolling=args.rolling, + metric_last_n_average_window=args.metric_last_n_average_window, + scan_history=args.scan_history, + ) + if args.report: + print("saving report") + report = wb.Report( + project="cleanrl", + title=f"Regression Report: {exp_name}", + description=str(args.filters), + blocks=blocks, + ) + report.save() + print(f"view the generated report at {report.url}") diff --git a/cleanrl_utils/rlops_tags.py b/cleanrl_utils/rlops_tags.py new file mode 100644 index 00000000..fc15ecb0 --- /dev/null +++ b/cleanrl_utils/rlops_tags.py @@ -0,0 +1,80 @@ +import argparse +from urllib.parse import parse_qs, urlparse + +import wandb +from rich.console import Console + +api = wandb.Api() + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--wandb-project-name", type=str, default="cleanrl", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default="openrlbenchmark", + help="the entity (team) of wandb's project") + parser.add_argument("--filters", nargs="+", default=["v1.0.0b2-9-g4605546", "rlops-pilot"], + help='the tags of the runsets (e.g., `--tags v1.0.0b2-9-g4605546 rlops-pilot` and you can also use `--tags "v1.0.0b2-9-g4605546;latest"` to filter runs with multiple tags)') + parser.add_argument("--add", type=str, default="", + help="the tag to be added to any runs with the `--source-tag`") + parser.add_argument("--remove", type=str, default="", + help="the tag to be removed from any runs with the `--source-tag`") + parser.add_argument("--source-tag", type=str, default="v1.0.0b2-7-g4bb6766", + help="the source tag of the set of runs") + # fmt: on + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + console = Console() + + # parse filter string + for filter_str in args.filters: + parse_result = urlparse(filter_str) + exp_name = parse_result.path + query = parse_qs(parse_result.query) + user = [{"username": query["user"][0]}] if "user" in query else [] + include_tag_groups = [{"tags": {"$in": [tag]}} for tag in query["tag"]] if "tag" in query else [] + metric = query["metric"][0] if "metric" in query else "charts/episodic_return" + wandb_project_name = query["wpn"][0] if "wpn" in query else args.wandb_project_name + wandb_entity = query["we"][0] if "we" in query else args.wandb_entity + custom_env_id_key = query["ceik"][0] if "ceik" in query else "env_id" + + runs = api.runs( + path=f"{args.wandb_entity}/{args.wandb_project_name}", + filters={ + "$and": [ + *include_tag_groups, + *user, + {"config.exp_name.value": exp_name}, + ] + }, + ) + print(len(runs)) + confirmation_str = "You are about to make the following changes:\n" + modified_runs = [] + for run in runs: + tags = run.tags + if args.add and args.add not in tags and args.source_tag in tags: + confirmation_str += ( + f"Adding the tag '{args.add}' to [link={run.url}]{run.name}[/link], which has tags {str(tags)}\n" + ) + tags.append(args.add) + run.tags = tags + modified_runs.append(run) + if args.remove and args.remove in tags and args.source_tag in tags: + confirmation_str += ( + f"Removing the tag '{args.remove}' from [link={run.url}]{run.name}[/link], which has tags {str(tags)}\n" + ) + tags.remove(args.remove) + run.tags = tags + modified_runs.append(run) + + console.print(confirmation_str) + response = input("Are you sure you want to proceed? (y/n):") + if response.lower() == "y": + for run in modified_runs: + print(f"Updating {run.name}") + run.update() diff --git a/docs/advanced/rlops.md b/docs/advanced/rlops.md new file mode 100644 index 00000000..97527877 --- /dev/null +++ b/docs/advanced/rlops.md @@ -0,0 +1,223 @@ +# RLops + +This document describes how to we do "RLops" to validate new features / bug fixes and avoid introducing regressions. + + +## Background +DRL is brittle and has a series of reproducibility issues — even bug fixes sometimes could introduce performance regression (e.g., see [how a bug fix of contact force in MuJoCo results in worse performance for PPO](https://github.com/openai/gym/pull/2762#discussion_r853488897)). Therefore, it is essential to understand how the proposed changes impact the performance of the algorithms. At large, we wish to distinguish two types of contributions: 1) **non-performance-impacting changes** and 2) **performance-impacting changes**. + +* **non-performance-impacting changes**: this type of change does *not* impact the performance of the algorithm, such as documentation fixes ([:material-github: #282](https://github.com/vwxyzjn/cleanrl/pull/282)), renaming variables ([:material-github: #257](https://github.com/vwxyzjn/cleanrl/pull/257)), and removing unused code ([:material-github: #287](https://github.com/vwxyzjn/cleanrl/pull/287)). For this type of change, we can easily merge them without worrying too much about the consequences. +* **performance-impacting changes**: this type of change impacts the algorithm's performance. Examples include making a slight modification to the `gamma` parameter in PPO ([:material-github: #209](https://github.com/vwxyzjn/cleanrl/pull/209)), properly handling action bounds in DDPG ([:material-github: #211](https://github.com/vwxyzjn/cleanrl/pull/211)), and fixing bugs ([:material-github: #281](https://github.com/vwxyzjn/cleanrl/pull/281)) + + +**Importantly, regardless of the slight difference in performance-impacting changes, we need to re-run the benchmark to ensure there is no regression**. This post proposes a way for us to re-run the model and check regression seamlessly. + +## Methodology + + +### (Step 1) Run the benchmark + +Given a new feature, we create a PR and then run the benchmark experiments through [`benchmark.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl_utils/benchmark.py), such as the following: + +```bash +poetry install --with mujoco,pybullet +python -c "import mujoco_py" +xvfb-run -a python -m cleanrl_utils.benchmark \ + --env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \ + --command "poetry run python cleanrl/ddpg_continuous_action.py --track --capture-video" \ + --num-seeds 3 \ + --workers 1 +``` + +under the hood, this script will invoke an `--autotag` feature that tries to tag the the experiments with version control information, such as the git tag (e.g., `v1.0.0b2-8-g6081d30`) and the github PR number (e.g., `pr-299`). This is useful for us to compare the performance of the same algorithm across different versions. + +![](./rlops/tags.png) + +### (Step 2) Regression check + +Let's say our latest experiments is tagged with `pr-299`. We can then run the following command to compare its performance with our pilot experiments `rlops-pilot`. Note that the pilot experiments include all experiments before we started using RLops (i.e., `rlops-pilot` is the baseline). + + +```bash +python -m cleanrl_utils.rlops --exp-name ddpg_continuous_action \ + --wandb-project-name cleanrl \ + --wandb-entity openrlbenchmark \ + --tags 'pr-299' 'rlops-pilot' \ + --env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \ + --output-filename compare.png \ + --scan-history \ + --metric-last-n-average-window 100 \ + --report +``` +``` +CleanRL's ddpg_continuous_action (pr-299) in HalfCheetah-v2 has 3 runs +┣━━ HalfCheetah-v2__ddpg_continuous_action__4__1667280971 with tags = ['pr-299', 'v1.0.0b2-8-g6081d30'] +┣━━ HalfCheetah-v2__ddpg_continuous_action__3__1667271574 with tags = ['pr-299', 'v1.0.0b2-8-g6081d30'] +┣━━ HalfCheetah-v2__ddpg_continuous_action__2__1667261986 with tags = ['pr-299', 'v1.0.0b2-8-g6081d30'] +CleanRL's ddpg_continuous_action (pr-299) in Walker2d-v2 has 3 runs +┣━━ Walker2d-v2__ddpg_continuous_action__4__1667284233 with tags = ['pr-299', 'v1.0.0b2-8-g6081d30'] +┣━━ Walker2d-v2__ddpg_continuous_action__3__1667274709 with tags = ['pr-299', 'v1.0.0b2-8-g6081d30'] +┣━━ Walker2d-v2__ddpg_continuous_action__2__1667265261 with tags = ['pr-299', 'v1.0.0b2-8-g6081d30'] +CleanRL's ddpg_continuous_action (pr-299) in Hopper-v2 has 3 runs +┣━━ Hopper-v2__ddpg_continuous_action__4__1667287363 with tags = ['pr-299', 'v1.0.0b2-8-g6081d30'] +┣━━ Hopper-v2__ddpg_continuous_action__3__1667277826 with tags = ['pr-299', 'v1.0.0b2-8-g6081d30'] +┣━━ Hopper-v2__ddpg_continuous_action__2__1667268434 with tags = ['pr-299', 'v1.0.0b2-8-g6081d30'] +CleanRL's ddpg_continuous_action (rlops-pilot) in HalfCheetah-v2 has 3 runs +┣━━ HalfCheetah-v2__ddpg_continuous_action__3__1651008691 with tags = ['latest', 'rlops-pilot'] +┣━━ HalfCheetah-v2__ddpg_continuous_action__2__1651004631 with tags = ['latest', 'rlops-pilot'] +┣━━ HalfCheetah-v2__ddpg_continuous_action__1__1651000539 with tags = ['latest', 'rlops-pilot'] +CleanRL's ddpg_continuous_action (rlops-pilot) in Walker2d-v2 has 3 runs +┣━━ Walker2d-v2__ddpg_continuous_action__3__1651008768 with tags = ['latest', 'rlops-pilot'] +┣━━ Walker2d-v2__ddpg_continuous_action__2__1651004640 with tags = ['latest', 'rlops-pilot'] +┣━━ Walker2d-v2__ddpg_continuous_action__1__1651000539 with tags = ['latest', 'rlops-pilot'] +CleanRL's ddpg_continuous_action (rlops-pilot) in Hopper-v2 has 3 runs +┣━━ Hopper-v2__ddpg_continuous_action__3__1651008797 with tags = ['latest', 'rlops-pilot'] +┣━━ Hopper-v2__ddpg_continuous_action__2__1651004715 with tags = ['latest', 'rlops-pilot'] +┣━━ Hopper-v2__ddpg_continuous_action__1__1651000539 with tags = ['latest', 'rlops-pilot'] + + + CleanRL's ddpg_continuous_action (pr-299) CleanRL's ddpg_continuous_action (rlops-pilot) +HalfCheetah-v2 10210.57 ± 196.22 9205.65 ± 1093.88 +Walker2d-v2 1661.14 ± 250.01 1447.09 ± 260.24 +Hopper-v2 1007.44 ± 148.29 1126.37 ± 278.02 +``` + + +which could generate the table above, which reports the mean and standard deviation of the performance of the algorithm in the last 100 episodes by scanning the entire training data (enabled by `--scan-history`). + +!!! info + + To make the script run faster, we can choose not to use `--scan-history` which allows wandb to sample 500 data points from the training data. This is the default behavior and is much faster. + + +It also generates the following image and a wandb report. + +![](./rlops/rlops.png) + + + + + +!!! info + + **Support for multiple tags, their inclusions and exclusions, and filter by users**: The syntax looks like `--tags "tag1;tag2!tag3;tag4?user1"`, where tag1 and tag2 are included, tag3 and tag4 are excluded, and user1 is included. Here are some examples: + + ```bash + python -m cleanrl_utils.rlops --exp-name ddpg_continuous_action_jax \ + --wandb-project-name cleanrl \ + --wandb-entity openrlbenchmark \ + --filters 'ddpg_continuous_action_jax?user=costa-huang&tag=rlops-pilot' 'ddpg_continuous_action_jax?user=costa-huang&tag=pr-298' \ + --env-ids Hopper-v2 Walker2d-v2 HalfCheetah-v2 \ + --output-filename compare.png \ + --report + + python -m cleanrl_utils.rlops --wandb-project-name cleanrl \ + --wandb-entity openrlbenchmark \ + --filters 'ddpg_continuous_action_jax?user=joaogui1&tag=rlops-pilot' 'ddpg_continuous_action_jax?user=joaogui1&tag=pr-298' \ + --env-ids Hopper-v2 Walker2d-v2 HalfCheetah-v2 \ + --output-filename compare.png \ + --report + ``` + + +!!! warning + + The RLops procedure (e.g., `cleanrl_utils.rlops` script) is still in its early stage. Please feel free to open an issue if you have any questions or suggestions. + +### (Step 3) Update the documentation + +Once we confirm there is no regression in the performance, we can update the documentation to display the new benchmark results. Run the previous command without comparing previous tags: + +```bash +python -m cleanrl_utils.rlops --exp-name ddpg_continuous_action \ + --wandb-project-name cleanrl \ + --wandb-entity openrlbenchmark \ + --tags 'pr-299' \ + --env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \ + --output-filename compare.png \ + --scan-history \ + --metric-last-n-average-window 100 +``` + +which gives us a table like below and a `compare.png` as the learning curve. + +``` + CleanRL's ddpg_continuous_action (pr-299) +HalfCheetah-v2 10210.57 ± 196.22 +Walker2d-v2 1661.14 ± 250.01 +Hopper-v2 1007.44 ± 148.29 +``` + +We will use them to update the [experimental result section](https://github.com/vwxyzjn/cleanrl/blob/master/docs/rl-algorithms/ddpg.md#experiment-results) in the docs and replace the learning curves with the new ones. + + +![](./rlops/docs-update.png) + +### (Step 4) Update the tags + + +As the last step before merging the PR, we shall update the wandb labels. We would label the new experiments as `latest` (and remove the tag `latest` for `v1.0.0b2-7-gxfd3d3` correspondingly. + +```bash +python -m cleanrl_utils.rlops_tags \ + --add latest \ + --source-tag pr-299 \ + --filters 'ddpg_continuous_action' \ + --wandb-project-name cleanrl \ + --wandb-entity openrlbenchmark +``` +``` +You are about to make the following changes: +Adding the tag 'latest' to HalfCheetah-v2__ddpg_continuous_action__3__1667429781, which has tags ['pr-299', 'v1.0.0b2-8-g6081d30'] +Adding the tag 'latest' to Hopper-v2__ddpg_continuous_action__3__1667425897, which has tags ['pr-299', 'v1.0.0b2-8-g6081d30'] +Adding the tag 'latest' to Walker2d-v2__ddpg_continuous_action__3__1667422863, which has tags ['pr-299', 'v1.0.0b2-8-g6081d30'] +Adding the tag 'latest' to HalfCheetah-v2__ddpg_continuous_action__2__1667419688, which has tags ['pr-299', 'v1.0.0b2-8-g6081d30'] +Adding the tag 'latest' to Hopper-v2__ddpg_continuous_action__2__1667415702, which has tags ['pr-299', 'v1.0.0b2-8-g6081d30'] +Adding the tag 'latest' to Walker2d-v2__ddpg_continuous_action__2__1667412699, which has tags ['pr-299', 'v1.0.0b2-8-g6081d30'] +Adding the tag 'latest' to HalfCheetah-v2__ddpg_continuous_action__1__1667409617, which has tags ['pr-299', 'v1.0.0b2-8-g6081d30'] +Adding the tag 'latest' to Hopper-v2__ddpg_continuous_action__1__1667405668, which has tags ['pr-299', 'v1.0.0b2-8-g6081d30'] +Adding the tag 'latest' to Walker2d-v2__ddpg_continuous_action__1__1667402741, which has tags ['pr-299', 'v1.0.0b2-8-g6081d30'] + +Are you sure you want to proceed? (y/n): +``` + +Press `y` to confirm the changes after reviewing them. + +Then, we shall remove the tag `latest` from the previous experiments. + +```bash +python -m cleanrl_utils.rlops_tags \ + --remove latest \ + --source-tag rlops-pilot \ + --filters 'ddpg_continuous_action' \ + --wandb-project-name cleanrl \ + --wandb-entity openrlbenchmark +``` +``` +Removing the tag 'latest' from Walker2d-v2__ddpg_continuous_action__1__1656400724, which has tags ['latest', 'rlops-pilot'] +Removing the tag 'latest' from Walker2d-v2__ddpg_continuous_action__2__1656400725, which has tags ['latest', 'rlops-pilot'] +Removing the tag 'latest' from HalfCheetah-v2__ddpg_continuous_action__1__1656400725, which has tags ['latest', 'rlops-pilot'] +Removing the tag 'latest' from Walker2d-v2__ddpg_continuous_action__3__1656400724, which has tags ['latest', 'rlops-pilot'] +Removing the tag 'latest' from Hopper-v2__ddpg_continuous_action__2__1656400724, which has tags ['latest', 'rlops-pilot'] +Removing the tag 'latest' from Hopper-v2__ddpg_continuous_action__3__1656400724, which has tags ['latest', 'rlops-pilot'] +Removing the tag 'latest' from HalfCheetah-v2__ddpg_continuous_action__2__1656400724, which has tags ['latest', 'rlops-pilot'] +Removing the tag 'latest' from HalfCheetah-v2__ddpg_continuous_action__3__1656400724, which has tags ['latest', 'rlops-pilot'] +Removing the tag 'latest' from Hopper-v2__ddpg_continuous_action__1__1656400724, which has tags ['latest', 'rlops-pilot'] + +Are you sure you want to proceed? (y/n): +``` +Press `y` to confirm the changes after reviewing them. + +### (Step 5) Merge the PR + +Finally, we can merge the PR. + + +## Checklist + +Here is a checklist that sumaizes the steps above. + +- [ ] Run the benchmark on the new code +- [ ] Compare the results w/ the previous version via `python -m cleanrl_utils.rlops` +- [ ] Update the tags +- [ ] Once regression tests pass, update the documentation \ No newline at end of file diff --git a/docs/advanced/rlops/docs-update.png b/docs/advanced/rlops/docs-update.png new file mode 100644 index 00000000..c945de08 Binary files /dev/null and b/docs/advanced/rlops/docs-update.png differ diff --git a/docs/advanced/rlops/rlops.png b/docs/advanced/rlops/rlops.png new file mode 100644 index 00000000..e1572a64 Binary files /dev/null and b/docs/advanced/rlops/rlops.png differ diff --git a/docs/advanced/rlops/tags.png b/docs/advanced/rlops/tags.png new file mode 100644 index 00000000..00129305 Binary files /dev/null and b/docs/advanced/rlops/tags.png differ diff --git a/mkdocs.yml b/mkdocs.yml index 7dc9d5a0..bb022ac7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -91,6 +91,7 @@ nav: - Advanced: - advanced/hyperparameter-tuning.md - advanced/resume-training.md + - advanced/rlops.md - Community: - contribution.md - cleanrl-supported-papers-projects.md diff --git a/poetry.lock b/poetry.lock index beb22a2d..d9d89138 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2296,7 +2296,7 @@ testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>= [metadata] lock-version = "1.1" python-versions = ">=3.7.1,<3.10" -content-hash = "78f4008a7a5d3ac846d8b6d07c9e6d23cfd4136404324c2daf516e4288041e8d" +content-hash = "28e942604fdfc9d399d83208658bcf45ad37ce5874954cd62818f9028bb08494" [metadata.files] absl-py = [ diff --git a/pyproject.toml b/pyproject.toml index ca14cf70..1092407b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ stable-baselines3 = "1.2.0" [tool.poetry.group.dev.dependencies] pre-commit = "^2.20.0" +rich = "<12.0" [tool.poetry.group.atari] optional = true