Skip to content

Commit

Permalink
Enable loading models with <freejoint/> attribute (#5)
Browse files Browse the repository at this point in the history
* Fix set_kinematic(False)

* Fix warning

* Can load models with freejoints

* Refactor

* Refactor

---------

Co-authored-by: Nikita Chernyadev <[email protected]>
  • Loading branch information
2 people authored and stepjam committed Jul 29, 2024
1 parent 514663d commit 9cdde12
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 3 deletions.
3 changes: 3 additions & 0 deletions mojo/elements/body.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ def set_kinematic(self, value: bool):
if value and not self.is_kinematic():
self.mjcf.add("freejoint")
self._mojo.mark_dirty()
elif not value and self.is_kinematic():
self.remove_all_joints()
self._mojo.mark_dirty()
11 changes: 11 additions & 0 deletions mojo/elements/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ def _is_kinematic(elem: mjcf.Element):
return has_freejoint or has_joints or _is_kinematic(elem.parent)


def _remove_all_joints(elem: mjcf.Element):
if hasattr(elem, "freejoint") and elem.freejoint is not None:
elem.freejoint.remove()
if hasattr(elem, "joint") and len(elem.joint) > 0:
for joint in elem.joint:
joint.remove()


def _find_freejoint(elem: mjcf.Element):
if elem.parent is None:
# Root of tree
Expand Down Expand Up @@ -73,6 +81,9 @@ def get_quaternion(self) -> np.ndarray:
def is_kinematic(self) -> bool:
return _is_kinematic(self.mjcf)

def remove_all_joints(self):
_remove_all_joints(self.mjcf)

@property
def id(self):
return self._mojo.physics.bind(self.mjcf).element_id
Expand Down
3 changes: 3 additions & 0 deletions mojo/elements/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,6 @@ def set_kinematic(self, value: bool):
if value and not self.is_kinematic():
self.mjcf.parent.add("freejoint")
self._mojo.mark_dirty()
elif not value and self.is_kinematic():
self.parent.remove_all_joints()
self._mojo.mark_dirty()
4 changes: 2 additions & 2 deletions mojo/elements/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def create(

def get_joint_position(self) -> float:
"""Get current joint position."""
return float(self._mojo.physics.bind(self.mjcf).qpos)
return float(self._mojo.physics.bind(self.mjcf).qpos.item())

def set_joint_position(self, value: float):
self._mojo.physics.bind(self.mjcf).qpos *= 0
Expand All @@ -63,4 +63,4 @@ def set_joint_position(self, value: float):

def get_joint_velocity(self) -> float:
"""Get current joint velocity."""
return float(self._mojo.physics.bind(self.mjcf).qvel)
return float(self._mojo.physics.bind(self.mjcf).qvel.item())
27 changes: 27 additions & 0 deletions mojo/elements/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

import numpy as np
from dm_control import mjcf
from lxml import etree

from mojo.elements.consts import TextureMapping

# Default minimum distance between two geoms for them to be considered in collision.
_DEFAULT_COLLISION_MARGIN: float = 1e-8
_FREEJOINT_TAG = "freejoint"
_WORLDBODY_TAG = "worldbody"


def has_collision(
Expand Down Expand Up @@ -77,6 +80,30 @@ def load_mesh(
return mesh


def resolve_freejoints(
root_model: mjcf.RootElement, model: mjcf.RootElement
) -> mjcf.RootElement:
child_xml = model.to_xml()
if len(child_xml.findall(f".//{_FREEJOINT_TAG}")) > 0:
root_xml = root_model.to_xml()
worldbody = root_xml.find(_WORLDBODY_TAG)
search_expr = f".//{child_xml.tag}"
for attr_name, attr_value in child_xml.attrib.items():
search_expr += f"[@{attr_name}='{attr_value}']"
child_xml = root_xml.find(search_expr)
freejoints = child_xml.findall(f".//{_FREEJOINT_TAG}")
for joint in freejoints:
worldbody.append(joint.getparent())
if len(child_xml) == 0:
child_xml.getparent().remove(child_xml)
root_model = mjcf.from_xml_string(
etree.tostring(root_xml),
escape_separators=True,
assets=root_model.get_assets(),
)
return root_model


class AssetStore:
"""Container for Mujoco assets."""

Expand Down
10 changes: 9 additions & 1 deletion mojo/mojo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mojo.elements.body import Body
from mojo.elements.element import MujocoElement
from mojo.elements.model import MujocoModel
from mojo.elements.utils import AssetStore
from mojo.elements.utils import AssetStore, resolve_freejoints


class Mojo:
Expand Down Expand Up @@ -108,6 +108,7 @@ def load_model(
path: str,
parent: MujocoElement = None,
on_loaded: Optional[Callable[[mjcf.RootElement], None]] = None,
handle_freejoints: bool = False,
):
"""Load a Mujoco model from xml file and attach to specified parent element.
Expand All @@ -116,6 +117,8 @@ def load_model(
If None, it attaches to the root element.
:param on_loaded: Optional callback to be executed after model is loaded.
Use it to customize the Mujoco model before attaching it to the parent.
:param handle_freejoints: If true handles <freejoint/> elements.
Freejoint bodies will be re-parented to the worldbody.
:return: A Body element representing the attached model.
"""

Expand All @@ -124,6 +127,11 @@ def load_model(
on_loaded(model_mjcf)
attach_site = self.root_element.mjcf if parent is None else parent.mjcf
attached_model_mjcf = attach_site.attach(model_mjcf)
if handle_freejoints:
root_model_mjcf = resolve_freejoints(
self.root_element.mjcf, attached_model_mjcf
)
self.root_element = MujocoElement(self, root_model_mjcf)
self.mark_dirty()
return Body(self, attached_model_mjcf)

Expand Down
9 changes: 9 additions & 0 deletions tests/assets/models/sphere.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="sphere">

<worldbody>
<body name="sphere">
<freejoint/>
<geom type="sphere" size="0.05 0.05 0.05"/>
</body>
</worldbody>
</mujoco>
13 changes: 13 additions & 0 deletions tests/assets/models/sphere_and_box.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<mujoco model="sphere_and_box">

<worldbody>
<body name="sphere">
<freejoint/>
<geom type="sphere" size="0.05 0.05 0.05"/>
</body>
<body name="cube">
<freejoint/>
<geom type="box" size="0.05 0.05 0.05"/>
</body>
</worldbody>
</mujoco>
7 changes: 7 additions & 0 deletions tests/elements/test_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,10 @@ def test_get_set_kinematic(mojo: Mojo, body: Body):
mojo.step() # Objects should fall
after = body.get_position()
assert np.any(np.not_equal(before, after))


def test_is_kinematic(mojo: Mojo, body: Body):
body.set_kinematic(True)
assert body.is_kinematic()
body.set_kinematic(False)
assert not body.is_kinematic()
7 changes: 7 additions & 0 deletions tests/elements/test_geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,10 @@ def test_get_set_kinematic(mojo: Mojo, geom: Geom):
mojo.step() # Objects should fall
after = geom.get_position()
assert np.any(np.not_equal(before, after))


def test_is_kinematic(mojo: Mojo, geom: Geom):
geom.set_kinematic(True)
assert geom.is_kinematic()
geom.set_kinematic(False)
assert not geom.is_kinematic()
36 changes: 36 additions & 0 deletions tests/elements/test_mojo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from pathlib import Path

import pytest

from mojo import Mojo
from mojo.elements import Body


@pytest.fixture()
def mojo() -> Mojo:
return Mojo(str(Path(__file__).parents[1] / "world.xml"))


def load_model(mojo: Mojo, model_name: str, handle_freejoints: bool) -> Body:
body = mojo.load_model(
str(Path(__file__).parents[1] / "assets" / "models" / model_name),
handle_freejoints=handle_freejoints,
)
_ = mojo.physics
return body


def test_load_freejoint(mojo: Mojo):
load_model(mojo, "sphere.xml", True)


def test_load_freejoint_raises_by_default(mojo: Mojo):
with pytest.raises(ValueError):
load_model(mojo, "sphere.xml", False)


def test_load_freejoint_hierarchy(mojo: Mojo):
sphere_and_box = load_model(mojo, "sphere_and_box.xml", True)
for joint in sphere_and_box.joints:
assert joint.mjcf.tag == "freejoint"
assert joint.mjcf.parent.parent.tag == "worldbody"

0 comments on commit 9cdde12

Please sign in to comment.