Skip to content

Commit

Permalink
fix: Attention mask being None crashes distillation (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul authored Oct 3, 2024
1 parent 57ffb58 commit 649ab51
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,13 @@ def create_output_embeddings_from_model_name(
for batch_idx in tqdm(range(0, len(stacked), _DEFAULT_BATCH_SIZE)):
batch = stacked[batch_idx : batch_idx + _DEFAULT_BATCH_SIZE].to(model.device)
with torch.no_grad():
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(input_ids=batch.to(device))
# NOTE: we create these masks because nomic embed requires them.
# Normally, we could set them to None
token_type_ids = torch.zeros_like(batch)
attention_mask = torch.ones_like(batch)
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(
input_ids=batch.to(device), attention_mask=attention_mask, token_type_ids=token_type_ids
)
out: torch.Tensor = encoded.last_hidden_state
# NOTE: If the dtype is bfloat 16, we convert to float32,
# because numpy does not suport bfloat16
Expand Down

0 comments on commit 649ab51

Please sign in to comment.