From 58e75940a039ce013a9c03046aa6ae6c65905b6e Mon Sep 17 00:00:00 2001 From: = Date: Fri, 5 Apr 2024 11:04:55 +0100 Subject: [PATCH] Add site object and add extra functions --- mojo/elements/__init__.py | 1 + mojo/elements/body.py | 8 +- mojo/elements/consts.py | 8 ++ mojo/elements/joint.py | 5 ++ mojo/elements/light.py | 7 ++ mojo/elements/site.py | 139 +++++++++++++++++++++++++++++++++++ mojo/mojo.py | 2 +- tests/elements/test_light.py | 9 +++ tests/elements/test_site.py | 47 ++++++++++++ 9 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 mojo/elements/site.py create mode 100644 tests/elements/test_site.py diff --git a/mojo/elements/__init__.py b/mojo/elements/__init__.py index a7238ea..a647391 100644 --- a/mojo/elements/__init__.py +++ b/mojo/elements/__init__.py @@ -5,3 +5,4 @@ from mojo.elements.joint import Joint from mojo.elements.light import Light from mojo.elements.model import MujocoModel +from mojo.elements.site import Site diff --git a/mojo/elements/body.py b/mojo/elements/body.py index f6945fd..b04a152 100644 --- a/mojo/elements/body.py +++ b/mojo/elements/body.py @@ -8,7 +8,7 @@ from mujoco_utils import mjcf_utils from typing_extensions import Self -from mojo.elements import geom +from mojo.elements import geom, joint from mojo.elements.element import MujocoElement from mojo.elements.utils import has_collision @@ -53,6 +53,12 @@ def geoms(self) -> list[geom.Geom]: geoms = self.mjcf.find_all("geom") or [] return [geom.Geom(self._mojo, mjcf) for mjcf in geoms] + @property + def joints(self) -> list[joint.Joint]: + # Loop through all children + joints = self.mjcf.find_all("joint") or [] + return [joint.Joint(self._mojo, mjcf) for mjcf in joints] + def set_position(self, position: np.ndarray): position = np.array(position) # ensure is numpy array if self.mjcf.freejoint is not None: diff --git a/mojo/elements/consts.py b/mojo/elements/consts.py index 97679de..03569a5 100644 --- a/mojo/elements/consts.py +++ b/mojo/elements/consts.py @@ -13,6 +13,14 @@ class GeomType(Enum): SDF = "sdf" +class SiteType(Enum): + SPHERE = "sphere" + CAPSULE = "capsule" + ELLIPSOID = "ellipsoid" + CYLINDER = "cylinder" + BOX = "box" + + class TextureMapping(Enum): PLANAR = "2d" CUBE = "cube" diff --git a/mojo/elements/joint.py b/mojo/elements/joint.py index b6ff7a5..b76a662 100644 --- a/mojo/elements/joint.py +++ b/mojo/elements/joint.py @@ -56,6 +56,11 @@ def get_joint_position(self) -> float: """Get current joint position.""" return float(self._mojo.physics.bind(self.mjcf).qpos) + def set_joint_position(self, value: float): + self._mojo.physics.bind(self.mjcf).qpos *= 0 + self._mojo.physics.bind(self.mjcf).qpos += value + self._mojo.mark_dirty() + def get_joint_velocity(self) -> float: """Get current joint velocity.""" return float(self._mojo.physics.bind(self.mjcf).qvel) diff --git a/mojo/elements/light.py b/mojo/elements/light.py index 6c6de0d..b5ae12b 100644 --- a/mojo/elements/light.py +++ b/mojo/elements/light.py @@ -102,5 +102,12 @@ def set_direction(self, direction: np.ndarray): def get_direction(self) -> np.ndarray: return self.mjcf.dir + def set_shadows(self, value: bool): + self.mjcf.castshadow = value + self._mojo.physics.bind(self.mjcf).castshadow = value + + def is_using_shadows(self) -> bool: + return self.mjcf.castshadow == "true" + def get_light_type(self) -> LightType: return LightType.DIRECTIONAL if self.mjcf.directional else LightType.SPOTLIGHT diff --git a/mojo/elements/site.py b/mojo/elements/site.py new file mode 100644 index 0000000..a0ef8f1 --- /dev/null +++ b/mojo/elements/site.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import mujoco +import numpy as np +from mujoco_utils import mjcf_utils +from typing_extensions import Self + +from mojo.elements import body +from mojo.elements.consts import SiteType, TextureMapping +from mojo.elements.element import MujocoElement +from mojo.elements.utils import load_texture + +if TYPE_CHECKING: + from mojo import Mojo + from mojo.elements.body import Body + + +class Site(MujocoElement): + @staticmethod + def get( + mojo: Mojo, + name: str, + parent: MujocoElement = None, + ) -> Self: + root_mjcf = mojo.root_element.mjcf if parent is None else parent.mjcf + mjcf = mjcf_utils.safe_find(root_mjcf, "site", name) + return Site(mojo, mjcf) + + @staticmethod + def create( + mojo: Mojo, + parent: body.Body = None, + size: np.ndarray = None, + position: np.ndarray = None, + quaternion: np.ndarray = None, + color: np.ndarray = None, + site_type: SiteType = SiteType.SPHERE, + group: int = 1, + ) -> Self: + position = np.array([0, 0, 0]) if position is None else position + quaternion = np.array([1, 0, 0, 0]) if quaternion is None else quaternion + size = np.array([0.1, 0.1, 0.1]) if size is None else size + color = np.array([1, 1, 1, 1]) if color is None else color + parent = body.Body.create(mojo) if parent is None else parent + new_geom = parent.mjcf.add( + "site", + type=site_type.value, + pos=position, + quat=quaternion, + size=size, + rgba=color, + group=group, + ) + new_site_obj = Site(mojo, new_geom) + mojo.mark_dirty() + return new_site_obj + + @property + def parent(self) -> "Body": + # Have to do this due to circular import + from mojo.elements.body import Body + + return Body(self._mojo, self.mjcf.parent) + + def set_position(self, position: np.ndarray): + position = np.array(position) # ensure is numpy array + if self.mjcf.parent.freejoint: + self._mojo.physics.bind(self.mjcf.parent.freejoint).qpos[:3] = position + self._mojo.physics.bind(self.mjcf).pos = position + self.mjcf.pos = position + + def get_position(self) -> np.ndarray: + if self.mjcf.parent.freejoint: + return self._mojo.physics.bind(self.mjcf.parent.freejoint).qpos[:3].copy() + return self._mojo.physics.bind(self.mjcf).xpos + + def set_quaternion(self, quaternion: np.ndarray): + # wxyz + quaternion = np.array(quaternion) # ensure is numpy array + if self.mjcf.parent.freejoint is not None: + self._mojo.physics.bind(self.mjcf.parent.freejoint).qpos[3:] = quaternion + mat = np.zeros(9) + mujoco.mju_quat2Mat(mat, quaternion) + self._mojo.physics.bind(self.mjcf).xmat = mat + self.mjcf.quat = quaternion + + def get_quaternion(self) -> np.ndarray: + if self.mjcf.parent.freejoint is not None: + return self._mojo.physics.bind(self.mjcf.parent.freejoint).qpos[3:].copy() + quat = np.zeros(4) + mujoco.mju_mat2Quat(quat, self._mojo.physics.bind(self.mjcf).xmat) + return quat + + def set_color(self, color: np.ndarray): + color = np.array(color) + if len(color) == 3: + color = np.concatenate([color, [1]]) # add alpha + self._mojo.physics.bind(self.mjcf).rgba = color + self.mjcf.rgba = color + + def get_color(self) -> np.ndarray: + return np.array(self._mojo.physics.bind(self.mjcf).rgba) + + def set_texture( + self, + texture_path: str, + mapping: TextureMapping = TextureMapping.CUBE, + tex_repeat: np.ndarray = None, + tex_uniform: bool = False, + emission: float = 0.0, + specular: float = 0.0, + shininess: float = 0.0, + reflectance: float = 0.0, + color: np.ndarray = None, + ): + # First check if we have loaded this texture + key_name = f"{texture_path}_{mapping.value}" + material = self._mojo.get_material(key_name) + if material is None: + material = load_texture( + self._mojo.root_element.mjcf, + texture_path, + mapping, + tex_repeat, + tex_uniform, + emission, + specular, + shininess, + reflectance, + color, + ) + self._mojo.store_material(key_name, material) + self.mjcf.material = material + if self.mjcf.rgba is None: + # Have a default white color for texture + self.set_color(np.ones(4)) + self._mojo.mark_dirty() diff --git a/mojo/mojo.py b/mojo/mojo.py index cffa845..9a98169 100644 --- a/mojo/mojo.py +++ b/mojo/mojo.py @@ -122,7 +122,7 @@ def load_model( model_mjcf = mjcf.from_path(path) if on_loaded is not None: on_loaded(model_mjcf) - attach_site = self.root_element.mjcf if parent is None else parent + attach_site = self.root_element.mjcf if parent is None else parent.mjcf attached_model_mjcf = attach_site.attach(model_mjcf) self.mark_dirty() return Body(self, attached_model_mjcf) diff --git a/tests/elements/test_light.py b/tests/elements/test_light.py index ac758fc..2fe47a0 100644 --- a/tests/elements/test_light.py +++ b/tests/elements/test_light.py @@ -34,6 +34,15 @@ def test_get_set_active(mojo, light): assert light.is_active() == expected +def test_get_set_shadows(mojo, light): + expected = False + light.set_shadows(expected) + assert light.is_using_shadows() == expected + expected = True + light.set_shadows(expected) + assert light.is_using_shadows() == expected + + def test_get_set_ambient(mojo, light): expected = np.array([0.8, 0.8, 0.8]) light.set_ambient(expected) diff --git a/tests/elements/test_site.py b/tests/elements/test_site.py new file mode 100644 index 0000000..56f932a --- /dev/null +++ b/tests/elements/test_site.py @@ -0,0 +1,47 @@ +from pathlib import Path + +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +from mojo import Mojo +from mojo.elements import Site + + +@pytest.fixture() +def mojo() -> Mojo: + return Mojo(str(Path(__file__).parents[1] / "world.xml")) + + +@pytest.fixture() +def site(mojo: Mojo) -> Site: + return Site.create(mojo) + + +def test_get_set_position(mojo: Mojo, site: Site): + expected = np.array([2, 2, 2]) + site.set_position(expected) + assert_array_equal(site.get_position(), expected) + + +def test_get_set_quaternion(mojo: Mojo, site: Site): + expected = np.array([0, 1, 0, 0]) + site.set_quaternion(expected) + assert_array_equal(site.get_quaternion(), expected) + + +def test_get_set_color(mojo: Mojo, site: Site): + expected = np.array([0.8, 0.8, 0.8, 1.0], dtype=np.float32) + site.set_color(expected) + assert_array_equal(site.get_color(), expected) + + +def test_set_texture(mojo: Mojo, site: Site): + # just test that there are no exceptions + site.set_texture( + str(Path(__file__).parents[1] / "assets" / "textures" / "texture00.png") + ) + + +def test_get_parent(mojo: Mojo, site: Site): + assert site.parent is not None