-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathtests.py
63 lines (39 loc) · 1.33 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import unittest
import drnn
import torch
class TestForward(unittest.TestCase):
def test(self):
model = drnn.DRNN(10, 10, 4, 0, 'GRU')
x = torch.randn(23, 3, 10)
out = model(x)[0]
self.assertTrue(out.size(0) == 23)
self.assertTrue(out.size(1) == 3)
self.assertTrue(out.size(2) == 10)
class TestReshape(unittest.TestCase):
def test(self):
model = drnn.DRNN(10, 10, 4, 0, 'GRU')
x = torch.randn(24, 3, 10)
split_x = model._prepare_inputs(x, 2)
second_block = x[1::2]
check = split_x[:, x.size(1):, :]
self.assertTrue((second_block == check).all())
unsplit_x = model._split_outputs(split_x, 2)
self.assertTrue((x == unsplit_x).all())
class TestHidden(unittest.TestCase):
def test(self):
model = drnn.DRNN(10, 10, 4, 0, 'GRU')
x = torch.randn(23, 3, 10)
hidden = model(x)[1]
self.assertEqual(len(hidden), 4)
for hid in hidden:
print(hid.size())
class TestPassHidden(unittest.TestCase):
def test(self):
model = drnn.DRNN(10, 10, 4, 0, 'GRU')
hidden = []
for i in range(4):
hidden.append(torch.randn(2 ** i, 3, 10))
x = torch.randn(24, 3, 10)
hidden = model(x, hidden)
if __name__ == '__main__':
unittest.main()