forked from machinaut/azero
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_model.py
executable file
·68 lines (60 loc) · 2.58 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python
import random
import unittest
import numpy as np
from itertools import product
from model import models, MLP
from game import games, MNOP
from azero import AlphaZero
from nn import loss_fwd
from util import sample_logits, sample_games
N = 10
class TestModel(unittest.TestCase):
def test_random_play(self):
for model_cls, game_cls in product(models, games):
game = game_cls()
for _ in range(N):
model = model_cls(game.n_action, game.n_view, game.n_player)
state, player, outcome = game.start()
while outcome is None:
obs = game.view(state, player)
valid = game.valid(state, player)
logits, _ = model.model(obs)
action = sample_logits(logits, valid)
state, player, outcome = game.step(state, player, action)
def test_nan_propagation(self):
for model_cls, game_cls in product(models, games):
game = game_cls()
if game.n_view == 0:
continue
for _ in range(N):
model = model_cls(game.n_action, game.n_view, game.n_player)
state, player, outcome = game.start()
while outcome is None:
obs = game.view(state, player)
bad_obs = obs.copy().flatten()
bad_obs[random.randrange(model.n_obs)] = np.nan
bad_obs = bad_obs.reshape(obs.shape)
bad_logits, bad_value = model.model(bad_obs)
assert np.isnan(bad_value).all()
assert np.isnan(bad_logits).all()
valid = game.valid(state, player)
logits, _ = model.model(obs)
action = sample_logits(logits, valid)
state, player, outcome = game.step(state, player, action)
def test_mlp_overfit(self):
azero = AlphaZero.make(MNOP, MLP, seed=0)
games = azero.play_multi()
obs, q, z = sample_games(games, rs=azero.rs)
loss, _ = azero._model._loss(obs, q, z)
for i in range(1000):
last = loss
azero._model._sparse_update(obs, q, z)
loss, _ = azero._model._loss(obs, q, z)
self.assertLess(loss, last)
# We should have a better score than the correct answer
# This is due to cross-entropy loss being lower if we extremize
true, _ = loss_fwd(np.c_[q, z], q, z, azero._model.c)
self.assertLess(loss, np.mean(true))
if __name__ == '__main__':
unittest.main()