From 9a0d1c6248e58acd67e164bda9cbcdc58afd62ba Mon Sep 17 00:00:00 2001 From: Nikita Chernyadev Date: Thu, 4 Apr 2024 21:28:52 +0100 Subject: [PATCH] Add asset stores of limited capacity for textures and meshes (#3) * Add limited asset stores for textures and meshes * Rename store to add --------- Co-authored-by: Nikita Chernyadev --- mojo/elements/utils.py | 41 ++++++++++++++++++++++++ mojo/mojo.py | 29 ++++++++++------- tests/utils/test_asset_store.py | 57 +++++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 11 deletions(-) create mode 100644 tests/utils/test_asset_store.py diff --git a/mojo/elements/utils.py b/mojo/elements/utils.py index 2df5881..97f7e1a 100644 --- a/mojo/elements/utils.py +++ b/mojo/elements/utils.py @@ -1,4 +1,7 @@ import uuid +import warnings +from collections import OrderedDict +from typing import Optional import numpy as np from dm_control import mjcf @@ -72,3 +75,41 @@ def load_mesh( uid = str(uuid.uuid4()) mesh = mjcf_model.asset.add("mesh", name=f"mesh_{uid}", file=path, scale=scale) return mesh + + +class AssetStore: + """Container for Mujoco assets.""" + + DEFAULT_CAPACITY = 32 + + def __init__(self, capacity: Optional[int] = None): + self._store: OrderedDict[str, mjcf.Element] = OrderedDict() + self._capacity = capacity + + def get(self, path: str) -> Optional[mjcf.Element]: + """Get MJCF asset by path.""" + return self._store.get(path, None) + + def remove(self, path: str) -> None: + """Remove MJCF asset by path.""" + if path in self._store: + asset = self._store.pop(path) + self._unload_asset(asset) + + def add(self, path: str, asset_mjcf: mjcf.Element) -> None: + """Add new MJCF asset.""" + self._store[path] = asset_mjcf + if self._capacity and len(self._store) > self._capacity: + warnings.warn( + f"The capacity of the store ({self._capacity}) has been exceeded." + f"Removing the oldest asset.", + UserWarning, + ) + _, asset = self._store.popitem(last=False) + self._unload_asset(asset) + + @staticmethod + def _unload_asset(asset: mjcf.Element) -> None: + if asset.tag == "material": + asset.texture.remove() + asset.remove() diff --git a/mojo/mojo.py b/mojo/mojo.py index f725d4e..cffa845 100644 --- a/mojo/mojo.py +++ b/mojo/mojo.py @@ -7,14 +7,21 @@ from mojo.elements.body import Body from mojo.elements.element import MujocoElement from mojo.elements.model import MujocoModel +from mojo.elements.utils import AssetStore class Mojo: - def __init__(self, base_model_path: str, timestep: float = 0.01): + def __init__( + self, + base_model_path: str, + timestep: float = 0.01, + texture_store_capacity: int = AssetStore.DEFAULT_CAPACITY, + mesh_store_capacity: int = AssetStore.DEFAULT_CAPACITY, + ): model_mjcf = mjcf.from_path(base_model_path) self.root_element = MujocoModel(self, model_mjcf) - self._texture_store: dict[str, mjcf.Element] = {} - self._mesh_store: dict[str, mjcf.Element] = {} + self._texture_store: AssetStore = AssetStore(texture_store_capacity) + self._mesh_store: AssetStore = AssetStore(mesh_store_capacity) self._dirty = True self._passive_dirty = False self._passive_viewer_handle = None @@ -84,17 +91,17 @@ def step(self): self._create_physics_from_model() self.physics.step() - def get_material(self, path: str) -> mjcf.Element: - return self._texture_store.get(path, None) + def get_material(self, path: str) -> Optional[mjcf.Element]: + return self._texture_store.get(path) - def store_material(self, path: str, material_mjcf: mjcf.Element) -> mjcf.Element: - self._texture_store[path] = material_mjcf + def store_material(self, path: str, material_mjcf: mjcf.Element) -> None: + self._texture_store.add(path, material_mjcf) - def get_mesh(self, path: str) -> mjcf.Element: - return self._mesh_store.get(path, None) + def get_mesh(self, path: str) -> Optional[mjcf.Element]: + return self._mesh_store.get(path) - def store_mesh(self, path: str, mesh_mjcf: mjcf.Element) -> mjcf.Element: - self._mesh_store[path] = mesh_mjcf + def store_mesh(self, path: str, mesh_mjcf: mjcf.Element) -> None: + self._mesh_store.add(path, mesh_mjcf) def load_model( self, diff --git a/tests/utils/test_asset_store.py b/tests/utils/test_asset_store.py new file mode 100644 index 0000000..abbc2d3 --- /dev/null +++ b/tests/utils/test_asset_store.py @@ -0,0 +1,57 @@ +import shutil +import tempfile +from pathlib import Path + +import pytest + +from mojo import Mojo +from mojo.elements import Geom + +TEXTURE_STORE_CAPACITY = 10 +MESH_STORE_CAPACITY = 10 + + +@pytest.fixture() +def mojo() -> Mojo: + return Mojo( + str(Path(__file__).parents[1] / "world.xml"), + texture_store_capacity=TEXTURE_STORE_CAPACITY, + mesh_store_capacity=MESH_STORE_CAPACITY, + ) + + +@pytest.fixture() +def geom(mojo: Mojo) -> Geom: + return Geom.create(mojo) + + +def test_texture_store(mojo: Mojo, geom: Geom): + initial_count = len(mojo.root_element.mjcf.asset.texture) + texture_path = Path(__file__).parents[1] / "assets" / "textures" / "texture00.png" + geom.set_texture(str(texture_path)) + with pytest.warns(UserWarning): + with tempfile.TemporaryDirectory() as temp_dir: + for i in range(TEXTURE_STORE_CAPACITY * 2): + temp_path = Path(temp_dir) / f"{i}{texture_path.suffix}" + shutil.copy2(texture_path, temp_path) + geom.set_texture(str(temp_path)) + assert ( + len(mojo.root_element.mjcf.asset.texture) - initial_count + <= TEXTURE_STORE_CAPACITY + ) + + +def test_mesh_store(mojo: Mojo, geom: Geom): + initial_count = len(mojo.root_element.mjcf.asset.mesh) + mesh_path = Path(__file__).parents[1] / "assets" / "models" / "mug.obj" + geom.set_mesh(str(mesh_path)) + with pytest.warns(UserWarning): + with tempfile.TemporaryDirectory() as temp_dir: + for i in range(MESH_STORE_CAPACITY * 2): + temp_path = Path(temp_dir) / f"{i}{mesh_path.suffix}" + shutil.copy2(mesh_path, temp_path) + geom.set_mesh(str(temp_path)) + assert ( + len(mojo.root_element.mjcf.asset.mesh) - initial_count + <= MESH_STORE_CAPACITY + )