-
Notifications
You must be signed in to change notification settings - Fork 937
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #43 from jbarrow/main
BERT implementation
- Loading branch information
Showing
6 changed files
with
413 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# BERT | ||
|
||
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX. | ||
|
||
## Downloading and Converting Weights | ||
|
||
The `convert.py` script relies on `transformers` to download the weights, and exports them as a single `.npz` file. | ||
|
||
``` | ||
python convert.py \ | ||
--bert-model bert-base-uncased | ||
--mlx-model weights/bert-base-uncased.npz | ||
``` | ||
|
||
## Usage | ||
|
||
To use the `Bert` model in your own code, you can load it with: | ||
|
||
```python | ||
from model import Bert, load_model | ||
|
||
model, tokenizer = load_model( | ||
"bert-base-uncased", | ||
"weights/bert-base-uncased.npz") | ||
|
||
batch = ["This is an example of BERT working on MLX."] | ||
tokens = tokenizer(batch, return_tensors="np", padding=True) | ||
tokens = {key: mx.array(v) for key, v in tokens.items()} | ||
|
||
output, pooled = model(**tokens) | ||
``` | ||
|
||
The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector for every input token. | ||
If you want to train anything at a **token-level**, you'll want to use this. | ||
|
||
The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input. | ||
If you want to train a **classification** model, you'll want to use this. | ||
|
||
## Comparison with 🤗 `transformers` Implementation | ||
|
||
In order to run the model, and have it forward inference on a batch of examples: | ||
|
||
```sh | ||
python model.py \ | ||
--bert-model bert-base-uncased \ | ||
--mlx-model weights/bert-base-uncased.npz | ||
``` | ||
|
||
Which will show the following outputs: | ||
``` | ||
MLX BERT: | ||
[[[-0.52508914 -0.1993871 -0.28210318 ... -0.61125606 0.19114694 | ||
0.8227601 ] | ||
[-0.8783862 -0.37107834 -0.52238125 ... -0.5067165 1.0847603 | ||
0.31066895] | ||
[-0.70010054 -0.5424497 -0.26593682 ... -0.2688697 0.38338926 | ||
0.6557663 ] | ||
... | ||
``` | ||
|
||
They can be compared against the 🤗 implementation with: | ||
|
||
```sh | ||
python hf_model.py \ | ||
--bert-model bert-base-uncased | ||
``` | ||
|
||
Which will show: | ||
``` | ||
HF BERT: | ||
[[[-0.52508944 -0.1993877 -0.28210333 ... -0.6112575 0.19114678 | ||
0.8227603 ] | ||
[-0.878387 -0.371079 -0.522381 ... -0.50671494 1.0847601 | ||
0.31066933] | ||
[-0.7001008 -0.5424504 -0.26593733 ... -0.26887015 0.38339025 | ||
0.65576553] | ||
... | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from transformers import BertModel | ||
|
||
import argparse | ||
import numpy | ||
|
||
|
||
def replace_key(key: str) -> str: | ||
key = key.replace(".layer.", ".layers.") | ||
key = key.replace(".self.key.", ".key_proj.") | ||
key = key.replace(".self.query.", ".query_proj.") | ||
key = key.replace(".self.value.", ".value_proj.") | ||
key = key.replace(".attention.output.dense.", ".attention.out_proj.") | ||
key = key.replace(".attention.output.LayerNorm.", ".ln1.") | ||
key = key.replace(".output.LayerNorm.", ".ln2.") | ||
key = key.replace(".intermediate.dense.", ".linear1.") | ||
key = key.replace(".output.dense.", ".linear2.") | ||
key = key.replace(".LayerNorm.", ".norm.") | ||
key = key.replace("pooler.dense.", "pooler.") | ||
return key | ||
|
||
|
||
def convert(bert_model: str, mlx_model: str) -> None: | ||
model = BertModel.from_pretrained(bert_model) | ||
# save the tensors | ||
tensors = { | ||
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items() | ||
} | ||
numpy.savez(mlx_model, **tensors) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") | ||
parser.add_argument( | ||
"--bert-model", | ||
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"], | ||
default="bert-base-uncased", | ||
help="The huggingface name of the BERT model to save.", | ||
) | ||
parser.add_argument( | ||
"--mlx-model", | ||
type=str, | ||
default="weights/bert-base-uncased.npz", | ||
help="The output path for the MLX BERT weights.", | ||
) | ||
args = parser.parse_args() | ||
|
||
convert(args.bert_model, args.mlx_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from transformers import AutoModel, AutoTokenizer | ||
|
||
import argparse | ||
|
||
|
||
def run(bert_model: str): | ||
batch = [ | ||
"This is an example of BERT working on MLX.", | ||
"A second string", | ||
"This is another string.", | ||
] | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(bert_model) | ||
torch_model = AutoModel.from_pretrained(bert_model) | ||
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True) | ||
torch_forward = torch_model(**torch_tokens) | ||
torch_output = torch_forward.last_hidden_state.detach().numpy() | ||
torch_pooled = torch_forward.pooler_output.detach().numpy() | ||
|
||
print("\n HF BERT:") | ||
print(torch_output) | ||
print("\n\n HF Pooled:") | ||
print(torch_pooled[0, :20]) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Run the BERT model using HuggingFace Transformers.") | ||
parser.add_argument( | ||
"--bert-model", | ||
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"], | ||
default="bert-base-uncased", | ||
help="The huggingface name of the BERT model to save.", | ||
) | ||
args = parser.parse_args() | ||
|
||
run(args.bert_model) |
Oops, something went wrong.