From 1d72e0851a88f9383de7708e439a6d4257645837 Mon Sep 17 00:00:00 2001 From: Nikita Chernyadev Date: Wed, 17 Apr 2024 12:47:32 +0100 Subject: [PATCH 1/5] Fix set_kinematic(False) --- mojo/elements/body.py | 3 +++ mojo/elements/geom.py | 3 +++ tests/elements/test_body.py | 7 +++++++ tests/elements/test_geom.py | 7 +++++++ 4 files changed, 20 insertions(+) diff --git a/mojo/elements/body.py b/mojo/elements/body.py index fa3509b..bf67253 100644 --- a/mojo/elements/body.py +++ b/mojo/elements/body.py @@ -98,3 +98,6 @@ def set_kinematic(self, value: bool): if value and not self.is_kinematic(): self.mjcf.add("freejoint") self._mojo.mark_dirty() + elif not value and self.is_kinematic(): + self.mjcf.freejoint.remove() + self._mojo.mark_dirty() diff --git a/mojo/elements/geom.py b/mojo/elements/geom.py index 09a1572..d45b586 100644 --- a/mojo/elements/geom.py +++ b/mojo/elements/geom.py @@ -175,3 +175,6 @@ def set_kinematic(self, value: bool): if value and not self.is_kinematic(): self.mjcf.parent.add("freejoint") self._mojo.mark_dirty() + elif not value and self.is_kinematic(): + self.mjcf.parent.freejoint.remove() + self._mojo.mark_dirty() diff --git a/tests/elements/test_body.py b/tests/elements/test_body.py index 5fdbd6a..5ff7c7d 100644 --- a/tests/elements/test_body.py +++ b/tests/elements/test_body.py @@ -45,3 +45,10 @@ def test_get_set_kinematic(mojo: Mojo, body: Body): mojo.step() # Objects should fall after = body.get_position() assert np.any(np.not_equal(before, after)) + + +def test_is_kinematic(mojo: Mojo, body: Body): + body.set_kinematic(True) + assert body.is_kinematic() + body.set_kinematic(False) + assert not body.is_kinematic() diff --git a/tests/elements/test_geom.py b/tests/elements/test_geom.py index 45ec766..966442c 100644 --- a/tests/elements/test_geom.py +++ b/tests/elements/test_geom.py @@ -80,3 +80,10 @@ def test_get_set_kinematic(mojo: Mojo, geom: Geom): mojo.step() # Objects should fall after = geom.get_position() assert np.any(np.not_equal(before, after)) + + +def test_is_kinematic(mojo: Mojo, geom: Geom): + geom.set_kinematic(True) + assert geom.is_kinematic() + geom.set_kinematic(False) + assert not geom.is_kinematic() From de8f0aaefa9e6d4ccf7e4d7a75a66e206dda0cb8 Mon Sep 17 00:00:00 2001 From: Nikita Chernyadev Date: Wed, 17 Apr 2024 12:48:02 +0100 Subject: [PATCH 2/5] Fix warning --- mojo/elements/joint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mojo/elements/joint.py b/mojo/elements/joint.py index b76a662..f664f16 100644 --- a/mojo/elements/joint.py +++ b/mojo/elements/joint.py @@ -54,7 +54,7 @@ def create( def get_joint_position(self) -> float: """Get current joint position.""" - return float(self._mojo.physics.bind(self.mjcf).qpos) + return float(self._mojo.physics.bind(self.mjcf).qpos.item()) def set_joint_position(self, value: float): self._mojo.physics.bind(self.mjcf).qpos *= 0 @@ -63,4 +63,4 @@ def set_joint_position(self, value: float): def get_joint_velocity(self) -> float: """Get current joint velocity.""" - return float(self._mojo.physics.bind(self.mjcf).qvel) + return float(self._mojo.physics.bind(self.mjcf).qvel.item()) From c85ec89be9059f15bb782cbe11ea97f626f8ddda Mon Sep 17 00:00:00 2001 From: Nikita Chernyadev Date: Wed, 17 Apr 2024 12:50:16 +0100 Subject: [PATCH 3/5] Can load models with freejoints --- mojo/mojo.py | 28 ++++++++++++++++++++ tests/assets/models/sphere.xml | 9 +++++++ tests/assets/models/sphere_and_box.xml | 13 ++++++++++ tests/elements/test_mojo.py | 36 ++++++++++++++++++++++++++ 4 files changed, 86 insertions(+) create mode 100644 tests/assets/models/sphere.xml create mode 100644 tests/assets/models/sphere_and_box.xml create mode 100644 tests/elements/test_mojo.py diff --git a/mojo/mojo.py b/mojo/mojo.py index 9a98169..da02754 100644 --- a/mojo/mojo.py +++ b/mojo/mojo.py @@ -3,6 +3,7 @@ import mujoco.viewer import numpy as np from dm_control import mjcf +from lxml import etree from mojo.elements.body import Body from mojo.elements.element import MujocoElement @@ -108,6 +109,7 @@ def load_model( path: str, parent: MujocoElement = None, on_loaded: Optional[Callable[[mjcf.RootElement], None]] = None, + handle_freejoints: bool = False, ): """Load a Mujoco model from xml file and attach to specified parent element. @@ -116,6 +118,8 @@ def load_model( If None, it attaches to the root element. :param on_loaded: Optional callback to be executed after model is loaded. Use it to customize the Mujoco model before attaching it to the parent. + :param handle_freejoints: If true handles elements. + Freejoint bodies will be re-parented to the worldbody. :return: A Body element representing the attached model. """ @@ -124,6 +128,30 @@ def load_model( on_loaded(model_mjcf) attach_site = self.root_element.mjcf if parent is None else parent.mjcf attached_model_mjcf = attach_site.attach(model_mjcf) + if handle_freejoints: + FREEJOINT = "freejoint" + WORLDBODY = "worldbody" + attached_xml = attached_model_mjcf.to_xml() + attached_freejoints = attached_xml.findall(f".//{FREEJOINT}") + if len(attached_freejoints) > 0: + root_xml = self.root_element.mjcf.to_xml() + worldbody = root_xml.find(WORLDBODY) + xpath_expr = f".//{attached_xml.tag}" + for attr_name, attr_value in attached_xml.attrib.items(): + xpath_expr += f"[@{attr_name}='{attr_value}']" + attached_xml = root_xml.find(xpath_expr) + freejoints = attached_xml.findall(f".//{FREEJOINT}") + for freejoint in freejoints: + worldbody.append(freejoint.getparent()) + if len(attached_xml) == 0: + attached_xml.getparent().remove(attached_xml) + + root_model = mjcf.from_xml_string( + etree.tostring(root_xml), + escape_separators=True, + assets=self.root_element.mjcf.get_assets(), + ) + self.root_element = MujocoElement(self, root_model) self.mark_dirty() return Body(self, attached_model_mjcf) diff --git a/tests/assets/models/sphere.xml b/tests/assets/models/sphere.xml new file mode 100644 index 0000000..801b5cc --- /dev/null +++ b/tests/assets/models/sphere.xml @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/tests/assets/models/sphere_and_box.xml b/tests/assets/models/sphere_and_box.xml new file mode 100644 index 0000000..fa1eac6 --- /dev/null +++ b/tests/assets/models/sphere_and_box.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/tests/elements/test_mojo.py b/tests/elements/test_mojo.py new file mode 100644 index 0000000..1a68588 --- /dev/null +++ b/tests/elements/test_mojo.py @@ -0,0 +1,36 @@ +from pathlib import Path + +import pytest + +from mojo import Mojo +from mojo.elements import Body + + +@pytest.fixture() +def mojo() -> Mojo: + return Mojo(str(Path(__file__).parents[1] / "world.xml")) + + +def load_model(mojo: Mojo, model_name: str, handle_freejoints: bool) -> Body: + body = mojo.load_model( + str(Path(__file__).parents[1] / "assets" / "models" / model_name), + handle_freejoints=handle_freejoints, + ) + _ = mojo.physics + return body + + +def test_load_freejoint(mojo: Mojo): + load_model(mojo, "sphere.xml", True) + + +def test_load_freejoint_raises_by_default(mojo: Mojo): + with pytest.raises(ValueError): + load_model(mojo, "sphere.xml", False) + + +def test_load_freejoint_hierarchy(mojo: Mojo): + sphere_and_box = load_model(mojo, "sphere_and_box.xml", True) + for joint in sphere_and_box.joints: + assert joint.mjcf.tag == "freejoint" + assert joint.mjcf.parent.parent.tag == "worldbody" From 0ea8af6f3a4749f81472cc141c21087f8e88c0d2 Mon Sep 17 00:00:00 2001 From: Nikita Chernyadev Date: Wed, 17 Apr 2024 12:56:58 +0100 Subject: [PATCH 4/5] Refactor --- mojo/elements/body.py | 2 +- mojo/elements/element.py | 11 +++++++++++ mojo/elements/geom.py | 2 +- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/mojo/elements/body.py b/mojo/elements/body.py index bf67253..bf2d5df 100644 --- a/mojo/elements/body.py +++ b/mojo/elements/body.py @@ -99,5 +99,5 @@ def set_kinematic(self, value: bool): self.mjcf.add("freejoint") self._mojo.mark_dirty() elif not value and self.is_kinematic(): - self.mjcf.freejoint.remove() + self.remove_all_joints() self._mojo.mark_dirty() diff --git a/mojo/elements/element.py b/mojo/elements/element.py index b8a7341..e892101 100644 --- a/mojo/elements/element.py +++ b/mojo/elements/element.py @@ -20,6 +20,14 @@ def _is_kinematic(elem: mjcf.Element): return has_freejoint or has_joints or _is_kinematic(elem.parent) +def _remove_all_joints(elem: mjcf.Element): + if hasattr(elem, "freejoint") and elem.freejoint is not None: + elem.freejoint.remove() + if hasattr(elem, "joint") and len(elem.joint) > 0: + for joint in elem.joint: + joint.remove() + + def _find_freejoint(elem: mjcf.Element): if elem.parent is None: # Root of tree @@ -73,6 +81,9 @@ def get_quaternion(self) -> np.ndarray: def is_kinematic(self) -> bool: return _is_kinematic(self.mjcf) + def remove_all_joints(self): + _remove_all_joints(self.mjcf) + @property def id(self): return self._mojo.physics.bind(self.mjcf).element_id diff --git a/mojo/elements/geom.py b/mojo/elements/geom.py index d45b586..72e75c6 100644 --- a/mojo/elements/geom.py +++ b/mojo/elements/geom.py @@ -176,5 +176,5 @@ def set_kinematic(self, value: bool): self.mjcf.parent.add("freejoint") self._mojo.mark_dirty() elif not value and self.is_kinematic(): - self.mjcf.parent.freejoint.remove() + self.parent.remove_all_joints() self._mojo.mark_dirty() From b488f6aae8d59e31d1d17151e2126f145276d15b Mon Sep 17 00:00:00 2001 From: Nikita Chernyadev Date: Wed, 17 Apr 2024 13:10:56 +0100 Subject: [PATCH 5/5] Refactor --- mojo/elements/utils.py | 27 +++++++++++++++++++++++++++ mojo/mojo.py | 30 +++++------------------------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/mojo/elements/utils.py b/mojo/elements/utils.py index 97f7e1a..8ab8956 100644 --- a/mojo/elements/utils.py +++ b/mojo/elements/utils.py @@ -5,11 +5,14 @@ import numpy as np from dm_control import mjcf +from lxml import etree from mojo.elements.consts import TextureMapping # Default minimum distance between two geoms for them to be considered in collision. _DEFAULT_COLLISION_MARGIN: float = 1e-8 +_FREEJOINT_TAG = "freejoint" +_WORLDBODY_TAG = "worldbody" def has_collision( @@ -77,6 +80,30 @@ def load_mesh( return mesh +def resolve_freejoints( + root_model: mjcf.RootElement, model: mjcf.RootElement +) -> mjcf.RootElement: + child_xml = model.to_xml() + if len(child_xml.findall(f".//{_FREEJOINT_TAG}")) > 0: + root_xml = root_model.to_xml() + worldbody = root_xml.find(_WORLDBODY_TAG) + search_expr = f".//{child_xml.tag}" + for attr_name, attr_value in child_xml.attrib.items(): + search_expr += f"[@{attr_name}='{attr_value}']" + child_xml = root_xml.find(search_expr) + freejoints = child_xml.findall(f".//{_FREEJOINT_TAG}") + for joint in freejoints: + worldbody.append(joint.getparent()) + if len(child_xml) == 0: + child_xml.getparent().remove(child_xml) + root_model = mjcf.from_xml_string( + etree.tostring(root_xml), + escape_separators=True, + assets=root_model.get_assets(), + ) + return root_model + + class AssetStore: """Container for Mujoco assets.""" diff --git a/mojo/mojo.py b/mojo/mojo.py index da02754..e3c8195 100644 --- a/mojo/mojo.py +++ b/mojo/mojo.py @@ -3,12 +3,11 @@ import mujoco.viewer import numpy as np from dm_control import mjcf -from lxml import etree from mojo.elements.body import Body from mojo.elements.element import MujocoElement from mojo.elements.model import MujocoModel -from mojo.elements.utils import AssetStore +from mojo.elements.utils import AssetStore, resolve_freejoints class Mojo: @@ -129,29 +128,10 @@ def load_model( attach_site = self.root_element.mjcf if parent is None else parent.mjcf attached_model_mjcf = attach_site.attach(model_mjcf) if handle_freejoints: - FREEJOINT = "freejoint" - WORLDBODY = "worldbody" - attached_xml = attached_model_mjcf.to_xml() - attached_freejoints = attached_xml.findall(f".//{FREEJOINT}") - if len(attached_freejoints) > 0: - root_xml = self.root_element.mjcf.to_xml() - worldbody = root_xml.find(WORLDBODY) - xpath_expr = f".//{attached_xml.tag}" - for attr_name, attr_value in attached_xml.attrib.items(): - xpath_expr += f"[@{attr_name}='{attr_value}']" - attached_xml = root_xml.find(xpath_expr) - freejoints = attached_xml.findall(f".//{FREEJOINT}") - for freejoint in freejoints: - worldbody.append(freejoint.getparent()) - if len(attached_xml) == 0: - attached_xml.getparent().remove(attached_xml) - - root_model = mjcf.from_xml_string( - etree.tostring(root_xml), - escape_separators=True, - assets=self.root_element.mjcf.get_assets(), - ) - self.root_element = MujocoElement(self, root_model) + root_model_mjcf = resolve_freejoints( + self.root_element.mjcf, attached_model_mjcf + ) + self.root_element = MujocoElement(self, root_model_mjcf) self.mark_dirty() return Body(self, attached_model_mjcf)