diff --git a/mojo/elements/body.py b/mojo/elements/body.py index fa3509b..bf2d5df 100644 --- a/mojo/elements/body.py +++ b/mojo/elements/body.py @@ -98,3 +98,6 @@ def set_kinematic(self, value: bool): if value and not self.is_kinematic(): self.mjcf.add("freejoint") self._mojo.mark_dirty() + elif not value and self.is_kinematic(): + self.remove_all_joints() + self._mojo.mark_dirty() diff --git a/mojo/elements/element.py b/mojo/elements/element.py index b8a7341..e892101 100644 --- a/mojo/elements/element.py +++ b/mojo/elements/element.py @@ -20,6 +20,14 @@ def _is_kinematic(elem: mjcf.Element): return has_freejoint or has_joints or _is_kinematic(elem.parent) +def _remove_all_joints(elem: mjcf.Element): + if hasattr(elem, "freejoint") and elem.freejoint is not None: + elem.freejoint.remove() + if hasattr(elem, "joint") and len(elem.joint) > 0: + for joint in elem.joint: + joint.remove() + + def _find_freejoint(elem: mjcf.Element): if elem.parent is None: # Root of tree @@ -73,6 +81,9 @@ def get_quaternion(self) -> np.ndarray: def is_kinematic(self) -> bool: return _is_kinematic(self.mjcf) + def remove_all_joints(self): + _remove_all_joints(self.mjcf) + @property def id(self): return self._mojo.physics.bind(self.mjcf).element_id diff --git a/mojo/elements/geom.py b/mojo/elements/geom.py index 09a1572..72e75c6 100644 --- a/mojo/elements/geom.py +++ b/mojo/elements/geom.py @@ -175,3 +175,6 @@ def set_kinematic(self, value: bool): if value and not self.is_kinematic(): self.mjcf.parent.add("freejoint") self._mojo.mark_dirty() + elif not value and self.is_kinematic(): + self.parent.remove_all_joints() + self._mojo.mark_dirty() diff --git a/mojo/elements/joint.py b/mojo/elements/joint.py index b76a662..f664f16 100644 --- a/mojo/elements/joint.py +++ b/mojo/elements/joint.py @@ -54,7 +54,7 @@ def create( def get_joint_position(self) -> float: """Get current joint position.""" - return float(self._mojo.physics.bind(self.mjcf).qpos) + return float(self._mojo.physics.bind(self.mjcf).qpos.item()) def set_joint_position(self, value: float): self._mojo.physics.bind(self.mjcf).qpos *= 0 @@ -63,4 +63,4 @@ def set_joint_position(self, value: float): def get_joint_velocity(self) -> float: """Get current joint velocity.""" - return float(self._mojo.physics.bind(self.mjcf).qvel) + return float(self._mojo.physics.bind(self.mjcf).qvel.item()) diff --git a/mojo/elements/utils.py b/mojo/elements/utils.py index 97f7e1a..8ab8956 100644 --- a/mojo/elements/utils.py +++ b/mojo/elements/utils.py @@ -5,11 +5,14 @@ import numpy as np from dm_control import mjcf +from lxml import etree from mojo.elements.consts import TextureMapping # Default minimum distance between two geoms for them to be considered in collision. _DEFAULT_COLLISION_MARGIN: float = 1e-8 +_FREEJOINT_TAG = "freejoint" +_WORLDBODY_TAG = "worldbody" def has_collision( @@ -77,6 +80,30 @@ def load_mesh( return mesh +def resolve_freejoints( + root_model: mjcf.RootElement, model: mjcf.RootElement +) -> mjcf.RootElement: + child_xml = model.to_xml() + if len(child_xml.findall(f".//{_FREEJOINT_TAG}")) > 0: + root_xml = root_model.to_xml() + worldbody = root_xml.find(_WORLDBODY_TAG) + search_expr = f".//{child_xml.tag}" + for attr_name, attr_value in child_xml.attrib.items(): + search_expr += f"[@{attr_name}='{attr_value}']" + child_xml = root_xml.find(search_expr) + freejoints = child_xml.findall(f".//{_FREEJOINT_TAG}") + for joint in freejoints: + worldbody.append(joint.getparent()) + if len(child_xml) == 0: + child_xml.getparent().remove(child_xml) + root_model = mjcf.from_xml_string( + etree.tostring(root_xml), + escape_separators=True, + assets=root_model.get_assets(), + ) + return root_model + + class AssetStore: """Container for Mujoco assets.""" diff --git a/mojo/mojo.py b/mojo/mojo.py index 9a98169..e3c8195 100644 --- a/mojo/mojo.py +++ b/mojo/mojo.py @@ -7,7 +7,7 @@ from mojo.elements.body import Body from mojo.elements.element import MujocoElement from mojo.elements.model import MujocoModel -from mojo.elements.utils import AssetStore +from mojo.elements.utils import AssetStore, resolve_freejoints class Mojo: @@ -108,6 +108,7 @@ def load_model( path: str, parent: MujocoElement = None, on_loaded: Optional[Callable[[mjcf.RootElement], None]] = None, + handle_freejoints: bool = False, ): """Load a Mujoco model from xml file and attach to specified parent element. @@ -116,6 +117,8 @@ def load_model( If None, it attaches to the root element. :param on_loaded: Optional callback to be executed after model is loaded. Use it to customize the Mujoco model before attaching it to the parent. + :param handle_freejoints: If true handles elements. + Freejoint bodies will be re-parented to the worldbody. :return: A Body element representing the attached model. """ @@ -124,6 +127,11 @@ def load_model( on_loaded(model_mjcf) attach_site = self.root_element.mjcf if parent is None else parent.mjcf attached_model_mjcf = attach_site.attach(model_mjcf) + if handle_freejoints: + root_model_mjcf = resolve_freejoints( + self.root_element.mjcf, attached_model_mjcf + ) + self.root_element = MujocoElement(self, root_model_mjcf) self.mark_dirty() return Body(self, attached_model_mjcf) diff --git a/tests/assets/models/sphere.xml b/tests/assets/models/sphere.xml new file mode 100644 index 0000000..801b5cc --- /dev/null +++ b/tests/assets/models/sphere.xml @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/tests/assets/models/sphere_and_box.xml b/tests/assets/models/sphere_and_box.xml new file mode 100644 index 0000000..fa1eac6 --- /dev/null +++ b/tests/assets/models/sphere_and_box.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/tests/elements/test_body.py b/tests/elements/test_body.py index 5fdbd6a..5ff7c7d 100644 --- a/tests/elements/test_body.py +++ b/tests/elements/test_body.py @@ -45,3 +45,10 @@ def test_get_set_kinematic(mojo: Mojo, body: Body): mojo.step() # Objects should fall after = body.get_position() assert np.any(np.not_equal(before, after)) + + +def test_is_kinematic(mojo: Mojo, body: Body): + body.set_kinematic(True) + assert body.is_kinematic() + body.set_kinematic(False) + assert not body.is_kinematic() diff --git a/tests/elements/test_geom.py b/tests/elements/test_geom.py index 45ec766..966442c 100644 --- a/tests/elements/test_geom.py +++ b/tests/elements/test_geom.py @@ -80,3 +80,10 @@ def test_get_set_kinematic(mojo: Mojo, geom: Geom): mojo.step() # Objects should fall after = geom.get_position() assert np.any(np.not_equal(before, after)) + + +def test_is_kinematic(mojo: Mojo, geom: Geom): + geom.set_kinematic(True) + assert geom.is_kinematic() + geom.set_kinematic(False) + assert not geom.is_kinematic() diff --git a/tests/elements/test_mojo.py b/tests/elements/test_mojo.py new file mode 100644 index 0000000..1a68588 --- /dev/null +++ b/tests/elements/test_mojo.py @@ -0,0 +1,36 @@ +from pathlib import Path + +import pytest + +from mojo import Mojo +from mojo.elements import Body + + +@pytest.fixture() +def mojo() -> Mojo: + return Mojo(str(Path(__file__).parents[1] / "world.xml")) + + +def load_model(mojo: Mojo, model_name: str, handle_freejoints: bool) -> Body: + body = mojo.load_model( + str(Path(__file__).parents[1] / "assets" / "models" / model_name), + handle_freejoints=handle_freejoints, + ) + _ = mojo.physics + return body + + +def test_load_freejoint(mojo: Mojo): + load_model(mojo, "sphere.xml", True) + + +def test_load_freejoint_raises_by_default(mojo: Mojo): + with pytest.raises(ValueError): + load_model(mojo, "sphere.xml", False) + + +def test_load_freejoint_hierarchy(mojo: Mojo): + sphere_and_box = load_model(mojo, "sphere_and_box.xml", True) + for joint in sphere_and_box.joints: + assert joint.mjcf.tag == "freejoint" + assert joint.mjcf.parent.parent.tag == "worldbody"