diff --git a/dm2gym/__init__.py b/dm2gym/__init__.py index 11846f2..57a700f 100644 --- a/dm2gym/__init__.py +++ b/dm2gym/__init__.py @@ -1 +1,2 @@ from .dm_control_env import DMControlEnv +from .opencv_image_viewer import OpenCVImageViewer diff --git a/dm2gym/dm_control_env.py b/dm2gym/dm_control_env.py index 02a0a1b..9ab27aa 100644 --- a/dm2gym/dm_control_env.py +++ b/dm2gym/dm_control_env.py @@ -1,9 +1,11 @@ from collections import OrderedDict import numpy as np + import gym from gym import spaces from gym.envs.registration import EnvSpec + from dm_control.rl.specs import ArraySpec from dm_control.rl.specs import BoundedArraySpec @@ -55,16 +57,27 @@ def reset(self): timestep = self.env.reset() return timestep.observation - def render(self, mode='human', **kwargs): + def render(self, mode='human', *, render_window_mode='gym', **kwargs): if 'camera_id' not in kwargs: - kwargs['camera_id'] = 0 # tracking camera + # Tracking camera + kwargs['camera_id'] = 0 + # Verify render window mode + assert render_window_mode in ['gym', 'opencv'],\ + "Invalid value for render_window_mode: {}".format( + render_window_mode + ) img = self.env.physics.render(**kwargs) if mode == 'rgb_array': return img elif mode == "human": - from gym.envs.classic_control import rendering if self.viewer is None: - self.viewer = rendering.SimpleImageViewer(maxwidth=1024) + # Open viewer + if render_window_mode == 'gym': + from gym.envs.classic_control import rendering + self.viewer = rendering.SimpleImageViewer(maxwidth=1024) + elif render_window_mode == 'opencv': + from dm2gym import OpenCVImageViewer + self.viewer = OpenCVImageViewer() self.viewer.imshow(img) return self.viewer.isopen else: diff --git a/dm2gym/opencv_image_viewer.py b/dm2gym/opencv_image_viewer.py new file mode 100644 index 0000000..acb90f4 --- /dev/null +++ b/dm2gym/opencv_image_viewer.py @@ -0,0 +1,36 @@ +"""A simple OpenCV based viewer for dm_control images""" + +import cv2 +import uuid + + +class OpenCVImageViewer(): + """A simple OpenCV highgui based dm_control image viewer + + This class is meant to be a drop-in replacement for + `gym.envs.classic_control.rendering.SimpleImageViewer` + """ + def __init__(self, *, escape_to_exit=False): + """Construct the viewing window""" + self._escape_to_exit = escape_to_exit + self._window_name = str(uuid.uuid4()) + cv2.namedWindow(self._window_name, cv2.WINDOW_AUTOSIZE) + self._isopen = True + + def __del__(self): + """Close the window""" + cv2.destroyWindow(self._window_name) + self._isopen = False + + def imshow(self, img): + """Show an image""" + # Convert image to BGR format + cv2.imshow(self._window_name, img[:, :, [2, 1, 0]]) + # Listen for escape key, then exit if pressed + if cv2.waitKey(1) in [27] and self._escape_to_exit: + exit() + + @property + def isopen(self): + """Is the window open?""" + return self._isopen diff --git a/setup.py b/setup.py index f2a585b..7fbca8d 100644 --- a/setup.py +++ b/setup.py @@ -7,9 +7,10 @@ with codecs.open('README.md', encoding='utf-8') as f: long_description = f.read() -# Minimal requried dependencies (full dependencies in requirements.txt) +# Minimal requried dependencies install_requires = ['numpy', - 'gym'] + 'gym', + 'opencv-python'] tests_require = ['pytest', 'flake8', 'sphinx',