Skip to content

Commit

Permalink
Resolved conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Nov 2, 2024
2 parents d92390b + 54a6460 commit 8822c4b
Show file tree
Hide file tree
Showing 27 changed files with 1,787 additions and 979 deletions.
12 changes: 9 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest", "windows-latest"]
python-version: ["3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
exclude:
- os: windows-latest
python-version: "3.9"
- os: windows-latest
python-version: "3.11"
- os: windows-latest
python-version: "3.12"
fail-fast: false

steps:
Expand Down Expand Up @@ -42,8 +49,7 @@ jobs:
# Install dependencies using uv pip
- name: Install dependencies
run: make install
# run: uv pip install -e ".[pytest]"
run: make install-no-pre-commit

# Run tests with coverage
- name: Run tests under coverage
Expand Down
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ models
checkpoints/*
features/*
model2vec_models
results/*
counts/*
results_old/*
local/*
Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ install:
uv sync --all-extras
uv run pre-commit install

install-no-pre-commit:
uv pip install ".[dev,distill]"

install-base:
uv sync --extra dev

Expand Down
245 changes: 174 additions & 71 deletions README.md

Large diffs are not rendered by default.

Binary file added assets/images/speed_vs_accuracy_v4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/images/speed_vs_mteb_score.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/images/speed_vs_mteb_score_v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions model2vec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from model2vec.distill import distill
from model2vec.model import StaticModel
from model2vec.version import __version__

__all__ = ["distill", "StaticModel"]
__all__ = ["StaticModel", "__version__"]
7 changes: 7 additions & 0 deletions model2vec/distill/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from model2vec.utils import get_package_extras, importable

_REQUIRED_EXTRA = "distill"

for extra_dependency in get_package_extras("model2vec", _REQUIRED_EXTRA):
importable(extra_dependency, _REQUIRED_EXTRA)

from model2vec.distill.distillation import distill, distill_from_model

__all__ = ["distill", "distill_from_model"]
53 changes: 0 additions & 53 deletions model2vec/distill/__main__.py

This file was deleted.

14 changes: 12 additions & 2 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import logging
from typing import Literal
from typing import Literal, Union

import numpy as np
from huggingface_hub import model_info
Expand All @@ -26,7 +28,7 @@
logger = logging.getLogger(__name__)


PCADimType = int | None | Literal["auto"]
PCADimType = Union[int, None, Literal["auto"]]


def distill_from_model(
Expand Down Expand Up @@ -214,9 +216,17 @@ def _post_process_embeddings(embeddings: np.ndarray, pca_dims: PCADimType, apply
elif pca_dims <= embeddings.shape[1]:
logger.info(f"Applying PCA with n_components {pca_dims}")

orig_dims = embeddings.shape[1]
p = PCA(n_components=pca_dims, whiten=False)
embeddings = p.fit_transform(embeddings)

if embeddings.shape[1] < orig_dims:
explained_variance_ratio = np.sum(p.explained_variance_ratio_)
explained_variance = np.sum(p.explained_variance_)
logger.info(f"Reduced dimensionality from {orig_dims} to {embeddings.shape[1]}.")
logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
logger.info(f"Explained variance: {explained_variance:.3f}.")

if apply_zipf:
logger.info("Applying Zipf weighting")
embeddings *= np.log(1 + np.arange(embeddings.shape[0]))[:, None]
Expand Down
19 changes: 15 additions & 4 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import inspect
import logging
from pathlib import Path
from typing import Protocol
from typing import Protocol, Union

import numpy as np
import torch
Expand All @@ -13,7 +15,7 @@
logger = logging.getLogger(__name__)


PathLike = str | Path
PathLike = Union[Path, str]

_DEFAULT_BATCH_SIZE = 1024

Expand Down Expand Up @@ -113,7 +115,15 @@ def create_output_embeddings_from_model_name(
:return: The tokens and output embeddings.
"""
model = model.to(device)
ids = torch.arange(tokenizer.vocab_size)

# Quick check to see if the tokenizer is consistent.
vocab_length = len(tokenizer.get_vocab())
if vocab_length != tokenizer.vocab_size:
logger.warning(
f"Reported vocab size {tokenizer.vocab_size} is inconsistent with the vocab size {vocab_length}."
)

ids = torch.arange(vocab_length)

# Work-around to get the eos and bos token ids without having to go into tokenizer internals.
dummy_encoding = tokenizer.encode("A")
Expand All @@ -122,7 +132,8 @@ def create_output_embeddings_from_model_name(
bos = torch.full([len(ids)], fill_value=bos_token_id)
eos = torch.full([len(ids)], fill_value=eos_token_id)

stacked = torch.stack([bos, ids, eos], dim=1)
# NOTE: reversing the bos and eos tokens works better on our benchmarks.
stacked = torch.stack([eos, ids, bos], dim=1)

intermediate_weights: list[np.ndarray] = []
for batch_idx in tqdm(range(0, len(stacked), _DEFAULT_BATCH_SIZE)):
Expand Down
87 changes: 46 additions & 41 deletions model2vec/distill/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
from typing import Any

from tokenizers import Tokenizer

Expand Down Expand Up @@ -36,11 +37,11 @@ def remove_tokens(tokenizer: Tokenizer, tokens_to_remove: list[str]) -> Tokenize
logger.info("No tokens to remove.")
return Tokenizer.from_str(tokenizer.to_str())

tokenizer_data = json.loads(tokenizer.to_str())
tokenizer_data: dict[str, Any] = json.loads(tokenizer.to_str())

# Find all added tokens
added_tokens = tokenizer_data["added_tokens"]
added_tokens_str = {token["content"] for token in added_tokens}
added_tokens: list[dict[str, Any]] = tokenizer_data.get("added_tokens", [])
added_tokens_str: set[str] = {token["content"] for token in added_tokens}

# Remove all added tokens from the list of tokens to remove.
# Things will go bad if we keep them.
Expand All @@ -49,34 +50,36 @@ def remove_tokens(tokenizer: Tokenizer, tokens_to_remove: list[str]) -> Tokenize
# Load the vocabulary.
model_type = tokenizer_data["model"]["type"]

match model_type:
case "WordPiece":
# Vocab is a dictionary.
vocab: dict[str, int] = tokenizer_data["model"]["vocab"]
n_tokens = len(vocab)

# Remove the tokens.
for token in tokens_to_remove:
if vocab.pop(token, None) is None:
logger.warning(f"Token {token} was not in the vocabulary.")

n_removed = n_tokens - len(vocab)
logger.info(f"Removed {n_removed} tokens from the vocabulary.")

# Reindex the vocabulary so that it is contiguous.
reindexed = {token: idx for idx, (token, _) in enumerate(sorted(vocab.items(), key=lambda x: x[1]))}
tokenizer_data["model"]["vocab"] = reindexed
case "Unigram":
raise ValueError("Removing tokens from a unigram tokenizer is not supported.")
case "BPE":
raise ValueError("Removing tokens from a bpe tokenizer is not supported.")
case _:
raise ValueError(f"Unknown model type {model_type}")
if model_type == "WordPiece":
# Vocab is a dictionary.
vocab: dict[str, int] = tokenizer_data["model"]["vocab"]
n_tokens = len(vocab)

# Remove the tokens.
for token in tokens_to_remove:
if vocab.pop(token, None) is None:
logger.warning(f"Token {token} was not in the vocabulary.")

n_removed = n_tokens - len(vocab)
logger.info(f"Removed {n_removed} tokens from the vocabulary.")

# Reindex the vocabulary so that it is contiguous.
reindexed = {token: idx for idx, (token, _) in enumerate(sorted(vocab.items(), key=lambda x: x[1]))}
tokenizer_data["model"]["vocab"] = reindexed

elif model_type == "Unigram":
raise ValueError("Removing tokens from a unigram tokenizer is not supported.")

elif model_type == "BPE":
raise ValueError("Removing tokens from a BPE tokenizer is not supported.")

else:
raise ValueError(f"Unknown model type {model_type}")

# Reindex the special tokens (i.e., CLS and SEP for BertTokenizers.)
special_tokens_post_processor: dict[str, dict] = tokenizer_data["post_processor"]["special_tokens"]
for token, token_data in special_tokens_post_processor.items():
token_data["ids"] = [reindexed[token] for token in token_data["tokens"]]
added_tokens = tokenizer_data.get("added_tokens", [])
for token_data in added_tokens:
token_data["id"] = reindexed[token_data["content"]]

# Reinitialize the tokenizer from the json.
tokenizer = Tokenizer.from_str(json.dumps(tokenizer_data))
Expand All @@ -97,18 +100,20 @@ def add_tokens(tokenizer: Tokenizer, tokens_to_add: list[str]) -> Tokenizer:

model = data["model"]["type"]

match model:
case "WordPiece":
wordpiece_vocab: dict[str, int] = data["model"]["vocab"]
for token in tokens_to_add:
if token not in wordpiece_vocab:
wordpiece_vocab[token] = len(wordpiece_vocab)
case "Unigram":
raise ValueError("Adding tokens to a unigram tokenizer is not supported.")
case "BPE":
raise ValueError("Adding tokens to a bpe tokenizer is not supported.")
case _:
raise ValueError(f"Unknown model type {model}")
if model == "WordPiece":
wordpiece_vocab: dict[str, int] = data["model"]["vocab"]
for token in tokens_to_add:
if token not in wordpiece_vocab:
wordpiece_vocab[token] = len(wordpiece_vocab)

elif model == "Unigram":
raise ValueError("Adding tokens to a unigram tokenizer is not supported.")

elif model == "BPE":
raise ValueError("Adding tokens to a BPE tokenizer is not supported.")

else:
raise ValueError(f"Unknown model type {model}")

tokenizer = Tokenizer.from_str(json.dumps(data))

Expand Down
2 changes: 2 additions & 0 deletions model2vec/distill/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from logging import getLogger

import torch
Expand Down
Loading

0 comments on commit 8822c4b

Please sign in to comment.