-
Notifications
You must be signed in to change notification settings - Fork 918
/
lstm.py
105 lines (82 loc) · 4.12 KB
/
lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Implements the long-short term memory character model.
This version vectorizes over multiple examples, but each string
has a fixed length."""
from os.path import dirname, join
from rnn import build_dataset, concat_and_multiply, one_hot_to_string, sigmoid, string_to_one_hot
import autograd.numpy as np
import autograd.numpy.random as npr
from autograd import grad
from autograd.misc.optimizers import adam
from autograd.scipy.special import logsumexp
def init_lstm_params(input_size, state_size, output_size, param_scale=0.01, rs=npr.RandomState(0)):
def rp(*shape):
return rs.randn(*shape) * param_scale
return {
"init cells": rp(1, state_size),
"init hiddens": rp(1, state_size),
"change": rp(input_size + state_size + 1, state_size),
"forget": rp(input_size + state_size + 1, state_size),
"ingate": rp(input_size + state_size + 1, state_size),
"outgate": rp(input_size + state_size + 1, state_size),
"predict": rp(state_size + 1, output_size),
}
def lstm_predict(params, inputs):
def update_lstm(input, hiddens, cells):
change = np.tanh(concat_and_multiply(params["change"], input, hiddens))
forget = sigmoid(concat_and_multiply(params["forget"], input, hiddens))
ingate = sigmoid(concat_and_multiply(params["ingate"], input, hiddens))
outgate = sigmoid(concat_and_multiply(params["outgate"], input, hiddens))
cells = cells * forget + ingate * change
hiddens = outgate * np.tanh(cells)
return hiddens, cells
def hiddens_to_output_probs(hiddens):
output = concat_and_multiply(params["predict"], hiddens)
return output - logsumexp(output, axis=1, keepdims=True) # Normalize log-probs.
num_sequences = inputs.shape[1]
hiddens = np.repeat(params["init hiddens"], num_sequences, axis=0)
cells = np.repeat(params["init cells"], num_sequences, axis=0)
output = [hiddens_to_output_probs(hiddens)]
for input in inputs: # Iterate over time steps.
hiddens, cells = update_lstm(input, hiddens, cells)
output.append(hiddens_to_output_probs(hiddens))
return output
def lstm_log_likelihood(params, inputs, targets):
logprobs = lstm_predict(params, inputs)
loglik = 0.0
num_time_steps, num_examples, _ = inputs.shape
for t in range(num_time_steps):
loglik += np.sum(logprobs[t] * targets[t])
return loglik / (num_time_steps * num_examples)
if __name__ == "__main__":
num_chars = 128
# Learn to predict our own source code.
text_filename = join(dirname(__file__), "lstm.py")
train_inputs = build_dataset(text_filename, sequence_length=30, alphabet_size=num_chars, max_lines=60)
init_params = init_lstm_params(input_size=128, output_size=128, state_size=40, param_scale=0.01)
def print_training_prediction(weights):
print("Training text Predicted text")
logprobs = np.asarray(lstm_predict(weights, train_inputs))
for t in range(logprobs.shape[1]):
training_text = one_hot_to_string(train_inputs[:, t, :])
predicted_text = one_hot_to_string(logprobs[:, t, :])
print(training_text.replace("\n", " ") + "|" + predicted_text.replace("\n", " "))
def training_loss(params, iter):
return -lstm_log_likelihood(params, train_inputs, train_inputs)
def callback(weights, iter, gradient):
if iter % 10 == 0:
print("Iteration", iter, "Train loss:", training_loss(weights, 0))
print_training_prediction(weights)
# Build gradient of loss function using autograd.
training_loss_grad = grad(training_loss)
print("Training LSTM...")
trained_params = adam(training_loss_grad, init_params, step_size=0.1, num_iters=1000, callback=callback)
print()
print("Generating text from LSTM...")
num_letters = 30
for t in range(20):
text = ""
for i in range(num_letters):
seqs = string_to_one_hot(text, num_chars)[:, np.newaxis, :]
logprobs = lstm_predict(trained_params, seqs)[-1].ravel()
text += chr(npr.choice(len(logprobs), p=np.exp(logprobs)))
print(text)