Skip to content

Commit

Permalink
Add limited asset stores for textures and meshes
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikita Chernyadev committed Apr 4, 2024
1 parent 0b7f6b1 commit f9f7922
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 11 deletions.
41 changes: 41 additions & 0 deletions mojo/elements/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import uuid
import warnings
from collections import OrderedDict
from typing import Optional

import numpy as np
from dm_control import mjcf
Expand Down Expand Up @@ -72,3 +75,41 @@ def load_mesh(
uid = str(uuid.uuid4())
mesh = mjcf_model.asset.add("mesh", name=f"mesh_{uid}", file=path, scale=scale)
return mesh


class AssetStore:
"""Container for Mujoco assets."""

DEFAULT_CAPACITY = 32

def __init__(self, capacity: Optional[int] = None):
self._store: OrderedDict[str, mjcf.Element] = OrderedDict()
self._capacity = capacity

def get(self, path: str) -> Optional[mjcf.Element]:
"""Get MJCF asset by path."""
return self._store.get(path, None)

def remove(self, path: str) -> None:
"""Remove MJCF asset by path."""
if path in self._store:
asset = self._store.pop(path)
self._unload_asset(asset)

def store(self, path: str, asset_mjcf: mjcf.Element) -> None:
"""Add new MJCF asset."""
self._store[path] = asset_mjcf
if self._capacity and len(self._store) > self._capacity:
warnings.warn(
f"The capacity of the store ({self._capacity}) has been exceeded."
f"Removing the oldest asset.",
UserWarning,
)
_, asset = self._store.popitem(last=False)
self._unload_asset(asset)

@staticmethod
def _unload_asset(asset: mjcf.Element) -> None:
if asset.tag == "material":
asset.texture.remove()
asset.remove()
29 changes: 18 additions & 11 deletions mojo/mojo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,21 @@
from mojo.elements.body import Body
from mojo.elements.element import MujocoElement
from mojo.elements.model import MujocoModel
from mojo.elements.utils import AssetStore


class Mojo:
def __init__(self, base_model_path: str, timestep: float = 0.01):
def __init__(
self,
base_model_path: str,
timestep: float = 0.01,
texture_store_capacity: int = AssetStore.DEFAULT_CAPACITY,
mesh_store_capacity: int = AssetStore.DEFAULT_CAPACITY,
):
model_mjcf = mjcf.from_path(base_model_path)
self.root_element = MujocoModel(self, model_mjcf)
self._texture_store: dict[str, mjcf.Element] = {}
self._mesh_store: dict[str, mjcf.Element] = {}
self._texture_store: AssetStore = AssetStore(texture_store_capacity)
self._mesh_store: AssetStore = AssetStore(mesh_store_capacity)
self._dirty = True
self._passive_dirty = False
self._passive_viewer_handle = None
Expand Down Expand Up @@ -84,17 +91,17 @@ def step(self):
self._create_physics_from_model()
self.physics.step()

def get_material(self, path: str) -> mjcf.Element:
return self._texture_store.get(path, None)
def get_material(self, path: str) -> Optional[mjcf.Element]:
return self._texture_store.get(path)

def store_material(self, path: str, material_mjcf: mjcf.Element) -> mjcf.Element:
self._texture_store[path] = material_mjcf
def store_material(self, path: str, material_mjcf: mjcf.Element) -> None:
self._texture_store.store(path, material_mjcf)

def get_mesh(self, path: str) -> mjcf.Element:
return self._mesh_store.get(path, None)
def get_mesh(self, path: str) -> Optional[mjcf.Element]:
return self._mesh_store.get(path)

def store_mesh(self, path: str, mesh_mjcf: mjcf.Element) -> mjcf.Element:
self._mesh_store[path] = mesh_mjcf
def store_mesh(self, path: str, mesh_mjcf: mjcf.Element) -> None:
self._mesh_store.store(path, mesh_mjcf)

def load_model(
self,
Expand Down
57 changes: 57 additions & 0 deletions tests/utils/test_asset_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import shutil
import tempfile
from pathlib import Path

import pytest

from mojo import Mojo
from mojo.elements import Geom

TEXTURE_STORE_CAPACITY = 10
MESH_STORE_CAPACITY = 10


@pytest.fixture()
def mojo() -> Mojo:
return Mojo(
str(Path(__file__).parents[1] / "world.xml"),
texture_store_capacity=TEXTURE_STORE_CAPACITY,
mesh_store_capacity=MESH_STORE_CAPACITY,
)


@pytest.fixture()
def geom(mojo: Mojo) -> Geom:
return Geom.create(mojo)


def test_texture_store(mojo: Mojo, geom: Geom):
initial_count = len(mojo.root_element.mjcf.asset.texture)
texture_path = Path(__file__).parents[1] / "assets" / "textures" / "texture00.png"
geom.set_texture(str(texture_path))
with pytest.warns(UserWarning):
with tempfile.TemporaryDirectory() as temp_dir:
for i in range(TEXTURE_STORE_CAPACITY * 2):
temp_path = Path(temp_dir) / f"{i}{texture_path.suffix}"
shutil.copy2(texture_path, temp_path)
geom.set_texture(str(temp_path))
assert (
len(mojo.root_element.mjcf.asset.texture) - initial_count
<= TEXTURE_STORE_CAPACITY
)


def test_mesh_store(mojo: Mojo, geom: Geom):
initial_count = len(mojo.root_element.mjcf.asset.mesh)
mesh_path = Path(__file__).parents[1] / "assets" / "models" / "mug.obj"
geom.set_mesh(str(mesh_path))
with pytest.warns(UserWarning):
with tempfile.TemporaryDirectory() as temp_dir:
for i in range(MESH_STORE_CAPACITY * 2):
temp_path = Path(temp_dir) / f"{i}{mesh_path.suffix}"
shutil.copy2(mesh_path, temp_path)
geom.set_mesh(str(temp_path))
assert (
len(mojo.root_element.mjcf.asset.mesh) - initial_count
<= MESH_STORE_CAPACITY
)

0 comments on commit f9f7922

Please sign in to comment.