Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add from model #57

Merged
merged 5 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,22 @@ m2v_model = distill(model_name=model_name, pca_dims=256)
m2v_model.save_pretrained("m2v_model")
```

Distillation is really fast, and only takes about 30 seconds on a 2024 macbook using the MPS backend. Best of all, distillation requires no training data.
If you already have a model loaded, or need to load a model in some special way, we also offer an interface to distill models in memory.

```python
from model2vec.distill import distill_from_model

# Assuming a loaded model and tokenizer
model = load_my_model()
tokenizer = load_my_tokenizer()

m2v_model = distill(model=model, tokenizer=tokenizer, pca_dims=256)

m2v_model.save_pretrained("m2v_model")

```

Distillation is really fast, and only takes about 5 seconds on a 2024 macbook using the MPS backend, 30 seconds on CPU. Best of all, distillation requires no training data.

## What is Model2Vec?

Expand Down
4 changes: 2 additions & 2 deletions model2vec/distill/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from model2vec.distill.__main__ import distill
from model2vec.distill.distillation import distill, distill_from_model

__all__ = ["distill"]
__all__ = ["distill", "distill_from_model"]
79 changes: 58 additions & 21 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from huggingface_hub import model_info
from sklearn.decomposition import PCA
from tokenizers.models import BPE, Unigram, WordPiece
from tokenizers.models import BPE, Unigram
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast

from model2vec.distill.inference import (
Expand All @@ -16,8 +16,9 @@
logger = logging.getLogger(__name__)


def distill(
model_name: str,
def distill_from_model(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerFast,
vocabulary: list[str] | None = None,
device: str = "cpu",
pca_dims: int | None = 256,
Expand All @@ -33,7 +34,8 @@ def distill(
If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary.
If you don't pass a vocabulary, we use the model's tokenizer directly.

:param model_name: The model name to use. Any sentencetransformer compatible model works.
:param model: The model to use.
:param tokenizer: The tokenizer to use.
:param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary.
:param device: The device to use.
:param pca_dims: The number of components to use for PCA. If this is None, we don't apply PCA.
Expand All @@ -44,53 +46,46 @@ def distill(
:return: A StaticModel

"""
""""""
if not use_subword and vocabulary is None:
raise ValueError(
"You must pass a vocabulary if you don't use subword tokens. Either pass a vocabulary, or set use_subword to True."
)

# Load original tokenizer. We need to keep this to tokenize any tokens in the vocabulary.
original_tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(model_name)

if vocabulary and isinstance(original_tokenizer.backend_tokenizer.model, (BPE, Unigram)):
if vocabulary and isinstance(tokenizer.backend_tokenizer.model, (BPE, Unigram)):
raise ValueError(
"You passed a vocabulary, but the model you are using does not use a WordPiece tokenizer. "
"This is not supported yet."
"Feel free to open an issue if this is a blocker: https://github.com/MinishLab/model2vec/issues"
)

original_model: PreTrainedModel = AutoModel.from_pretrained(model_name)
# Make a base list of tokens.
tokens: list[str] = []
if use_subword:
# Create the subword embeddings.
tokens, embeddings = create_output_embeddings_from_model_name(
model=original_model, tokenizer=original_tokenizer, device=device
)
tokens, embeddings = create_output_embeddings_from_model_name(model=model, tokenizer=tokenizer, device=device)

# Remove any unused tokens from the tokenizer and embeddings.
wrong_tokens = [x for x in tokens if x.startswith("[unused")]
vocab = original_tokenizer.get_vocab()
vocab = tokenizer.get_vocab()
# Get the ids of the unused token.
wrong_token_ids = [vocab[token] for token in wrong_tokens]
# Remove the unused tokens from the tokenizer.
new_tokenizer = remove_tokens(original_tokenizer.backend_tokenizer, wrong_tokens)
new_tokenizer = remove_tokens(tokenizer.backend_tokenizer, wrong_tokens)
# Remove the embeddings of the unused tokens.
embeddings = np.delete(embeddings, wrong_token_ids, axis=0)
logger.info(f"Removed {len(wrong_tokens)} unused tokens from the tokenizer and embeddings.")
else:
# We need to keep the unk token in the tokenizer.
unk_token = original_tokenizer.backend_tokenizer.model.unk_token
unk_token = tokenizer.backend_tokenizer.model.unk_token
# Remove all tokens except the UNK token.
new_tokenizer = remove_tokens(
original_tokenizer.backend_tokenizer, list(set(original_tokenizer.get_vocab()) - {unk_token})
)
new_tokenizer = remove_tokens(tokenizer.backend_tokenizer, list(set(tokenizer.get_vocab()) - {unk_token}))
# We need to set embeddings to None because we don't know the dimensions of the embeddings yet.
embeddings = None

if vocabulary:
# Preprocess the vocabulary with the original tokenizer.
preprocessed_vocabulary = preprocess_vocabulary(original_tokenizer.backend_tokenizer, vocabulary)
preprocessed_vocabulary = preprocess_vocabulary(tokenizer.backend_tokenizer, vocabulary)
n_tokens_before = len(preprocessed_vocabulary)
# Clean the vocabulary by removing duplicate tokens and tokens that are in the subword vocabulary.
cleaned_vocabulary = _clean_vocabulary(preprocessed_vocabulary, tokens)
Expand All @@ -102,8 +97,8 @@ def distill(
if cleaned_vocabulary:
# Create the embeddings.
_, token_embeddings = create_output_embeddings_from_model_name_and_tokens(
model=original_model,
tokenizer=original_tokenizer,
model=model,
tokenizer=tokenizer,
tokens=cleaned_vocabulary,
device=device,
)
Expand All @@ -121,6 +116,8 @@ def distill(
# Post process the embeddings by applying PCA and Zipf weighting.
embeddings = _post_process_embeddings(np.asarray(embeddings), pca_dims, apply_zipf)

model_name = getattr(model, "name_or_path", "")

config = {"tokenizer_name": model_name, "apply_pca": pca_dims, "apply_zipf": apply_zipf}
# Get the language from the model card
info = model_info(model_name)
Expand All @@ -131,6 +128,46 @@ def distill(
)


def distill(
model_name: str,
vocabulary: list[str] | None = None,
device: str = "cpu",
pca_dims: int | None = 256,
apply_zipf: bool = True,
use_subword: bool = True,
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.

This function creates a set of embeddings from a sentence transformer. It does this by doing either
a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary.

If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary.
If you don't pass a vocabulary, we use the model's tokenizer directly.

:param model_name: The model name to use. Any sentencetransformer compatible model works.
:param vocabulary: The vocabulary to use. If this is None, we use the model's vocabulary.
:param device: The device to use.
:param pca_dims: The number of components to use for PCA. If this is None, we don't apply PCA.
:param apply_zipf: Whether to apply Zipf weighting to the embeddings.
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:return: A StaticModel

"""
model: PreTrainedModel = AutoModel.from_pretrained(model_name)
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(model_name)

return distill_from_model(
model=model,
tokenizer=tokenizer,
vocabulary=vocabulary,
device=device,
pca_dims=pca_dims,
apply_zipf=apply_zipf,
use_subword=use_subword,
)


def _post_process_embeddings(embeddings: np.ndarray, pca_dims: int | None, apply_zipf: bool) -> np.ndarray:
"""Post process embeddings by applying PCA and Zipf weighting."""
if pca_dims is not None:
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from transformers import AutoModel, BertTokenizerFast
from transformers import AutoModel, AutoTokenizer


@pytest.fixture
Expand All @@ -26,10 +26,9 @@ def mock_tokenizer() -> Tokenizer:


@pytest.fixture
def mock_berttokenizer() -> BertTokenizerFast:
def mock_berttokenizer() -> AutoTokenizer:
"""Load the real BertTokenizerFast from the provided tokenizer.json file."""
tokenizer_path = Path("tests/data/test_tokenizer/tokenizer.json")
return BertTokenizerFast(tokenizer_file=str(tokenizer_path))
return AutoTokenizer.from_pretrained("tests/data/test_tokenizer")


@pytest.fixture
Expand All @@ -39,6 +38,7 @@ def mock_transformer() -> AutoModel:
class MockPreTrainedModel:
def __init__(self) -> None:
self.device = "cpu"
self.name_or_path = "mock-model"

def to(self, device: str) -> MockPreTrainedModel:
self.device = device
Expand Down
83 changes: 77 additions & 6 deletions tests/test_distillation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from importlib import import_module
from unittest.mock import MagicMock, patch

Expand All @@ -6,12 +7,86 @@
from pytest import LogCaptureFixture
from transformers import AutoModel, BertTokenizerFast

from model2vec.distill.distillation import _clean_vocabulary, _post_process_embeddings, distill
from model2vec.distill.distillation import _clean_vocabulary, _post_process_embeddings, distill, distill_from_model
from model2vec.model import StaticModel

rng = np.random.default_rng()


@pytest.mark.parametrize(
"vocabulary, use_subword, pca_dims, apply_zipf",
[
(None, True, 256, True), # Output vocab with subwords, PCA applied
(
["wordA", "wordB"],
False,
4,
False,
), # Custom vocab without subword , PCA applied
(["wordA", "wordB"], True, 4, False), # Custom vocab with subword, PCA applied
(None, True, None, True), # No PCA applied
(["wordA", "wordB"], False, 4, True), # Custom vocab without subwords PCA and Zipf applied
(None, False, 256, True), # use_subword = False without passing a vocabulary should raise an error
],
)
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@patch("transformers.AutoModel.from_pretrained")
def test_distill_from_model(
mock_auto_model: MagicMock,
mock_model_info: MagicMock,
mock_berttokenizer: BertTokenizerFast,
mock_transformer: AutoModel,
vocabulary: list[str] | None,
use_subword: bool,
pca_dims: int | None,
apply_zipf: bool,
) -> None:
"""Test distill function with different parameters."""
# Mock the return value of model_info to avoid calling the Hugging Face API
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})

# Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
# mock_auto_tokenizer.return_value = mock_berttokenizer
mock_auto_model.return_value = mock_transformer

if vocabulary is None and not use_subword:
with pytest.raises(ValueError):
static_model = distill_from_model(
model=mock_transformer,
tokenizer=mock_berttokenizer,
vocabulary=vocabulary,
device="cpu",
pca_dims=pca_dims,
apply_zipf=apply_zipf,
use_subword=use_subword,
)
else:
# Call the distill function with the parametrized inputs
static_model = distill_from_model(
model=mock_transformer,
tokenizer=mock_berttokenizer,
vocabulary=vocabulary,
device="cpu",
pca_dims=pca_dims,
apply_zipf=apply_zipf,
use_subword=use_subword,
)

static_model2 = distill(
model_name="tests/data/test_tokenizer",
vocabulary=vocabulary,
device="cpu",
pca_dims=pca_dims,
apply_zipf=apply_zipf,
use_subword=use_subword,
)

assert static_model.embedding.weight.shape == static_model2.embedding.weight.shape
assert static_model.config == static_model2.config
assert json.loads(static_model.tokenizer.to_str()) == json.loads(static_model2.tokenizer.to_str())
assert static_model.base_model_name == static_model2.base_model_name


@pytest.mark.parametrize(
"vocabulary, use_subword, pca_dims, apply_zipf, expected_shape",
[
Expand All @@ -30,13 +105,10 @@
],
)
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@patch("transformers.AutoTokenizer.from_pretrained")
@patch("transformers.AutoModel.from_pretrained")
def test_distill(
mock_auto_model: MagicMock,
mock_auto_tokenizer: MagicMock,
mock_model_info: MagicMock,
mock_berttokenizer: BertTokenizerFast,
mock_transformer: AutoModel,
vocabulary: list[str] | None,
use_subword: bool,
Expand All @@ -49,10 +121,9 @@ def test_distill(
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})

# Patch the tokenizers and models to return the real BertTokenizerFast and mock model instances
mock_auto_tokenizer.return_value = mock_berttokenizer
mock_auto_model.return_value = mock_transformer

model_name = "mock-model"
model_name = "tests/data/test_tokenizer"

if vocabulary is None and not use_subword:
with pytest.raises(ValueError):
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.