diff --git a/mojo/elements/body.py b/mojo/elements/body.py
index fa3509b..bf2d5df 100644
--- a/mojo/elements/body.py
+++ b/mojo/elements/body.py
@@ -98,3 +98,6 @@ def set_kinematic(self, value: bool):
if value and not self.is_kinematic():
self.mjcf.add("freejoint")
self._mojo.mark_dirty()
+ elif not value and self.is_kinematic():
+ self.remove_all_joints()
+ self._mojo.mark_dirty()
diff --git a/mojo/elements/element.py b/mojo/elements/element.py
index b8a7341..e892101 100644
--- a/mojo/elements/element.py
+++ b/mojo/elements/element.py
@@ -20,6 +20,14 @@ def _is_kinematic(elem: mjcf.Element):
return has_freejoint or has_joints or _is_kinematic(elem.parent)
+def _remove_all_joints(elem: mjcf.Element):
+ if hasattr(elem, "freejoint") and elem.freejoint is not None:
+ elem.freejoint.remove()
+ if hasattr(elem, "joint") and len(elem.joint) > 0:
+ for joint in elem.joint:
+ joint.remove()
+
+
def _find_freejoint(elem: mjcf.Element):
if elem.parent is None:
# Root of tree
@@ -73,6 +81,9 @@ def get_quaternion(self) -> np.ndarray:
def is_kinematic(self) -> bool:
return _is_kinematic(self.mjcf)
+ def remove_all_joints(self):
+ _remove_all_joints(self.mjcf)
+
@property
def id(self):
return self._mojo.physics.bind(self.mjcf).element_id
diff --git a/mojo/elements/geom.py b/mojo/elements/geom.py
index 09a1572..72e75c6 100644
--- a/mojo/elements/geom.py
+++ b/mojo/elements/geom.py
@@ -175,3 +175,6 @@ def set_kinematic(self, value: bool):
if value and not self.is_kinematic():
self.mjcf.parent.add("freejoint")
self._mojo.mark_dirty()
+ elif not value and self.is_kinematic():
+ self.parent.remove_all_joints()
+ self._mojo.mark_dirty()
diff --git a/mojo/elements/joint.py b/mojo/elements/joint.py
index b76a662..f664f16 100644
--- a/mojo/elements/joint.py
+++ b/mojo/elements/joint.py
@@ -54,7 +54,7 @@ def create(
def get_joint_position(self) -> float:
"""Get current joint position."""
- return float(self._mojo.physics.bind(self.mjcf).qpos)
+ return float(self._mojo.physics.bind(self.mjcf).qpos.item())
def set_joint_position(self, value: float):
self._mojo.physics.bind(self.mjcf).qpos *= 0
@@ -63,4 +63,4 @@ def set_joint_position(self, value: float):
def get_joint_velocity(self) -> float:
"""Get current joint velocity."""
- return float(self._mojo.physics.bind(self.mjcf).qvel)
+ return float(self._mojo.physics.bind(self.mjcf).qvel.item())
diff --git a/mojo/elements/utils.py b/mojo/elements/utils.py
index 97f7e1a..8ab8956 100644
--- a/mojo/elements/utils.py
+++ b/mojo/elements/utils.py
@@ -5,11 +5,14 @@
import numpy as np
from dm_control import mjcf
+from lxml import etree
from mojo.elements.consts import TextureMapping
# Default minimum distance between two geoms for them to be considered in collision.
_DEFAULT_COLLISION_MARGIN: float = 1e-8
+_FREEJOINT_TAG = "freejoint"
+_WORLDBODY_TAG = "worldbody"
def has_collision(
@@ -77,6 +80,30 @@ def load_mesh(
return mesh
+def resolve_freejoints(
+ root_model: mjcf.RootElement, model: mjcf.RootElement
+) -> mjcf.RootElement:
+ child_xml = model.to_xml()
+ if len(child_xml.findall(f".//{_FREEJOINT_TAG}")) > 0:
+ root_xml = root_model.to_xml()
+ worldbody = root_xml.find(_WORLDBODY_TAG)
+ search_expr = f".//{child_xml.tag}"
+ for attr_name, attr_value in child_xml.attrib.items():
+ search_expr += f"[@{attr_name}='{attr_value}']"
+ child_xml = root_xml.find(search_expr)
+ freejoints = child_xml.findall(f".//{_FREEJOINT_TAG}")
+ for joint in freejoints:
+ worldbody.append(joint.getparent())
+ if len(child_xml) == 0:
+ child_xml.getparent().remove(child_xml)
+ root_model = mjcf.from_xml_string(
+ etree.tostring(root_xml),
+ escape_separators=True,
+ assets=root_model.get_assets(),
+ )
+ return root_model
+
+
class AssetStore:
"""Container for Mujoco assets."""
diff --git a/mojo/mojo.py b/mojo/mojo.py
index 9a98169..e3c8195 100644
--- a/mojo/mojo.py
+++ b/mojo/mojo.py
@@ -7,7 +7,7 @@
from mojo.elements.body import Body
from mojo.elements.element import MujocoElement
from mojo.elements.model import MujocoModel
-from mojo.elements.utils import AssetStore
+from mojo.elements.utils import AssetStore, resolve_freejoints
class Mojo:
@@ -108,6 +108,7 @@ def load_model(
path: str,
parent: MujocoElement = None,
on_loaded: Optional[Callable[[mjcf.RootElement], None]] = None,
+ handle_freejoints: bool = False,
):
"""Load a Mujoco model from xml file and attach to specified parent element.
@@ -116,6 +117,8 @@ def load_model(
If None, it attaches to the root element.
:param on_loaded: Optional callback to be executed after model is loaded.
Use it to customize the Mujoco model before attaching it to the parent.
+ :param handle_freejoints: If true handles elements.
+ Freejoint bodies will be re-parented to the worldbody.
:return: A Body element representing the attached model.
"""
@@ -124,6 +127,11 @@ def load_model(
on_loaded(model_mjcf)
attach_site = self.root_element.mjcf if parent is None else parent.mjcf
attached_model_mjcf = attach_site.attach(model_mjcf)
+ if handle_freejoints:
+ root_model_mjcf = resolve_freejoints(
+ self.root_element.mjcf, attached_model_mjcf
+ )
+ self.root_element = MujocoElement(self, root_model_mjcf)
self.mark_dirty()
return Body(self, attached_model_mjcf)
diff --git a/tests/assets/models/sphere.xml b/tests/assets/models/sphere.xml
new file mode 100644
index 0000000..801b5cc
--- /dev/null
+++ b/tests/assets/models/sphere.xml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
diff --git a/tests/assets/models/sphere_and_box.xml b/tests/assets/models/sphere_and_box.xml
new file mode 100644
index 0000000..fa1eac6
--- /dev/null
+++ b/tests/assets/models/sphere_and_box.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tests/elements/test_body.py b/tests/elements/test_body.py
index 5fdbd6a..5ff7c7d 100644
--- a/tests/elements/test_body.py
+++ b/tests/elements/test_body.py
@@ -45,3 +45,10 @@ def test_get_set_kinematic(mojo: Mojo, body: Body):
mojo.step() # Objects should fall
after = body.get_position()
assert np.any(np.not_equal(before, after))
+
+
+def test_is_kinematic(mojo: Mojo, body: Body):
+ body.set_kinematic(True)
+ assert body.is_kinematic()
+ body.set_kinematic(False)
+ assert not body.is_kinematic()
diff --git a/tests/elements/test_geom.py b/tests/elements/test_geom.py
index 45ec766..966442c 100644
--- a/tests/elements/test_geom.py
+++ b/tests/elements/test_geom.py
@@ -80,3 +80,10 @@ def test_get_set_kinematic(mojo: Mojo, geom: Geom):
mojo.step() # Objects should fall
after = geom.get_position()
assert np.any(np.not_equal(before, after))
+
+
+def test_is_kinematic(mojo: Mojo, geom: Geom):
+ geom.set_kinematic(True)
+ assert geom.is_kinematic()
+ geom.set_kinematic(False)
+ assert not geom.is_kinematic()
diff --git a/tests/elements/test_mojo.py b/tests/elements/test_mojo.py
new file mode 100644
index 0000000..1a68588
--- /dev/null
+++ b/tests/elements/test_mojo.py
@@ -0,0 +1,36 @@
+from pathlib import Path
+
+import pytest
+
+from mojo import Mojo
+from mojo.elements import Body
+
+
+@pytest.fixture()
+def mojo() -> Mojo:
+ return Mojo(str(Path(__file__).parents[1] / "world.xml"))
+
+
+def load_model(mojo: Mojo, model_name: str, handle_freejoints: bool) -> Body:
+ body = mojo.load_model(
+ str(Path(__file__).parents[1] / "assets" / "models" / model_name),
+ handle_freejoints=handle_freejoints,
+ )
+ _ = mojo.physics
+ return body
+
+
+def test_load_freejoint(mojo: Mojo):
+ load_model(mojo, "sphere.xml", True)
+
+
+def test_load_freejoint_raises_by_default(mojo: Mojo):
+ with pytest.raises(ValueError):
+ load_model(mojo, "sphere.xml", False)
+
+
+def test_load_freejoint_hierarchy(mojo: Mojo):
+ sphere_and_box = load_model(mojo, "sphere_and_box.xml", True)
+ for joint in sphere_and_box.joints:
+ assert joint.mjcf.tag == "freejoint"
+ assert joint.mjcf.parent.parent.tag == "worldbody"