Skip to content

Commit

Permalink
Fix test, docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Sep 25, 2024
1 parent 60e202f commit f61ef91
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
4 changes: 2 additions & 2 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def distill(
:param device: The device to use.
:param pca_dims: The number of components to use for PCA. If this is None, we don't apply PCA.
:param apply_zipf: Whether to apply Zipf weighting to the embeddings.
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary.
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings.
:raises: ValueError if the vocabulary contains duplicate tokens.
:return: A StaticModdel
:return: A StaticModel
"""
if not use_subword and vocabulary is None:
Expand Down
13 changes: 7 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import numpy as np
import pytest
from transformers import PreTrainedTokenizerFast

from model2vec.distill.tokenizer import create_tokenizer_from_vocab
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace


@pytest.fixture
def mock_tokenizer() -> PreTrainedTokenizerFast:
def mock_tokenizer() -> Tokenizer:
"""Create a mock tokenizer."""
vocab = ["word1", "word2", "word3", "[UNK]", "[PAD]"]
unk_token = "[UNK]"
pad_token = "[PAD]"

tokenizer = create_tokenizer_from_vocab(vocab, unk_token, pad_token)
model = WordLevel(vocab={word: idx for idx, word in enumerate(vocab)}, unk_token=unk_token)
tokenizer = Tokenizer(model)
tokenizer.pre_tokenizer = Whitespace()

return tokenizer

Expand Down

0 comments on commit f61ef91

Please sign in to comment.