-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_with_classifier.py
33 lines (29 loc) · 1.21 KB
/
train_with_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from .utils import common_functions as c_f
from .metric_loss_only import MetricLossOnly
import torch
class TrainWithClassifier(MetricLossOnly):
def calculate_loss(self, curr_batch):
data, labels = curr_batch
embeddings = self.compute_embeddings(data)
logits = self.maybe_get_logits(embeddings)
indices_tuple = self.maybe_mine_embeddings(embeddings, labels)
self.losses["metric_loss"] = self.maybe_get_metric_loss(
embeddings, labels, indices_tuple
)
self.losses["classifier_loss"] = self.maybe_get_classifier_loss(logits, labels)
def maybe_get_classifier_loss(self, logits, labels):
if logits is not None:
return self.loss_funcs["classifier_loss"](
logits, c_f.to_device(labels, logits)
)
return 0
def maybe_get_logits(self, embeddings):
if (
self.models.get("classifier", None)
and self.loss_weights.get("classifier_loss", 0) > 0
):
return self.models["classifier"](embeddings)
return None
def modify_schema(self):
self.schema["models"].keys += ["classifier"]
self.schema["loss_funcs"].keys += ["classifier_loss"]