Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable loading models with <freejoint/> attribute #5

Merged
merged 5 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mojo/elements/body.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ def set_kinematic(self, value: bool):
if value and not self.is_kinematic():
self.mjcf.add("freejoint")
self._mojo.mark_dirty()
elif not value and self.is_kinematic():
self.remove_all_joints()
self._mojo.mark_dirty()
11 changes: 11 additions & 0 deletions mojo/elements/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ def _is_kinematic(elem: mjcf.Element):
return has_freejoint or has_joints or _is_kinematic(elem.parent)


def _remove_all_joints(elem: mjcf.Element):
if hasattr(elem, "freejoint") and elem.freejoint is not None:
elem.freejoint.remove()
if hasattr(elem, "joint") and len(elem.joint) > 0:
for joint in elem.joint:
joint.remove()


def _find_freejoint(elem: mjcf.Element):
if elem.parent is None:
# Root of tree
Expand Down Expand Up @@ -73,6 +81,9 @@ def get_quaternion(self) -> np.ndarray:
def is_kinematic(self) -> bool:
return _is_kinematic(self.mjcf)

def remove_all_joints(self):
_remove_all_joints(self.mjcf)

@property
def id(self):
return self._mojo.physics.bind(self.mjcf).element_id
Expand Down
3 changes: 3 additions & 0 deletions mojo/elements/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,6 @@ def set_kinematic(self, value: bool):
if value and not self.is_kinematic():
self.mjcf.parent.add("freejoint")
self._mojo.mark_dirty()
elif not value and self.is_kinematic():
self.parent.remove_all_joints()
self._mojo.mark_dirty()
4 changes: 2 additions & 2 deletions mojo/elements/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def create(

def get_joint_position(self) -> float:
"""Get current joint position."""
return float(self._mojo.physics.bind(self.mjcf).qpos)
return float(self._mojo.physics.bind(self.mjcf).qpos.item())

def set_joint_position(self, value: float):
self._mojo.physics.bind(self.mjcf).qpos *= 0
Expand All @@ -63,4 +63,4 @@ def set_joint_position(self, value: float):

def get_joint_velocity(self) -> float:
"""Get current joint velocity."""
return float(self._mojo.physics.bind(self.mjcf).qvel)
return float(self._mojo.physics.bind(self.mjcf).qvel.item())
27 changes: 27 additions & 0 deletions mojo/elements/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

import numpy as np
from dm_control import mjcf
from lxml import etree

from mojo.elements.consts import TextureMapping

# Default minimum distance between two geoms for them to be considered in collision.
_DEFAULT_COLLISION_MARGIN: float = 1e-8
_FREEJOINT_TAG = "freejoint"
_WORLDBODY_TAG = "worldbody"


def has_collision(
Expand Down Expand Up @@ -77,6 +80,30 @@ def load_mesh(
return mesh


def resolve_freejoints(
root_model: mjcf.RootElement, model: mjcf.RootElement
) -> mjcf.RootElement:
child_xml = model.to_xml()
if len(child_xml.findall(f".//{_FREEJOINT_TAG}")) > 0:
root_xml = root_model.to_xml()
worldbody = root_xml.find(_WORLDBODY_TAG)
search_expr = f".//{child_xml.tag}"
for attr_name, attr_value in child_xml.attrib.items():
search_expr += f"[@{attr_name}='{attr_value}']"
child_xml = root_xml.find(search_expr)
freejoints = child_xml.findall(f".//{_FREEJOINT_TAG}")
for joint in freejoints:
worldbody.append(joint.getparent())
if len(child_xml) == 0:
child_xml.getparent().remove(child_xml)
root_model = mjcf.from_xml_string(
etree.tostring(root_xml),
escape_separators=True,
assets=root_model.get_assets(),
)
return root_model


class AssetStore:
"""Container for Mujoco assets."""

Expand Down
10 changes: 9 additions & 1 deletion mojo/mojo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mojo.elements.body import Body
from mojo.elements.element import MujocoElement
from mojo.elements.model import MujocoModel
from mojo.elements.utils import AssetStore
from mojo.elements.utils import AssetStore, resolve_freejoints


class Mojo:
Expand Down Expand Up @@ -108,6 +108,7 @@ def load_model(
path: str,
parent: MujocoElement = None,
on_loaded: Optional[Callable[[mjcf.RootElement], None]] = None,
handle_freejoints: bool = False,
):
"""Load a Mujoco model from xml file and attach to specified parent element.

Expand All @@ -116,6 +117,8 @@ def load_model(
If None, it attaches to the root element.
:param on_loaded: Optional callback to be executed after model is loaded.
Use it to customize the Mujoco model before attaching it to the parent.
:param handle_freejoints: If true handles <freejoint/> elements.
Freejoint bodies will be re-parented to the worldbody.
:return: A Body element representing the attached model.
"""

Expand All @@ -124,6 +127,11 @@ def load_model(
on_loaded(model_mjcf)
attach_site = self.root_element.mjcf if parent is None else parent.mjcf
attached_model_mjcf = attach_site.attach(model_mjcf)
if handle_freejoints:
root_model_mjcf = resolve_freejoints(
self.root_element.mjcf, attached_model_mjcf
)
self.root_element = MujocoElement(self, root_model_mjcf)
self.mark_dirty()
return Body(self, attached_model_mjcf)

Expand Down
9 changes: 9 additions & 0 deletions tests/assets/models/sphere.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="sphere">

<worldbody>
<body name="sphere">
<freejoint/>
<geom type="sphere" size="0.05 0.05 0.05"/>
</body>
</worldbody>
</mujoco>
13 changes: 13 additions & 0 deletions tests/assets/models/sphere_and_box.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<mujoco model="sphere_and_box">

<worldbody>
<body name="sphere">
<freejoint/>
<geom type="sphere" size="0.05 0.05 0.05"/>
</body>
<body name="cube">
<freejoint/>
<geom type="box" size="0.05 0.05 0.05"/>
</body>
</worldbody>
</mujoco>
7 changes: 7 additions & 0 deletions tests/elements/test_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,10 @@ def test_get_set_kinematic(mojo: Mojo, body: Body):
mojo.step() # Objects should fall
after = body.get_position()
assert np.any(np.not_equal(before, after))


def test_is_kinematic(mojo: Mojo, body: Body):
body.set_kinematic(True)
assert body.is_kinematic()
body.set_kinematic(False)
assert not body.is_kinematic()
7 changes: 7 additions & 0 deletions tests/elements/test_geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,10 @@ def test_get_set_kinematic(mojo: Mojo, geom: Geom):
mojo.step() # Objects should fall
after = geom.get_position()
assert np.any(np.not_equal(before, after))


def test_is_kinematic(mojo: Mojo, geom: Geom):
geom.set_kinematic(True)
assert geom.is_kinematic()
geom.set_kinematic(False)
assert not geom.is_kinematic()
36 changes: 36 additions & 0 deletions tests/elements/test_mojo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from pathlib import Path

import pytest

from mojo import Mojo
from mojo.elements import Body


@pytest.fixture()
def mojo() -> Mojo:
return Mojo(str(Path(__file__).parents[1] / "world.xml"))


def load_model(mojo: Mojo, model_name: str, handle_freejoints: bool) -> Body:
body = mojo.load_model(
str(Path(__file__).parents[1] / "assets" / "models" / model_name),
handle_freejoints=handle_freejoints,
)
_ = mojo.physics
return body


def test_load_freejoint(mojo: Mojo):
load_model(mojo, "sphere.xml", True)


def test_load_freejoint_raises_by_default(mojo: Mojo):
with pytest.raises(ValueError):
load_model(mojo, "sphere.xml", False)


def test_load_freejoint_hierarchy(mojo: Mojo):
sphere_and_box = load_model(mojo, "sphere_and_box.xml", True)
for joint in sphere_and_box.joints:
assert joint.mjcf.tag == "freejoint"
assert joint.mjcf.parent.parent.tag == "worldbody"