Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BERT implementation #43

Merged
merged 6 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions bert/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# BERT

An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX.

## Downloading and Converting Weights

The `convert.py` script relies on `transformers` to download the weights, and exports them as a single `.npz` file.

```
python convert.py \
--bert-model bert-base-uncased
--mlx-model weights/bert-base-uncased.npz
```

## Usage

To use the `Bert` model in your own code, you can load it with:

```python
from model import Bert, load_model

model, tokenizer = load_model(
"bert-base-uncased",
"weights/bert-base-uncased.npz")

batch = ["This is an example of BERT working on MLX."]
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}

output, pooled = model(**tokens)
```

The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector for every input token.
If you want to train anything at a **token-level**, you'll want to use this.

The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input.
If you want to train a **classification** model, you'll want to use this.

## Comparison with 🤗 `transformers` Implementation

In order to run the model, and have it forward inference on a batch of examples:

```sh
python model.py \
--bert-model bert-base-uncased \
--mlx-model weights/bert-base-uncased.npz
```

Which will show the following outputs:
```
MLX BERT:
[[[-0.17057164 0.08602728 -0.12471077 ... -0.09469379 -0.00275938
0.28314582]
[ 0.15222196 -0.48997563 -0.26665813 ... -0.19935863 -0.17162783
-0.51360303]
[ 0.9460105 0.1358298 -0.2945672 ... 0.00868467 -0.90271163
-0.2785422 ]]]
```
jbarrow marked this conversation as resolved.
Show resolved Hide resolved

They can be compared against the 🤗 implementation with:

```sh
python hf_model.py \
--bert-model bert-base-uncased
```

Which will show:
```
HF BERT:
[[[-0.17057131 0.08602707 -0.12471108 ... -0.09469365 -0.00275959
0.28314728]
[ 0.15222463 -0.48997375 -0.26665992 ... -0.19936043 -0.17162988
-0.5136028 ]
[ 0.946011 0.13582966 -0.29456618 ... 0.00868565 -0.90271175
-0.27854213]]]
```
47 changes: 47 additions & 0 deletions bert/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from transformers import BertModel

import argparse
import numpy


def replace_key(key: str) -> str:
key = key.replace(".layer.", ".layers.")
key = key.replace(".self.key.", ".key_proj.")
key = key.replace(".self.query.", ".query_proj.")
key = key.replace(".self.value.", ".value_proj.")
key = key.replace(".attention.output.dense.", ".attention.out_proj.")
key = key.replace(".attention.output.LayerNorm.", ".ln1.")
key = key.replace(".output.LayerNorm.", ".ln2.")
key = key.replace(".intermediate.dense.", ".linear1.")
key = key.replace(".output.dense.", ".linear2.")
key = key.replace(".LayerNorm.", ".norm.")
key = key.replace("pooler.dense.", "pooler.")
return key


def convert(bert_model: str, mlx_model: str) -> None:
model = BertModel.from_pretrained(bert_model)
# save the tensors
tensors = {
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
}
numpy.savez(mlx_model, **tensors)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
parser.add_argument(
"--bert-model",
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"],
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The output path for the MLX BERT weights.",
)
args = parser.parse_args()

convert(args.bert_model, args.mlx_model)
36 changes: 36 additions & 0 deletions bert/hf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from transformers import AutoModel, AutoTokenizer

import argparse


def run(bert_model: str):
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]

tokenizer = AutoTokenizer.from_pretrained(bert_model)
torch_model = AutoModel.from_pretrained(bert_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens)
torch_output = torch_forward.last_hidden_state.detach().numpy()
torch_pooled = torch_forward.pooler_output.detach().numpy()

print("\n HF BERT:")
print(torch_output)
print("\n\n HF Pooled:")
print(torch_pooled[0, :20])


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the BERT model using HuggingFace Transformers.")
parser.add_argument(
"--bert-model",
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"],
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
args = parser.parse_args()

run(args.bert_model)
Loading