-
Notifications
You must be signed in to change notification settings - Fork 0
/
simclr.py
225 lines (193 loc) · 9.58 KB
/
simclr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
# The SimCLR class implementation is based on the tutorial in https://keras.io/examples/vision/semisupervised_simclr/
import tensorflow as tf
import math
from transforms import ColorDrop, RandomBlur, RandomColorAffine
rng = tf.random.Generator.from_seed(42)
def get_projection_head(width, input_shape):
return tf.keras.Sequential(
[
tf.keras.Input(shape=input_shape),
tf.keras.layers.Dense(width, activation="relu"),
tf.keras.layers.Dense(width),
],
name="projection_head",
)
def get_encoder(width, input_shape):
return tf.keras.Sequential(
[
tf.keras.Input(shape=input_shape),
tf.keras.layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
tf.keras.layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
tf.keras.layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
tf.keras.layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(width, activation="relu"),
],
name="encoder",
)
def get_augmenter(input_shape, min_area, brightness, jitter):
zoom_factor = 1.0 - math.sqrt(min_area)
return tf.keras.Sequential(
[
tf.keras.Input(shape=input_shape),
#tf.keras.layers.experimental.preprocessing.Resizing(input_shape[0], input_shape[1]),
#tf.keras.layers.experimental.preprocessing.Rescaling(1.0 / 255),
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomTranslation(zoom_factor / 2, zoom_factor / 2),
tf.keras.layers.RandomZoom((-zoom_factor, 0.0), (-zoom_factor, 0.0)),
RandomColorAffine(brightness, jitter),
#ColorDrop(p=0.2),
RandomBlur(p=0.5, kernel_size=9),
]
)
class ContrastiveModel(tf.keras.Model):
"""TensorFlow SimCLR model for contrastive learning."""
def __init__(self, augmenter, encoder, projection_head, temperature=0.1, loss_implementation="simple", **kwargs):
super().__init__(name="simclr", **kwargs)
# === Hyperparameters ===
self.temperature = temperature
self.loss_implementation = loss_implementation
# === Architecture ===
self.augmenter = augmenter
self.encoder = encoder
self.projection_head = projection_head
# === Linear probe ===
self.linear_probe = None
self.labeled_dataset = None
def compile(self, optimizer, probe_optimizer, **kwargs):
"""
Initialize the contrastive and probe optimizers and losses.
Note that self.contrastive_loss is not here, but is defined as a method below.
"""
super().compile(**kwargs)
self.contrastive_optimizer = optimizer
self.probe_optimizer = probe_optimizer
self.probe_loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True
)
self.contrastive_loss_tracker = tf.keras.metrics.Mean(name="contrastive_loss")
self.contrastive_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name="contrastive_acc"
)
self.probe_loss_tracker = tf.keras.metrics.Mean(name="prediction_loss")
self.probe_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name="prediction_acc"
)
@property
def metrics(self):
return super().metrics + [
self.contrastive_loss_tracker,
self.contrastive_accuracy,
self.probe_loss_tracker,
self.probe_accuracy,
]
def contrastive_loss(self, projections_1, projections_2):
"""
Compute the contrastive loss - NT-Xent loss (normalized temperature-scaled cross entropy) - for two sets of projections.
Possibly equal to the InfoNCE loss (information noise-contrastive estimation) described elsewhere.
"""
batch_size = tf.shape(projections_1)[0]
# Cosine similarity: the dot product of the l2-normalized feature vectors
projections_1 = tf.math.l2_normalize(projections_1, axis=1)
projections_2 = tf.math.l2_normalize(projections_2, axis=1)
if self.loss_implementation == "simple":
# Similaritites as in tutorial
similarities1 = (
tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature
)
similarities2 = tf.transpose(similarities1)
elif self.loss_implementation == "complicated":
# Similarities as in the paper
large_value = tf.eye(batch_size) * 1e9
sim_12 = (
tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature
)
sim_21 = (
tf.matmul(projections_2, projections_1, transpose_b=True) / self.temperature
)
sim_11 = (
tf.matmul(projections_1, projections_1, transpose_b=True) / self.temperature
) - large_value
sim_22 = (
tf.matmul(projections_2, projections_2, transpose_b=True) / self.temperature
) - large_value
similarities1 = tf.concat([sim_12, sim_11], axis=1)
similarities2 = tf.concat([sim_21, sim_22], axis=1)
else:
raise ValueError("method must be 'simple' or 'complicated'")
# The similarity between the representations of two augmented views of the
# same image should be higher than their similarity with other views
contrastive_labels = tf.range(batch_size)
self.contrastive_accuracy.update_state(contrastive_labels, similarities1)
self.contrastive_accuracy.update_state(contrastive_labels, similarities2)
# The temperature-scaled similarities are used as logits for cross-entropy
# a symmetrized version of the loss is used here
loss_1_2 = tf.keras.losses.sparse_categorical_crossentropy(
contrastive_labels, similarities1, from_logits=True
)
loss_2_1 = tf.keras.losses.sparse_categorical_crossentropy(
contrastive_labels, similarities2, from_logits=True
)
return (loss_1_2 + loss_2_1) / 2
def train_step(self, images):
"""
In the Keras tutorial, both labeled and unlabeled images are concatenated for contrastive learning.
Here, we only use the unlabeled images for contrastive learning.
"""
augmented_images_1 = self.augmenter(images, training=True)
augmented_images_2 = self.augmenter(images, training=True)
with tf.GradientTape() as tape:
# Each augmented image is passed through the encoder
features_1 = self.encoder(augmented_images_1, training=True)
features_2 = self.encoder(augmented_images_2, training=True)
# The representations are passed through a projection mlp
projections_1 = self.projection_head(features_1, training=True)
projections_2 = self.projection_head(features_2, training=True)
# The contrastive loss is computed on the projections
contrastive_loss = self.contrastive_loss(projections_1, projections_2)
gradients = tape.gradient(
contrastive_loss,
self.encoder.trainable_weights + self.projection_head.trainable_weights,
)
self.contrastive_optimizer.apply_gradients(
zip(
gradients,
self.encoder.trainable_weights + self.projection_head.trainable_weights,
)
)
self.contrastive_loss_tracker.update_state(contrastive_loss)
#########################################
### The next part is the linear probe ###
#########################################
if self.linear_probe is not None:
# Labels are only used in evalutation for an on-the-fly logistic regression
images, labels = next(iter(self.labeled_dataset))
images = images / 255.0
# Sample weights are used for class balancing
sample_weight = self.class_weights.lookup(labels)
with tf.GradientTape() as tape:
# the encoder is used in inference mode here to avoid regularization
# and updating the batch normalization paramers if they are used
features = self.encoder(images, training=False)
class_logits = self.linear_probe(features, training=True)
probe_loss = self.probe_loss(
labels, class_logits, sample_weight=sample_weight
)
gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
self.probe_optimizer.apply_gradients(
zip(gradients, self.linear_probe.trainable_weights)
)
self.probe_loss_tracker.update_state(probe_loss)
self.probe_accuracy.update_state(labels, class_logits)
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
labeled_images, labels = data
# For testing the components are used with a training=False flag
preprocessed_images = labeled_images / 255.0
features = self.encoder(preprocessed_images, training=False)
class_logits = self.linear_probe(features, training=False)
probe_loss = self.probe_loss(labels, class_logits)
self.probe_loss_tracker.update_state(probe_loss)
self.probe_accuracy.update_state(labels, class_logits)
# Only the probe metrics are logged at test time
return {m.name: m.result() for m in self.metrics[2:]}