Skip to content

Commit

Permalink
Merge pull request #2 from aaronsnoswell/opencv-render-window
Browse files Browse the repository at this point in the history
Add OpenCV render window mode
  • Loading branch information
zuoxingdong authored May 23, 2019
2 parents f010373 + 8e1e2f7 commit e16048a
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 6 deletions.
1 change: 1 addition & 0 deletions dm2gym/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .dm_control_env import DMControlEnv
from .opencv_image_viewer import OpenCVImageViewer
21 changes: 17 additions & 4 deletions dm2gym/dm_control_env.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections import OrderedDict

import numpy as np

import gym
from gym import spaces
from gym.envs.registration import EnvSpec

from dm_control.rl.specs import ArraySpec
from dm_control.rl.specs import BoundedArraySpec

Expand Down Expand Up @@ -55,16 +57,27 @@ def reset(self):
timestep = self.env.reset()
return timestep.observation

def render(self, mode='human', **kwargs):
def render(self, mode='human', *, render_window_mode='gym', **kwargs):
if 'camera_id' not in kwargs:
kwargs['camera_id'] = 0 # tracking camera
# Tracking camera
kwargs['camera_id'] = 0
# Verify render window mode
assert render_window_mode in ['gym', 'opencv'],\
"Invalid value for render_window_mode: {}".format(
render_window_mode
)
img = self.env.physics.render(**kwargs)
if mode == 'rgb_array':
return img
elif mode == "human":
from gym.envs.classic_control import rendering
if self.viewer is None:
self.viewer = rendering.SimpleImageViewer(maxwidth=1024)
# Open viewer
if render_window_mode == 'gym':
from gym.envs.classic_control import rendering
self.viewer = rendering.SimpleImageViewer(maxwidth=1024)
elif render_window_mode == 'opencv':
from dm2gym import OpenCVImageViewer
self.viewer = OpenCVImageViewer()
self.viewer.imshow(img)
return self.viewer.isopen
else:
Expand Down
36 changes: 36 additions & 0 deletions dm2gym/opencv_image_viewer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""A simple OpenCV based viewer for dm_control images"""

import cv2
import uuid


class OpenCVImageViewer():
"""A simple OpenCV highgui based dm_control image viewer
This class is meant to be a drop-in replacement for
`gym.envs.classic_control.rendering.SimpleImageViewer`
"""
def __init__(self, *, escape_to_exit=False):
"""Construct the viewing window"""
self._escape_to_exit = escape_to_exit
self._window_name = str(uuid.uuid4())
cv2.namedWindow(self._window_name, cv2.WINDOW_AUTOSIZE)
self._isopen = True

def __del__(self):
"""Close the window"""
cv2.destroyWindow(self._window_name)
self._isopen = False

def imshow(self, img):
"""Show an image"""
# Convert image to BGR format
cv2.imshow(self._window_name, img[:, :, [2, 1, 0]])
# Listen for escape key, then exit if pressed
if cv2.waitKey(1) in [27] and self._escape_to_exit:
exit()

@property
def isopen(self):
"""Is the window open?"""
return self._isopen
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
with codecs.open('README.md', encoding='utf-8') as f:
long_description = f.read()

# Minimal requried dependencies (full dependencies in requirements.txt)
# Minimal requried dependencies
install_requires = ['numpy',
'gym']
'gym',
'opencv-python']
tests_require = ['pytest',
'flake8',
'sphinx',
Expand Down

0 comments on commit e16048a

Please sign in to comment.