diff --git a/cleanrl/ppo_continuous_action.py b/cleanrl/ppo_continuous_action.py index 0845222c8..0f2f3b033 100644 --- a/cleanrl/ppo_continuous_action.py +++ b/cleanrl/ppo_continuous_action.py @@ -33,6 +33,12 @@ def parse_args(): help="the entity (team) of wandb's project") parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="whether to capture videos of the agent performances (check out `videos` folder)") + parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to save model into the `runs/{run_name}` folder") + parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to upload the saved model to huggingface") + parser.add_argument("--hf-entity", type=str, default="", + help="the user or org name of the model repository from the Hugging Face Hub") # Algorithm specific arguments parser.add_argument("--env-id", type=str, default="HalfCheetah-v4", @@ -319,5 +325,31 @@ def get_action_and_value(self, x, action=None): print("SPS:", int(global_step / (time.time() - start_time))) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + if args.save_model: + model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" + torch.save(agent.state_dict(), model_path) + print(f"model saved to {model_path}") + from cleanrl_utils.evals.ppo_eval import evaluate + + episodic_returns = evaluate( + model_path, + make_env, + args.env_id, + eval_episodes=10, + run_name=f"{run_name}-eval", + Model=Agent, + device=device, + gamma=args.gamma, + ) + for idx, episodic_return in enumerate(episodic_returns): + writer.add_scalar("eval/episodic_return", episodic_return, idx) + + if args.upload_model: + from cleanrl_utils.huggingface import push_to_hub + + repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval") + envs.close() writer.close() diff --git a/cleanrl_utils/evals/ppo_eval.py b/cleanrl_utils/evals/ppo_eval.py new file mode 100644 index 000000000..05091f567 --- /dev/null +++ b/cleanrl_utils/evals/ppo_eval.py @@ -0,0 +1,56 @@ +from typing import Callable + +import gymnasium as gym +import torch + + +def evaluate( + model_path: str, + make_env: Callable, + env_id: str, + eval_episodes: int, + run_name: str, + Model: torch.nn.Module, + device: torch.device = torch.device("cpu"), + capture_video: bool = True, + gamma: float = 0.99, +): + envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, capture_video, run_name, gamma)]) + agent = Model(envs).to(device) + agent.load_state_dict(torch.load(model_path, map_location=device)) + agent.eval() + + obs, _ = envs.reset() + episodic_returns = [] + while len(episodic_returns) < eval_episodes: + actions, _, _, _ = agent.get_action_and_value(torch.Tensor(obs).to(device)) + next_obs, _, _, _, infos = envs.step(actions.cpu().numpy()) + if "final_info" in infos: + for info in infos["final_info"]: + if "episode" not in info: + continue + print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}") + episodic_returns += [info["episode"]["r"]] + obs = next_obs + + return episodic_returns + + +if __name__ == "__main__": + from huggingface_hub import hf_hub_download + + from cleanrl.ppo_continuous_action import Agent, make_env + + model_path = hf_hub_download( + repo_id="sdpkjc/Hopper-v4-ppo_continuous_action-seed1", filename="ppo_continuous_action.cleanrl_model" + ) + evaluate( + model_path, + make_env, + "Hopper-v4", + eval_episodes=10, + run_name=f"eval", + Model=Agent, + device="cpu", + capture_video=False, + )