Skip to content

Commit

Permalink
ppo_continuous_action huggingface integration (#423)
Browse files Browse the repository at this point in the history
* add ppo_continuous_action huggingface integration

* fix
  • Loading branch information
sdpkjc authored Oct 16, 2023
1 parent cf20043 commit 2d660b6
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
32 changes: 32 additions & 0 deletions cleanrl/ppo_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def parse_args():
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="HalfCheetah-v4",
Expand Down Expand Up @@ -319,5 +325,31 @@ def get_action_and_value(self, x, action=None):
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
torch.save(agent.state_dict(), model_path)
print(f"model saved to {model_path}")
from cleanrl_utils.evals.ppo_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=Agent,
device=device,
gamma=args.gamma,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
56 changes: 56 additions & 0 deletions cleanrl_utils/evals/ppo_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Callable

import gymnasium as gym
import torch


def evaluate(
model_path: str,
make_env: Callable,
env_id: str,
eval_episodes: int,
run_name: str,
Model: torch.nn.Module,
device: torch.device = torch.device("cpu"),
capture_video: bool = True,
gamma: float = 0.99,
):
envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, capture_video, run_name, gamma)])
agent = Model(envs).to(device)
agent.load_state_dict(torch.load(model_path, map_location=device))
agent.eval()

obs, _ = envs.reset()
episodic_returns = []
while len(episodic_returns) < eval_episodes:
actions, _, _, _ = agent.get_action_and_value(torch.Tensor(obs).to(device))
next_obs, _, _, _, infos = envs.step(actions.cpu().numpy())
if "final_info" in infos:
for info in infos["final_info"]:
if "episode" not in info:
continue
print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
episodic_returns += [info["episode"]["r"]]
obs = next_obs

return episodic_returns


if __name__ == "__main__":
from huggingface_hub import hf_hub_download

from cleanrl.ppo_continuous_action import Agent, make_env

model_path = hf_hub_download(
repo_id="sdpkjc/Hopper-v4-ppo_continuous_action-seed1", filename="ppo_continuous_action.cleanrl_model"
)
evaluate(
model_path,
make_env,
"Hopper-v4",
eval_episodes=10,
run_name=f"eval",
Model=Agent,
device="cpu",
capture_video=False,
)

1 comment on commit 2d660b6

@vercel
Copy link

@vercel vercel bot commented on 2d660b6 Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

cleanrl – ./

cleanrl-vwxyzjn.vercel.app
cleanrl-git-master-vwxyzjn.vercel.app
cleanrl.vercel.app
docs.cleanrl.dev

Please sign in to comment.