-
Notifications
You must be signed in to change notification settings - Fork 6
/
udc_model.py
132 lines (110 loc) · 4.23 KB
/
udc_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import tensorflow as tf
import sys
def get_id_feature(features, key, len_key, max_len):
ids = features[key]
ids_len = tf.squeeze(features[len_key], [1])
ids_len = tf.minimum(ids_len, tf.constant(max_len, dtype=tf.int64))
return ids, ids_len
def create_train_op(loss, hparams):
train_op = tf.contrib.layers.optimize_loss(
loss=loss,
global_step=tf.contrib.framework.get_global_step(),
learning_rate=hparams.learning_rate,
clip_gradients=10.0,
optimizer=hparams.optimizer)
return train_op
def create_model_fn(hparams, model_impl):
def model_fn(features, targets, mode):
context, context_len = get_id_feature(
features, "context", "context_len", hparams.max_context_len)
utterance, utterance_len = get_id_feature(
features, "utterance", "utterance_len", hparams.max_utterance_len)
persona, persona_len = get_id_feature(
features, "persona", "persona_len", hparams.max_persona_len)
if mode == tf.contrib.learn.ModeKeys.TRAIN:
probs, loss = model_impl(
hparams,
mode,
context,
context_len,
utterance,
utterance_len,
persona,
persona_len,
targets)
train_op = create_train_op(loss, hparams)
return probs, loss, train_op
if mode == tf.contrib.learn.ModeKeys.INFER:
all_contexts = [context]
all_context_lens = [context_len]
all_utterances = [utterance]
all_utterance_lens = [utterance_len]
all_personas = [persona]
all_persona_lens = [persona_len]
for i in range(1,features["len"]):
distractor, distractor_len = get_id_feature(features,
"utterance_{}".format(i),
"utterance_{}_len".format(i),
hparams.max_utterance_len)
all_contexts.append(context)
all_context_lens.append(context_len)
all_utterances.append(distractor)
all_utterance_lens.append(distractor_len)
all_personas.append(persona)
all_persona_lens.append(persona_len)
probs, loss = model_impl(
hparams,
mode,
tf.concat(all_contexts,0),
tf.concat(all_context_lens,0),
tf.concat(all_utterances,0),
tf.concat(all_utterance_lens,0),
tf.concat(all_personas,0),
tf.concat(all_persona_lens,0),
None)
split_probs = tf.split(probs, features["len"],0)
probs = tf.concat(split_probs,1)
return probs, 0.0, None
if mode == tf.contrib.learn.ModeKeys.EVAL:
batch_size = targets.get_shape().as_list()[0]
# We have 10 exampels per record, so we accumulate them
all_contexts = [context]
all_context_lens = [context_len]
all_utterances = [utterance]
all_utterance_lens = [utterance_len]
all_personas = [persona]
all_persona_lens = [persona_len]
all_targets = [tf.ones([batch_size, 1], dtype=tf.int64)]
for i in range(9):
distractor, distractor_len = get_id_feature(features,
"distractor_{}".format(i),
"distractor_{}_len".format(i),
hparams.max_utterance_len)
all_contexts.append(context)
all_context_lens.append(context_len)
all_utterances.append(distractor)
all_utterance_lens.append(distractor_len)
all_personas.append(persona)
all_persona_lens.append(persona_len)
all_targets.append(
tf.zeros([batch_size, 1], dtype=tf.int64)
)
probs, loss = model_impl(
hparams,
mode,
tf.concat(all_contexts,0),
tf.concat(all_context_lens,0),
tf.concat(all_utterances,0),
tf.concat(all_utterance_lens,0),
tf.concat(all_personas,0),
tf.concat(all_persona_lens,0),
tf.concat(all_targets,0))
split_probs = tf.split(probs,10,0)
shaped_probs = tf.concat(split_probs,1)
# Add summaries
tf.summary.histogram("eval_correct_probs_hist", split_probs[0])
tf.summary.scalar("eval_correct_probs_average", tf.reduce_mean(split_probs[0]))
tf.summary.histogram("eval_incorrect_probs_hist", split_probs[1])
tf.summary.scalar("eval_incorrect_probs_average", tf.reduce_mean(split_probs[1]))
return shaped_probs, loss, None
return model_fn