This repository has been archived by the owner on Apr 27, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathbert.py
110 lines (92 loc) · 3.73 KB
/
bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
import torch_mlir
import iree_torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def prepare_sentence_tokens(hf_model: str, sentence: str):
tokenizer = AutoTokenizer.from_pretrained(hf_model)
return torch.tensor([tokenizer.encode(sentence)])
class OnlyLogitsHuggingFaceModel(torch.nn.Module):
"""Wrapper that returns only the logits from a HuggingFace model."""
def __init__(self, model_name: str):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name, # The pretrained model name.
# The number of output labels--2 for binary classification.
num_labels=2,
# Whether the model returns attentions weights.
output_attentions=False,
# Whether the model returns all hidden-states.
output_hidden_states=False,
torchscript=True,
)
self.model.eval()
def forward(self, input):
# Return only the logits.
return self.model(input)[0]
def _suppress_warnings():
import warnings
warnings.simplefilter("ignore")
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def _get_argparse():
parser = argparse.ArgumentParser(
description="Run a HuggingFace BERT Model.")
parser.add_argument("--model-name",
default="philschmid/MiniLM-L6-H384-uncased-sst2",
help="The HuggingFace model name to use.")
parser.add_argument("--sentence",
default="The quick brown fox jumps over the lazy dog.",
help="sentence to run the model on.")
iree_backend_choices = ["llvm-cpu", "vmvx", "vulkan", "cuda"]
parser.add_argument("--iree-backend",
choices=iree_backend_choices,
default="llvm-cpu",
help=f"""
Meaning of options:
llvm-cpu - cpu, native code
vmvx - cpu, interpreted
vulkan - GPU for general GPU devices
cuda - GPU for NVIDIA devices
""")
return parser
def main():
_suppress_warnings()
args = _get_argparse().parse_args()
print("Parsing sentence tokens.")
example_input = prepare_sentence_tokens(args.model_name, args.sentence)
print("Instantiating model.")
model = OnlyLogitsHuggingFaceModel(args.model_name)
# TODO: Wrap up all these steps into a convenient, well-tested API.
# TODO: Add ability to run on IREE CUDA backend.
print("Tracing model.")
traced = torch.jit.trace(model, example_input)
print("Compiling with Torch-MLIR")
linalg_on_tensors_mlir = torch_mlir.compile(traced, example_input,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)
print("Compiling with IREE")
iree_vmfb = iree_torch.compile_to_vmfb(linalg_on_tensors_mlir, args.iree_backend)
print("Loading in IREE")
invoker = iree_torch.load_vmfb(iree_vmfb, args.iree_backend)
print("Running on IREE")
import time
start = time.time()
result = invoker.forward(example_input)
end = time.time()
print("RESULT:", result)
print(f"Model execution took {end - start} seconds.")
if __name__ == "__main__":
main()