-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlayers.py
79 lines (63 loc) · 2.67 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
import sentence_transformers
import torch
from keras.layers import TorchModuleWrapper
from images import ImageModel
# basic elsa model as a keras layer (usebale at other keras models)
class LayerELSA(keras.layers.Layer):
def __init__(self, n_dims, n_items, device):
super(LayerELSA, self).__init__()
self.device = device
self.A = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty([n_dims, n_items])))
def parameters(self, recurse=True):
return [self.A]
def track_module_parameters(self):
for param in self.parameters():
variable = keras.Variable(initializer=param, trainable=param.requires_grad)
variable._value = param
self._track_variable(variable)
self.built = True
def build(self):
self.to(self.device)
sample_input = torch.ones([self.A.shape[0]]).to(self.device)
_ = self.call(sample_input)
self.track_module_parameters()
def call(self, x):
A = torch.nn.functional.normalize(self.A, dim=-1)
xA = torch.matmul(x, A)
xAAT = torch.matmul(xA, A.T)
return keras.activations.relu(xAAT - x)
# keras wrapper around sentence transformers object
class LayerSBERT(keras.layers.Layer):
def __init__(self, model, device, tokenized_sentences):
super(LayerSBERT, self).__init__()
self.device = device
self.sbert = TorchModuleWrapper(model.to(device))
self.tokenize_ = self.sb().tokenize
self.tokenized_sentences = tokenized_sentences
self.build()
def sb(self):
for module in self.sbert.modules():
if isinstance(module, sentence_transformers.SentenceTransformer) or isinstance(module, ImageModel):
return module
def parameters(self, recurse=True):
return self.sbert.parameters()
def track_module_parameters(self):
for param in self.parameters():
variable = keras.Variable(initializer=param, trainable=param.requires_grad)
variable._value = param
self._track_variable(variable)
self.built = True
def tokenize(self, inp):
# move tokenized tensors to device and return tokenized sentences
return {k: v.to(self.device) for k, v in self.tokenize_(inp).items()}
def build(self):
self.to(self.device)
sample_input = {k: v[:2].to(self.device) for k, v in self.tokenized_sentences.items()}
_ = self.call(sample_input)
self.track_module_parameters()
def call(self, x):
# just call sentence transformer model
return self.sbert.forward(x)["sentence_embedding"]