From f61ef912c9fd29ffdc8295073f1fa1fc31cdec7e Mon Sep 17 00:00:00 2001 From: stephantul Date: Wed, 25 Sep 2024 19:08:49 +0200 Subject: [PATCH] Fix test, docstrings --- model2vec/distill/distillation.py | 4 ++-- tests/conftest.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/model2vec/distill/distillation.py b/model2vec/distill/distillation.py index b1421a2..dd6e44b 100644 --- a/model2vec/distill/distillation.py +++ b/model2vec/distill/distillation.py @@ -37,10 +37,10 @@ def distill( :param device: The device to use. :param pca_dims: The number of components to use for PCA. If this is None, we don't apply PCA. :param apply_zipf: Whether to apply Zipf weighting to the embeddings. - :param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary. + :param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words. :raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings. :raises: ValueError if the vocabulary contains duplicate tokens. - :return: A StaticModdel + :return: A StaticModel """ if not use_subword and vocabulary is None: diff --git a/tests/conftest.py b/tests/conftest.py index 3c28d82..c97ccf3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,18 +1,19 @@ import numpy as np import pytest -from transformers import PreTrainedTokenizerFast - -from model2vec.distill.tokenizer import create_tokenizer_from_vocab +from tokenizers import Tokenizer +from tokenizers.models import WordLevel +from tokenizers.pre_tokenizers import Whitespace @pytest.fixture -def mock_tokenizer() -> PreTrainedTokenizerFast: +def mock_tokenizer() -> Tokenizer: """Create a mock tokenizer.""" vocab = ["word1", "word2", "word3", "[UNK]", "[PAD]"] unk_token = "[UNK]" - pad_token = "[PAD]" - tokenizer = create_tokenizer_from_vocab(vocab, unk_token, pad_token) + model = WordLevel(vocab={word: idx for idx, word in enumerate(vocab)}, unk_token=unk_token) + tokenizer = Tokenizer(model) + tokenizer.pre_tokenizer = Whitespace() return tokenizer