Skip to content

Commit

Permalink
Add site object and add extra functions
Browse files Browse the repository at this point in the history
  • Loading branch information
stepjam committed Apr 5, 2024
1 parent 9a0d1c6 commit 58e7594
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 2 deletions.
1 change: 1 addition & 0 deletions mojo/elements/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from mojo.elements.joint import Joint
from mojo.elements.light import Light
from mojo.elements.model import MujocoModel
from mojo.elements.site import Site
8 changes: 7 additions & 1 deletion mojo/elements/body.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mujoco_utils import mjcf_utils
from typing_extensions import Self

from mojo.elements import geom
from mojo.elements import geom, joint
from mojo.elements.element import MujocoElement
from mojo.elements.utils import has_collision

Expand Down Expand Up @@ -53,6 +53,12 @@ def geoms(self) -> list[geom.Geom]:
geoms = self.mjcf.find_all("geom") or []
return [geom.Geom(self._mojo, mjcf) for mjcf in geoms]

@property
def joints(self) -> list[joint.Joint]:
# Loop through all children
joints = self.mjcf.find_all("joint") or []
return [joint.Joint(self._mojo, mjcf) for mjcf in joints]

def set_position(self, position: np.ndarray):
position = np.array(position) # ensure is numpy array
if self.mjcf.freejoint is not None:
Expand Down
8 changes: 8 additions & 0 deletions mojo/elements/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ class GeomType(Enum):
SDF = "sdf"


class SiteType(Enum):
SPHERE = "sphere"
CAPSULE = "capsule"
ELLIPSOID = "ellipsoid"
CYLINDER = "cylinder"
BOX = "box"


class TextureMapping(Enum):
PLANAR = "2d"
CUBE = "cube"
Expand Down
5 changes: 5 additions & 0 deletions mojo/elements/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def get_joint_position(self) -> float:
"""Get current joint position."""
return float(self._mojo.physics.bind(self.mjcf).qpos)

def set_joint_position(self, value: float):
self._mojo.physics.bind(self.mjcf).qpos *= 0
self._mojo.physics.bind(self.mjcf).qpos += value
self._mojo.mark_dirty()

def get_joint_velocity(self) -> float:
"""Get current joint velocity."""
return float(self._mojo.physics.bind(self.mjcf).qvel)
7 changes: 7 additions & 0 deletions mojo/elements/light.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,12 @@ def set_direction(self, direction: np.ndarray):
def get_direction(self) -> np.ndarray:
return self.mjcf.dir

def set_shadows(self, value: bool):
self.mjcf.castshadow = value
self._mojo.physics.bind(self.mjcf).castshadow = value

def is_using_shadows(self) -> bool:
return self.mjcf.castshadow == "true"

def get_light_type(self) -> LightType:
return LightType.DIRECTIONAL if self.mjcf.directional else LightType.SPOTLIGHT
139 changes: 139 additions & 0 deletions mojo/elements/site.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import mujoco
import numpy as np
from mujoco_utils import mjcf_utils
from typing_extensions import Self

from mojo.elements import body
from mojo.elements.consts import SiteType, TextureMapping
from mojo.elements.element import MujocoElement
from mojo.elements.utils import load_texture

if TYPE_CHECKING:
from mojo import Mojo
from mojo.elements.body import Body


class Site(MujocoElement):
@staticmethod
def get(
mojo: Mojo,
name: str,
parent: MujocoElement = None,
) -> Self:
root_mjcf = mojo.root_element.mjcf if parent is None else parent.mjcf
mjcf = mjcf_utils.safe_find(root_mjcf, "site", name)
return Site(mojo, mjcf)

@staticmethod
def create(
mojo: Mojo,
parent: body.Body = None,
size: np.ndarray = None,
position: np.ndarray = None,
quaternion: np.ndarray = None,
color: np.ndarray = None,
site_type: SiteType = SiteType.SPHERE,
group: int = 1,
) -> Self:
position = np.array([0, 0, 0]) if position is None else position
quaternion = np.array([1, 0, 0, 0]) if quaternion is None else quaternion
size = np.array([0.1, 0.1, 0.1]) if size is None else size
color = np.array([1, 1, 1, 1]) if color is None else color
parent = body.Body.create(mojo) if parent is None else parent
new_geom = parent.mjcf.add(
"site",
type=site_type.value,
pos=position,
quat=quaternion,
size=size,
rgba=color,
group=group,
)
new_site_obj = Site(mojo, new_geom)
mojo.mark_dirty()
return new_site_obj

@property
def parent(self) -> "Body":
# Have to do this due to circular import
from mojo.elements.body import Body

return Body(self._mojo, self.mjcf.parent)

