-
Notifications
You must be signed in to change notification settings - Fork 0
/
probing_model.py
executable file
·64 lines (56 loc) · 2.38 KB
/
probing_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# external libs
import tensorflow as tf
# from tensorflow.keras import models, layers
import tensorflow.keras.layers as layers
import tensorflow.keras.models as models
# project imports
from util import load_pretrained_model
class ProbingClassifier(models.Model):
def __init__(self,
pretrained_model_path: str,
layer_num: int,
classes_num: int) -> 'ProbingClassifier':
"""
It loads a pretrained main model. On the given input,
it takes the representations it generates on certain layer
and learns a linear classifier on top of these frozen
features.
Parameters
----------
pretrained_model_path : ``str``
Serialization directory of the main model which you
want to probe at one of the layers.
layer_num : ``int``
Layer number of the pretrained model on which to learn
a linear classifier probe.
classes_num : ``int``
Number of classes that the ProbingClassifier chooses from.
"""
super(ProbingClassifier, self).__init__()
self._pretrained_model = load_pretrained_model(pretrained_model_path)
self._pretrained_model.trainable = False
self._layer_num = layer_num
self.classes_num=classes_num
# TODO(students): start
self._layer=layers.Dense(self.classes_num,activation="linear")
# TODO(students): end
def call(self, inputs: tf.Tensor, training: bool =False) -> tf.Tensor:
"""
Forward pass of Probing Classifier.
Parameters
----------
inputs : ``str``
Tensorized version of the batched input text. It is of shape:
(batch_size, max_tokens_num) and entries are indices of tokens
in to the vocabulary. 0 means that it's a padding token. max_tokens_num
is maximum number of tokens in any text sequence in this batch.
training : ``bool``
Whether this call is in training mode or prediction mode.
This flag is useful while applying dropout because dropout should
only be applied during training.
"""
# TODO(students): start
output=self._pretrained_model(inputs,training)
logits=self._layer(output['layer_representations'][self._layer_num-1])
# TODO(students): end
return {"logits": logits}