Skip to content

Commit

Permalink
Get/set matrix for sites
Browse files Browse the repository at this point in the history
  • Loading branch information
stepjam committed Apr 5, 2024
1 parent 58e7594 commit 69c3098
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
10 changes: 10 additions & 0 deletions mojo/elements/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ def get_quaternion(self) -> np.ndarray:
mujoco.mju_mat2Quat(quat, self._mojo.physics.bind(self.mjcf).xmat)
return quat

def set_matrix(self, matrix: np.ndarray):
assert matrix.shape == (3, 3)
self._mojo.physics.bind(self.mjcf).xmat = np.reshape(matrix, (9,))
quat = np.zeros(4)
mujoco.mju_mat2Quat(quat, self._mojo.physics.bind(self.mjcf).xmat)
self.mjcf.quat = quat

def get_matrix(self) -> np.ndarray:
return np.reshape(self._mojo.physics.bind(self.mjcf).xmat.copy(), (3, 3))

def set_color(self, color: np.ndarray):
color = np.array(color)
if len(color) == 3:
Expand Down
6 changes: 6 additions & 0 deletions tests/elements/test_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def test_get_set_quaternion(mojo: Mojo, site: Site):
assert_array_equal(site.get_quaternion(), expected)


def test_get_set_matrix(mojo: Mojo, site: Site):
expected = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
site.set_matrix(expected)
assert_array_equal(site.get_matrix(), expected)


def test_get_set_color(mojo: Mojo, site: Site):
expected = np.array([0.8, 0.8, 0.8, 1.0], dtype=np.float32)
site.set_color(expected)
Expand Down

0 comments on commit 69c3098

Please sign in to comment.