-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathatari_utils.py
111 lines (86 loc) · 3.45 KB
/
atari_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from collections import deque
from gym import Wrapper, ObservationWrapper, ActionWrapper
from gym.spaces.box import Box
import numpy as np
import cv2
def _process_frame42(frame):
reshaped_screen = np.reshape(frame, [210, 160, 3]).astype(np.float32).mean(2)
resized_screen = cv2.resize(reshaped_screen, (84, 110))
x_t = resized_screen[18:102, :]
x_t = cv2.resize(x_t, (42, 42))
x_t *= (1.0 / 255.0)
x_t = np.reshape(x_t, [42, 42, 1])
return x_t
class AtariRescale42x42Env(ObservationWrapper):
def __init__(self, env=None):
super(AtariRescale42x42Env, self).__init__(env)
self.observation_space = Box(0, 255, [42, 42, 1])
def _observation(self, observation):
return _process_frame42(observation)
def _process_frame84(frame):
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
img = img[:, :, 0] * 0.2126 + img[:, :, 1] * 0.0722 + img[:, :, 2] * 0.7152
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_LINEAR)
x_t = resized_screen[18:102, :]
x_t /= 255.0
# x_t -= 0.5
x_t = np.reshape(x_t, [84, 84, 1])
return x_t
class AtariRescale84x84Env(ObservationWrapper):
def __init__(self, env=None):
super(AtariRescale84x84Env, self).__init__(env)
self.observation_space = Box(0, 255, [84, 84, 1])
def _observation(self, observation):
return _process_frame84(observation)
class RandomizedResetEnv(Wrapper):
def __init__(self, env, no_op_max=7):
super(RandomizedResetEnv, self).__init__(env)
self._no_op_max = no_op_max
def _reset(self):
ob = self.env.reset()
action = 0
# randomize initial state
if self._no_op_max > 0:
no_op = np.random.randint(0, self._no_op_max + 1)
for _ in range(no_op):
ob, _, _, _ = self.env.step(action)
return ob
class OneLiveResetEnv(Wrapper):
def _step(self, action):
lives = self.env.unwrapped.ale.lives()
observation, reward, done, info = self.env.step(action)
if lives != self.env.unwrapped.ale.lives():
done = True
return observation, reward, done, info
class UnstuckPolicyEnv(ActionWrapper):
actions = deque(maxlen=30)
def _action(self, action):
if self.actions.count(action) == 30:
action = 1
self.actions.append(action)
return action
def _reverse_action(self, action):
return action
class ObservationBuffer(Wrapper):
def __init__(self, env, buffer_size=4):
super(ObservationBuffer, self).__init__(env)
self.buffer_size = buffer_size
self.buffer = deque(maxlen=self.buffer_size)
assert len(self.env.observation_space.shape) == 3
self._shape = list(self.env.observation_space.shape)
self._num_channels = self._shape[2]
self._shape[2] *= self.buffer_size
self.observation_space = Box(-0.5, 0.5, self._shape)
def _step(self, action):
observation, reward, done, info = self.env.step(action)
self.buffer.append(observation)
return np.concatenate(self.buffer, axis=2), reward, done, info
def _reset(self):
obs = self.env.reset()
for _ in range(self.buffer_size):
self.buffer.append(obs)
return np.concatenate(self.buffer, axis=2)
# def _render(self, mode='human', close=False):
# if mode == "rgb_array":
# return self.buffer[-1]
# return self.env.render(mode, close)