-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathpenntreebank_evaluate.py
84 lines (71 loc) · 3.16 KB
/
penntreebank_evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import sys
import theano, itertools, pprint, copy, numpy as np, theano.tensor as T, re
from collections import OrderedDict
from blocks.serialization import load
import util
# to make unpickling work :-(
from penntreebank import *
# argument: path to a checkpoint file
main_loop = load(sys.argv[1])
print main_loop.log.current_row
# extract population statistic updates
updates = [update for update in main_loop.algorithm.updates
# FRAGILE
if re.search("_(mean|var)$", update[0].name)]
print updates
old_popstats = dict((popstat, popstat.get_value()) for popstat, _ in updates)
# baseline doesn't need all this
if updates:
train_stream = get_stream(which_set="train",
batch_size=100,
augment=False,
length=100)
nbatches = len(list(train_stream.get_epoch_iterator()))
# destructure moving average expression to construct a new expression
new_updates = []
for popstat, value in updates:
# FRAGILE
assert value.owner.op.scalar_op == theano.scalar.add
terms = value.owner.inputs
# right multiplicand of second term is popstat
assert popstat in theano.gof.graph.ancestors([terms[1].owner.inputs[1]])
# right multiplicand of first term is batchstat
batchstat = terms[0].owner.inputs[1]
old_popstats[popstat] = popstat.get_value()
# FRAGILE: assume population statistics not used in computation of batch statistics
# otherwise popstat should always have a reasonable value
popstat.set_value(0 * popstat.get_value(borrow=True))
new_updates.append((popstat, popstat + batchstat / float(nbatches)))
# FRAGILE: assume all the other algorithm updates are unneeded for computation of batch statistics
estimate_fn = theano.function(main_loop.algorithm.inputs, [],
updates=new_updates, on_unused_input="warn")
for batch in train_stream.get_epoch_iterator(as_dict=True):
estimate_fn(**batch)
new_popstats = dict((popstat, popstat.get_value()) for popstat, _ in updates)
from blocks.monitoring.evaluators import DatasetEvaluator
results = dict()
for situation in "training inference".split():
results[situation] = dict()
outputs, = [
extension._evaluator.theano_variables
for extension in main_loop.extensions
if getattr(extension, "prefix", None) == "valid_%s" % situation]
evaluator = DatasetEvaluator(outputs)
for which_set in "train valid test".split():
results[situation][which_set] = OrderedDict(
(length, evaluator.evaluate(get_stream(
which_set=which_set,
batch_size=100,
augment=False,
length=length)))
for length in [50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000])
results["proper_test"] = evaluator.evaluate(
get_stream(
which_set="test",
batch_size=1,
length=446184))
import cPickle
cPickle.dump(dict(results=results,
old_popstats=old_popstats,
new_popstats=new_popstats),
open(sys.argv[1] + "_popstat_results.pkl", "w"))