-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
72 lines (62 loc) · 2.33 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import imageio.v2 as imageio
import re
import numpy as np
import os
from atc_gym.envs.conflict_art import ConflictArtEnv
from atc_gym.envs.conflict_gen_art import ConflictGenArtEnv
from atc_gym.envs.conflict_urban_art import ConflictUrbanArtEnv
from atc_gym.envs.conflict_multi_art import ConflictMultiArtEnv
def natural_sort(l):
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted(l, key=alphanum_key)
def make_gif() -> None:
# Get a list of all the images in the debug folder
png_folder = f'{os.path.dirname(__file__)}atc_gym/envs/data/images/'
png_list = natural_sort([img for img in os.listdir(png_folder) if '.png' in img])
# Create a gif
images = []
for img in png_list:
images.append(imageio.imread(png_folder + img))
imageio.mimsave('output/render.gif', images)
# Clean up
for filename in os.listdir(png_folder):
os.remove(png_folder + filename)
# Testing
if __name__ == "__main__":
# Variables
n_intruders = 3
image_mode = 'rgb'
image_pixel_size = 128
# Make environment
env = ConflictMultiArtEnv('images', n_intruders, image_mode, image_pixel_size)
env.reset()
#Test images
if False:
done = truncated = False
while not (done or truncated):
obs, reward, done, truncated, info = env.step(0)
make_gif()
#Test env creation
if False:
for a in range(100):
env.reset()
env.step(0)
#Test average dumb reward
if True:
rolling_avg = []
rew_list = []
tests_num = 200
for a in range(tests_num):
env.reset()
rew_sum = 0
steps = 0
done = truncated = False
while not (done or truncated):
spd = np.array([[0],[0],[0]])
obs, reward, done, truncated, info = env.step(spd)
rew_sum += reward
steps += 1
rew_list.append(rew_sum)
rolling_avg.append(np.average(rew_list))
print(f'Episode: {a+1}/{tests_num} | avg: {rew_sum} | rolling avg: {np.average(rolling_avg)} | steps: {steps}')