-
Notifications
You must be signed in to change notification settings - Fork 96
/
run_model.py
32 lines (29 loc) · 853 Bytes
/
run_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import theano
import theano.tensor as T
import numpy as np
from theano_toolkit.parameters import Parameters
from theano_toolkit import updates
from theano_toolkit import utils as U
from theano_toolkit import hinton
import controller
import model
import tasks
import random
import math
def make_model(
input_size=8,
output_size=8,
mem_size=128,
mem_width=20,
hidden_sizes=[100]):
P = Parameters()
ctrl = controller.build(P, input_size, output_size,
mem_size, mem_width, hidden_sizes)
predict = model.build(P, mem_size, mem_width, hidden_sizes[-1], ctrl)
input_seq = T.matrix('input_sequence')
[M_curr, weights, output] = predict(input_seq)
test_fun = theano.function(
inputs=[input_seq],
outputs=[weights, output]
)
return P, test_fun