-
Notifications
You must be signed in to change notification settings - Fork 2
/
miniworld_gotoobj_env.py
124 lines (99 loc) · 3.73 KB
/
miniworld_gotoobj_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from pdb import set_trace
import numpy as np
from gymnasium import spaces, utils
from miniworld.entity import COLOR_NAMES, Ball, Box, Key
from miniworld.manual_control import ManualControl
from miniworld.miniworld import MiniWorldEnv
from miniworld.params import DEFAULT_PARAMS
class MiniworldGoToObjEnv(MiniWorldEnv, utils.EzPickle):
def __init__(
self, size=9, max_episode_steps=100, fast=True, manual_control=False, **kwargs
):
assert size >= 2
self.size = size
self.manual_control = manual_control
# Parameters for larger movement steps, fast stepping
params = DEFAULT_PARAMS.no_random()
if fast:
params.set("forward_step", 0.9)
params.set("turn_step", 90)
else:
params.set("forward_step", 0.3)
params.set("turn_step", 30)
MiniWorldEnv.__init__(
self, params=params, max_episode_steps=max_episode_steps, **kwargs
)
utils.EzPickle.__init__(
self, size=size, max_episode_steps=max_episode_steps, **kwargs
)
# Allow only movement actions (left/right/forward)
self.action_space = spaces.Discrete(self.actions.move_forward + 1)
def _gen_world(self):
self.add_rect_room(min_x=0, max_x=self.size, min_z=0, max_z=self.size)
colors = np.random.choice(COLOR_NAMES, size=4, replace=False)
ObjList = [
{
"name": "ball",
"obj": Ball,
},
{
"name": "box",
"obj": Box,
},
{
"name": "key",
"obj": Key,
},
]
_objs = np.random.choice(ObjList, size=4, replace=True)
positions = [
np.array([0.9, 0.5, 4.5]),
np.array([4.5, 0.5, 8.1]),
np.array([8.1, 0.5, 4.5]),
np.array([4.5, 0.5, 0.9]),
]
self.objs = []
for i in range(4):
if _objs[i]["name"] == "key":
self.objs.append(
self.place_entity(
_objs[i]["obj"](color=colors[i]),
pos=positions[i],
dir=i * np.pi / 2,
)
)
else:
self.objs.append(
self.place_entity(
_objs[i]["obj"](color=colors[i], size=0.5),
pos=positions[i],
dir=i * np.pi / 2,
)
)
# Select a random target object
ObjIdx = np.random.choice(4, size=1)[0]
self.target_obj = self.objs[ObjIdx]
self.target_obj_name = _objs[ObjIdx]["name"]
self.target_color = colors[ObjIdx]
# Generate the mission string
self.mission = f"go to the {self.target_color} {self.target_obj_name}"
if self.manual_control:
print(self.mission)
self.place_agent(pos=np.array([4.5, 0.5, 4.5]), dir=0.0)
def step(self, action):
obs, reward, termination, truncation, info = super().step(action)
ax, ay = self.agent.pos[0], self.agent.pos[2]
tx, ty = self.target_obj.pos[0], self.target_obj.pos[2]
next_ax = ax + 0.8 * np.cos(self.agent.dir)
next_ay = ay - 0.8 * np.sin(self.agent.dir)
_dis = np.sqrt((next_ax - tx) ** 2 + (next_ay - ty) ** 2)
if _dis <= 0.2:
reward += self._reward()
termination = True
return obs, reward, termination, truncation, info
def main():
env = MiniworldGoToObjEnv(view="top", render_mode="human", manual_control=True)
manual_control = ManualControl(env, True, True)
manual_control.run()
if __name__ == "__main__":
main()