Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
stepjam committed Apr 16, 2024
1 parent 358f93c commit a95e233
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 25 deletions.
8 changes: 8 additions & 0 deletions mojo/elements/body.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,11 @@ def has_collided(self, other: Body = None):
this_object_id = self._mojo.physics.bind(self.mjcf).element_id
other_object_id = self._mojo.physics.bind(other.mjcf).element_id
return has_collision(self._mojo.physics, other_object_id, this_object_id)

def remove(self):
self.mjcf.remove()

def set_kinematic(self, value: bool):
if value and not self.is_kinematic():
self.mjcf.add("freejoint")
self._mojo.mark_dirty()
47 changes: 23 additions & 24 deletions mojo/elements/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ def _is_kinematic(elem: mjcf.Element):
return has_freejoint or has_joints or _is_kinematic(elem.parent)


def _find_freejoint(elem: mjcf.Element):
if elem.parent is None:
# Root of tree
return None
if free_joint := getattr(elem, "freejoint", None):
return free_joint
return _find_freejoint(elem.parent)


class MujocoElement(ABC):
def __init__(self, mojo: Mojo, mjcf_elem: mjcf.RootElement):
self._mojo = mojo
Expand All @@ -35,51 +44,41 @@ def id(self):

def set_position(self, position: np.ndarray):
position = np.array(position) # ensure is numpy array
if hasattr(self.mjcf, "freejoint") and self.mjcf.freejoint is not None:
self._mojo.physics.bind(self.mjcf.freejoint).qpos[:3] = position
if freejoint := _find_freejoint(self.mjcf):
self._mojo.physics.bind(freejoint).qpos[:3] = position
else:
self._mojo.physics.bind(self.mjcf).pos = position
self.mjcf.pos = position

def get_position(self) -> np.ndarray:
# if the element has a free joint (and thus is a body), then access qpos
if hasattr(self.mjcf, "freejoint") and self.mjcf.freejoint is not None:
return self._mojo.physics.bind(self.mjcf.freejoint).qpos[:3].copy()
if freejoint := _find_freejoint(self.mjcf):
return self._mojo.physics.bind(freejoint).qpos[:3].copy()
return self._mojo.physics.bind(self.mjcf).xpos.copy()

def set_quaternion(self, quaternion: np.ndarray):
# wxyz
quaternion = np.array(quaternion) # ensure is numpy array
if hasattr(self.mjcf, "freejoint") and self.mjcf.freejoint is not None:
self._mojo.physics.bind(self.mjcf.freejoint).qpos[3:] = quaternion
binded = self._mojo.physics.bind(self.mjcf)
if binded.quat is not None:
binded.quat = quaternion
mat = np.zeros(9)
mujoco.mju_quat2Mat(mat, quaternion)
self._mojo.physics.bind(self.mjcf).xmat = mat
if freejoint := _find_freejoint(self.mjcf):
self._mojo.physics.bind(freejoint).qpos[3:] = quaternion
else:
mat = np.zeros(9)
mujoco.mju_quat2Mat(mat, quaternion)
self._mojo.physics.bind(self.mjcf).xmat = mat
self.mjcf.quat = quaternion

def get_quaternion(self) -> np.ndarray:
quat = np.zeros(4)
mujoco.mju_mat2Quat(quat, self._mojo.physics.bind(self.mjcf).xmat)
return quat

def set_kinematic(self, value: bool):
if value and not self.is_kinematic():
self.mjcf.add("freejoint")
self._mojo.mark_dirty()
if (
not value
and self.is_kinematic()
and hasattr(self.mjcf, "freejoint")
and self.mjcf.freejoint is not None
):
self.mjcf.freejoint.remove()

def is_kinematic(self) -> bool:
return _is_kinematic(self.mjcf)

@property
def id(self):
return self._mojo.physics.bind(self.mjcf).element_id

def __eq__(self, other):
return (
isinstance(other, MujocoElement)
Expand Down
15 changes: 14 additions & 1 deletion mojo/elements/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def create(
mesh_scale: np.ndarray = None,
group: int = 1,
density: float = 1000,
mass: float = None,
) -> Self:
position = np.array([0, 0, 0]) if position is None else position
quaternion = np.array([1, 0, 0, 0]) if quaternion is None else quaternion
Expand All @@ -57,6 +58,9 @@ def create(
"To create mesh geom, 'mesh_file' must be defined "
"and 'geom_type' must be GeomType.MESH"
)
kwargs = {}
if mass is not None:
kwargs["mass"] = mass
new_geom = parent.mjcf.add(
"geom",
type=geom_type.value,
Expand All @@ -66,6 +70,7 @@ def create(
rgba=color,
group=group,
density=density,
**kwargs,
)
new_geom_obj = Geom(mojo, new_geom)
if mesh_path:
Expand Down Expand Up @@ -104,7 +109,10 @@ def set_texture(
color: np.ndarray = None,
):
# First check if we have loaded this texture
key_name = f"{texture_path}_{mapping.value}"
key_name = (
f"{texture_path}_{mapping.value}_{tex_uniform}_{tex_repeat}_"
f"{emission}_{specular}_{shininess}_{reflectance}_{color}"
)
material = self._mojo.get_material(key_name)
if material is None:
material = load_texture(
Expand Down Expand Up @@ -162,3 +170,8 @@ def has_collided(self, other: Geom = None):
this_object_id = self._mojo.physics.bind(self.mjcf).element_id
other_object_id = self._mojo.physics.bind(other.mjcf).element_id
return has_collision(self._mojo.physics, other_object_id, this_object_id)

def set_kinematic(self, value: bool):
if value and not self.is_kinematic():
self.mjcf.parent.add("freejoint")
self._mojo.mark_dirty()

0 comments on commit a95e233

Please sign in to comment.