Skip to content

Commit

Permalink
Moved function
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Oct 31, 2024
1 parent 084d8a5 commit 62ac69c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 70 deletions.
61 changes: 1 addition & 60 deletions model2vec/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
logger = logging.getLogger(__name__)


import json


def save_pretrained(
folder_path: Path,
embeddings: np.ndarray,
Expand All @@ -42,63 +39,7 @@ def save_pretrained(
folder_path.mkdir(exist_ok=True, parents=True)
save_file({"embeddings": embeddings}, folder_path / "model.safetensors")
tokenizer.save(str(folder_path / "tokenizer.json"))
with open(folder_path / "config.json", "w") as config_file:
json.dump(config, config_file, indent=4, sort_keys=True)

# Save vocab.txt
with open(folder_path / "vocab.txt", "w") as vocab_file:
vocab = tokenizer.get_vocab()
for token in sorted(vocab, key=vocab.get):
vocab_file.write(f"{token}\n")

# Load tokenizer.json to use for generating tokenizer_config.json
with open(folder_path / "tokenizer.json", "r") as f:
tokenizer_data = json.load(f)

# Save special_tokens_map.json
special_tokens = {
"cls_token": "[CLS]",
"sep_token": "[SEP]",
"pad_token": "[PAD]",
"unk_token": "[UNK]",
"mask_token": "[MASK]",
}
with open(folder_path / "special_tokens_map.json", "w") as special_tokens_file:
json.dump(special_tokens, special_tokens_file, indent=4, sort_keys=True)

# Set fallback values for normalizer attributes in case normalizer is None
normalizer = tokenizer_data.get("normalizer")
do_lower_case = normalizer.get("lowercase") if normalizer else config.get("do_lower_case", True)
strip_accents = normalizer.get("strip_accents") if normalizer else None
tokenize_chinese_chars = normalizer.get("handle_chinese_chars", True) if normalizer else True

# Save tokenizer_config.json based on tokenizer.json
tokenizer_config = {
"added_tokens_decoder": {
str(token["id"]): {
"content": token["content"],
"lstrip": token.get("lstrip", False),
"normalized": token.get("normalized", False),
"rstrip": token.get("rstrip", False),
"single_word": token.get("single_word", False),
"special": token.get("special", True),
}
for token in tokenizer_data.get("added_tokens", [])
},
"clean_up_tokenization_spaces": True,
"cls_token": special_tokens["cls_token"],
"do_lower_case": do_lower_case,
"mask_token": special_tokens["mask_token"],
"model_max_length": config.get("seq_length", 512),
"pad_token": special_tokens["pad_token"],
"sep_token": special_tokens["sep_token"],
"strip_accents": strip_accents,
"tokenize_chinese_chars": tokenize_chinese_chars,
"tokenizer_class": "BertTokenizer",
"unk_token": special_tokens["unk_token"],
}
with open(folder_path / "tokenizer_config.json", "w") as tokenizer_config_file:
json.dump(tokenizer_config, tokenizer_config_file, indent=4, sort_keys=True)
json.dump(config, open(folder_path / "config.json", "w"))

logger.info(f"Saved model to {folder_path}")

Expand Down
71 changes: 61 additions & 10 deletions scripts/export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from pathlib import Path

import torch
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast

from model2vec import StaticModel

Expand Down Expand Up @@ -70,30 +72,40 @@ def tokenize(self, sentences: list[str], max_length: int | None = None) -> tuple
encodings_ids = [token_ids[:max_length] for token_ids in encodings_ids]
# Flatten input_ids and compute offsets
offsets = torch.tensor([0] + [len(ids) for ids in encodings_ids[:-1]], dtype=torch.long).cumsum(dim=0)
input_ids = torch.tensor([token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long)
input_ids = torch.tensor(
[token_id for token_ids in encodings_ids for token_id in token_ids],
dtype=torch.long,
)
return input_ids, offsets


def export_model_to_onnx(model_path: str, save_path: str) -> None:
def export_model_to_onnx(model_path: str, save_path: Path) -> None:
"""
Export the StaticModel to ONNX format.
Export the StaticModel to ONNX format and save tokenizer files.
:param model_path: The path to the pretrained StaticModel.
:param save_path: The path to save the exported ONNX model
:param save_path: The directory to save the model and related files.
"""
# Convert the StaticModel to TorchStaticModel
save_path.mkdir(parents=True, exist_ok=True)

# Load the StaticModel
model = StaticModel.from_pretrained(model_path)
torch_model = TorchStaticModel(model)

# Save the model using save_pretrained
model.save_pretrained(save_path)

# Prepare dummy input data
texts = ["hello", "hello world"]
input_ids, offsets = torch_model.tokenize(texts)

# Export the model to ONNX
onnx_model_path = save_path / "onnx/model.onnx"
onnx_model_path.parent.mkdir(parents=True, exist_ok=True)
torch.onnx.export(
torch_model,
(input_ids, offsets),
save_path,
str(onnx_model_path),
export_params=True,
opset_version=14,
do_constant_folding=True,
Expand All @@ -106,13 +118,52 @@ def export_model_to_onnx(model_path: str, save_path: str) -> None:
},
)

logger.info(f"Model has been successfully exported to {save_path}")
logger.info(f"Model has been successfully exported to {onnx_model_path}")

# Save the tokenizer files required for transformers.js
save_tokenizer(model.tokenizer, save_path)
logger.info(f"Tokenizer files have been saved to {save_path}")


def save_tokenizer(tokenizer: Tokenizer, save_directory: Path) -> None:
"""
Save tokenizer files in a format compatible with Transformers.
:param tokenizer: The tokenizer from the StaticModel.
:param save_directory: The directory to save the tokenizer files.
"""
# Convert the tokenizers.Tokenizer to a PreTrainedTokenizerFast and save
tokenizer_json_path = save_directory / "tokenizer.json"
tokenizer.save(str(tokenizer_json_path))

# Load the tokenizer using PreTrainedTokenizerFast
fast_tokenizer = PreTrainedTokenizerFast(
tokenizer_file=str(tokenizer_json_path),
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]",
)

# Save the tokenizer files
fast_tokenizer.save_pretrained(str(save_directory))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export StaticModel to ONNX format")
parser.add_argument("--model_path", type=Path, required=True, help="Path to the pretrained StaticModel")
parser.add_argument("--save_path", type=Path, required=True, help="Path to save the exported ONNX model")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to the pretrained StaticModel",
)
parser.add_argument(
"--save_path",
type=str,
required=True,
help="Directory to save the exported model and files",
)
args = parser.parse_args()

export_model_to_onnx(args.model_path, args.save_path)
export_model_to_onnx(args.model_path, Path(args.save_path))

0 comments on commit 62ac69c

Please sign in to comment.