diff --git a/mojo/mojo.py b/mojo/mojo.py index 6517443..b9e1365 100644 --- a/mojo/mojo.py +++ b/mojo/mojo.py @@ -1,3 +1,5 @@ +from typing import Callable, Optional + import mujoco.viewer import numpy as np from dm_control import mjcf @@ -85,8 +87,25 @@ def get_material(self, path: str) -> mjcf.Element: def store_material(self, path: str, material_mjcf: mjcf.Element) -> mjcf.Element: self._texture_store[path] = material_mjcf - def load_model(self, path: str, parent: MujocoElement = None): + def load_model( + self, + path: str, + parent: MujocoElement = None, + on_loaded: Optional[Callable[[mjcf.RootElement], None]] = None, + ): + """Load a Mujoco model from the xml file and attach it to the specified parent element. + + :param path: The file path to the Mujoco model XML file. + :param parent: The parent MujocoElement to which the loaded model will be attached. + If None, it attaches to the root element. + :param on_loaded: An optional callback function to be executed after the model is loaded. + Use it to customize the Mujoco model before attaching it to the parent. + :return: A Body element representing the attached model. + """ + model_mjcf = mjcf.from_path(path) + if on_loaded is not None: + on_loaded(model_mjcf) attach_site = self.root_element.mjcf if parent is None else parent attached_model_mjcf = attach_site.attach(model_mjcf) self.mark_dirty()