Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix distill model bos and eos token #78

Merged
merged 1 commit into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def distill_from_model(
:param device: The device to use.
:param pca_dims: The number of components to use for PCA.
If this is None, we don't apply PCA.
If this is 'auto', we don't reduce dimenionality, but still apply PCA.
If this is 'auto', we don't reduce dimensionality, but still apply PCA.
:param apply_zipf: Whether to apply Zipf weighting to the embeddings.
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings.
Expand Down
4 changes: 2 additions & 2 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def create_output_embeddings_from_model_name(

# Work-around to get the eos and bos token ids without having to go into tokenizer internals.
dummy_encoding = tokenizer.encode("A")
eos_token_id, bos_token_id = dummy_encoding[0], dummy_encoding[-1]
bos_token_id, eos_token_id = dummy_encoding[0], dummy_encoding[-1]

eos = torch.full([len(ids)], fill_value=eos_token_id)
bos = torch.full([len(ids)], fill_value=bos_token_id)
eos = torch.full([len(ids)], fill_value=eos_token_id)

stacked = torch.stack([bos, ids, eos], dim=1)

Expand Down
2 changes: 1 addition & 1 deletion model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def from_pretrained(

:param path: The path to load your static model from.
:param token: The huggingface token to use.
:return: A StaticEmbedder
:return: A StaticModel
"""
embeddings, tokenizer, config, metadata = load_pretrained(path, token=token)

Expand Down