Skip to content

Commit

Permalink
Merge pull request #450 from tomaarsen/fix/normalize_with_diff_head
Browse files Browse the repository at this point in the history
Allow normalize_embeddings with a differentiable head
  • Loading branch information
tomaarsen authored Nov 29, 2023
2 parents f04e997 + 6f226e5 commit f021e13
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/setfit/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,9 @@ def fit(

outputs = self.model_body(features)
if self.normalize_embeddings:
outputs = nn.functional.normalize(outputs, p=2, dim=1)
outputs["sentence_embedding"] = nn.functional.normalize(
outputs["sentence_embedding"], p=2, dim=1
)
outputs = self.model_head(outputs)
logits = outputs["logits"]

Expand Down
18 changes: 18 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,24 @@ def setUp(self):
)
self.args = TrainingArguments(num_iterations=1)

def test_trainer_normalize(self):
self.model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-albert-small-v2",
use_differentiable_head=True,
head_params={"out_features": 3},
normalize_embeddings=True,
)
trainer = Trainer(
model=self.model,
args=self.args,
train_dataset=self.dataset,
eval_dataset=self.dataset,
column_mapping={"text_new": "text", "label_new": "label"},
)
trainer.train()
metrics = trainer.evaluate()
self.assertEqual(metrics, {"accuracy": 1.0})

def test_trainer_max_length_exceeds_max_acceptable_length(self):
trainer = Trainer(
model=self.model,
Expand Down

0 comments on commit f021e13

Please sign in to comment.