Skip to content

Commit

Permalink
update convert_model script
Browse files Browse the repository at this point in the history
  • Loading branch information
oleflb committed Feb 1, 2025
1 parent d0461d0 commit d4a1804
Showing 1 changed file with 40 additions and 21 deletions.
61 changes: 40 additions & 21 deletions tools/machine-learning/mujoco/scripts/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,26 @@
import click
import openvino as ov
import torch
from nao_env.nao_standing import OFFSET_QPOS
from nao_env import nao_standing, nao_walking
from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.policies import ActorCriticPolicy
from torch import nn


class UndefinedObservationSpaceError(ValueError):
def __init__(self) -> None:
super().__init__("observation space must have a fixed size.")


class UndefinedActionSpaceError(ValueError):
def __init__(self) -> None:
super().__init__("action space must have a fixed size.")


class OnnxableSB3Policy(nn.Module):
def __init__(self, policy: BasePolicy) -> None:
def __init__(self, policy: ActorCriticPolicy, offset: torch.Tensor) -> None:
super().__init__()
self.offset = offset
self.policy = policy

def unscale_action(self, scaled_action: torch.Tensor) -> torch.Tensor:
Expand All @@ -36,33 +47,44 @@ def forward(self, observation: torch.Tensor) -> torch.Tensor:
else:
actions = self.clip_action(actions)

return actions + torch.from_numpy(OFFSET_QPOS)
return actions + self.offset


@click.command()
@click.option(
"--load-policy",
@click.argument(
"policy",
type=click.Path(exists=True),
default=None,
help="Load a policy from a file.",
help="The policy to convert to ONNX.",
)
@click.argument(
"environment-type",
type=click.Choice(["NaoStanding", "NaoStandup", "NaoWalking"]),
)
def main(load_policy: str) -> None:
path = Path(load_policy)
def main(policy: str, environment_type: str) -> None:
path = Path(policy)
name = path.parent.name
model = PPO.load(load_policy)
network = OnnxableSB3Policy(model.policy)
observation_size = model.observation_space.shape
model = PPO.load(policy)

observation_size = model.observation_space.shape
if observation_size is None:
raise UndefinedObservationSpaceError()
action_size = model.action_space.shape
if action_size is None:
raise UndefinedActionSpaceError()

offset = {
"NaoStanding": torch.from_numpy(nao_standing.OFFSET_QPOS),
"NaoStandup": torch.zeros(action_size),
"NaoWalking": torch.from_numpy(nao_walking.OFFSET_QPOS),
}[environment_type]

network = OnnxableSB3Policy(model.policy, offset)
Path("result").mkdir(exist_ok=True)

observation = torch.zeros(1, *observation_size)

print(observation.shape, network.forward(observation))

with torch.inference_mode():
torch.onnx.export(
network,
torch.randn(1, *observation_size),
(torch.randn(observation_size),),
f"result/{name}-model.onnx",
input_names=["input"],
output_names=["output"],
Expand All @@ -72,9 +94,6 @@ def main(load_policy: str) -> None:
ov_model = ov.convert_model(f"result/{name}-model.onnx")
ov.save_model(ov_model, f"result/{name}-policy-ov.xml")

nn = ov.compile_model(ov_model)
print(nn(observation))


if __name__ == "__main__":
main()

0 comments on commit d4a1804

Please sign in to comment.