def set_position(self, position: np.ndarray):
position = np.array(position) # ensure is numpy array
if self.mjcf.parent.freejoint:
self._mojo.physics.bind(self.mjcf.parent.freejoint).qpos[:3] = position
self._mojo.physics.bind(self.mjcf).pos = position
self.mjcf.pos = position

def get_position(self) -> np.ndarray:
if self.mjcf.parent.freejoint:
return self._mojo.physics.bind(self.mjcf.parent.freejoint).qpos[:3].copy()
return self._mojo.physics.bind(self.mjcf).xpos

def set_quaternion(self, quaternion: np.ndarray):
# wxyz
quaternion = np.array(quaternion) # ensure is numpy array
if self.mjcf.parent.freejoint is not None:
self._mojo.physics.bind(self.mjcf.parent.freejoint).qpos[3:] = quaternion
mat = np.zeros(9)
mujoco.mju_quat2Mat(mat, quaternion)
self._mojo.physics.bind(self.mjcf).xmat = mat
self.mjcf.quat = quaternion

def get_quaternion(self) -> np.ndarray:
if self.mjcf.parent.freejoint is not None:
return self._mojo.physics.bind(self.mjcf.parent.freejoint).qpos[3:].copy()
quat = np.zeros(4)
mujoco.mju_mat2Quat(quat, self._mojo.physics.bind(self.mjcf).xmat)
return quat

def set_color(self, color: np.ndarray):
color = np.array(color)
if len(color) == 3:
color = np.concatenate([color, [1]]) # add alpha
self._mojo.physics.bind(self.mjcf).rgba = color
self.mjcf.rgba = color

def get_color(self) -> np.ndarray:
return np.array(self._mojo.physics.bind(self.mjcf).rgba)

def set_texture(
self,
texture_path: str,
mapping: TextureMapping = TextureMapping.CUBE,
tex_repeat: np.ndarray = None,
tex_uniform: bool = False,
emission: float = 0.0,
specular: float = 0.0,
shininess: float = 0.0,
reflectance: float = 0.0,
color: np.ndarray = None,
):
# First check if we have loaded this texture
key_name = f"{texture_path}_{mapping.value}"
material = self._mojo.get_material(key_name)
if material is None:
material = load_texture(
self._mojo.root_element.mjcf,
texture_path,
mapping,
tex_repeat,
tex_uniform,
emission,
specular,
shininess,
reflectance,
color,
)
self._mojo.store_material(key_name, material)
self.mjcf.material = material
if self.mjcf.rgba is None:
# Have a default white color for texture
self.set_color(np.ones(4))
self._mojo.mark_dirty()
2 changes: 1 addition & 1 deletion mojo/mojo.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def load_model(
model_mjcf = mjcf.from_path(path)
if on_loaded is not None:
on_loaded(model_mjcf)
attach_site = self.root_element.mjcf if parent is None else parent
attach_site = self.root_element.mjcf if parent is None else parent.mjcf
attached_model_mjcf = attach_site.attach(model_mjcf)
self.mark_dirty()
return Body(self, attached_model_mjcf)
Expand Down
9 changes: 9 additions & 0 deletions tests/elements/test_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def test_get_set_active(mojo, light):
assert light.is_active() == expected


def test_get_set_shadows(mojo, light):
expected = False
light.set_shadows(expected)
assert light.is_using_shadows() == expected
expected = True
light.set_shadows(expected)
assert light.is_using_shadows() == expected


def test_get_set_ambient(mojo, light):
expected = np.array([0.8, 0.8, 0.8])
light.set_ambient(expected)
Expand Down
47 changes: 47 additions & 0 deletions tests/elements/test_site.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pathlib import Path

import numpy as np
import pytest
from numpy.testing import assert_array_equal

from mojo import Mojo
from mojo.elements import Site


@pytest.fixture()
def mojo() -> Mojo:
return Mojo(str(Path(__file__).parents[1] / "world.xml"))


@pytest.fixture()
def site(mojo: Mojo) -> Site:
return Site.create(mojo)


def test_get_set_position(mojo: Mojo, site: Site):
expected = np.array([2, 2, 2])
site.set_position(expected)
assert_array_equal(site.get_position(), expected)


def test_get_set_quaternion(mojo: Mojo, site: Site):
expected = np.array([0, 1, 0, 0])
site.set_quaternion(expected)
assert_array_equal(site.get_quaternion(), expected)


def test_get_set_color(mojo: Mojo, site: Site):
expected = np.array([0.8, 0.8, 0.8, 1.0], dtype=np.float32)
site.set_color(expected)
assert_array_equal(site.get_color(), expected)


def test_set_texture(mojo: Mojo, site: Site):
# just test that there are no exceptions
site.set_texture(
str(Path(__file__).parents[1] / "assets" / "textures" / "texture00.png")
)


def test_get_parent(mojo: Mojo, site: Site):
assert site.parent is not None

0 comments on commit 58e7594

Please sign in to comment.