Skip to content

Commit

Permalink
Merge pull request #43 from jbarrow/main
Browse files Browse the repository at this point in the history
BERT implementation
  • Loading branch information
awni authored Dec 9, 2023
2 parents 8b2a6fe + d873e10 commit 46c6bbe
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 0 deletions.
78 changes: 78 additions & 0 deletions bert/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# BERT

An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX.

## Downloading and Converting Weights

The `convert.py` script relies on `transformers` to download the weights, and exports them as a single `.npz` file.

```
python convert.py \
--bert-model bert-base-uncased
--mlx-model weights/bert-base-uncased.npz
```

## Usage

To use the `Bert` model in your own code, you can load it with:

```python
from model import Bert, load_model

model, tokenizer = load_model(
"bert-base-uncased",
"weights/bert-base-uncased.npz")

batch = ["This is an example of BERT working on MLX."]
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}

output, pooled = model(**tokens)
```

The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector for every input token.
If you want to train anything at a **token-level**, you'll want to use this.

The `pooled` contains a `Batch x Dims` tensor, which is the pooled representation for each input.
If you want to train a **classification** model, you'll want to use this.

## Comparison with 🤗 `transformers` Implementation

In order to run the model, and have it forward inference on a batch of examples:

```sh
python model.py \
--bert-model bert-base-uncased \
--mlx-model weights/bert-base-uncased.npz
```

Which will show the following outputs:
```
MLX BERT:
[[[-0.52508914 -0.1993871 -0.28210318 ... -0.61125606 0.19114694
0.8227601 ]
[-0.8783862 -0.37107834 -0.52238125 ... -0.5067165 1.0847603
0.31066895]
[-0.70010054 -0.5424497 -0.26593682 ... -0.2688697 0.38338926
0.6557663 ]
...
```

They can be compared against the 🤗 implementation with:

```sh
python hf_model.py \
--bert-model bert-base-uncased
```

Which will show:
```
HF BERT:
[[[-0.52508944 -0.1993877 -0.28210333 ... -0.6112575 0.19114678
0.8227603 ]
[-0.878387 -0.371079 -0.522381 ... -0.50671494 1.0847601
0.31066933]
[-0.7001008 -0.5424504 -0.26593733 ... -0.26887015 0.38339025
0.65576553]
...
```
47 changes: 47 additions & 0 deletions bert/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from transformers import BertModel

import argparse
import numpy


def replace_key(key: str) -> str:
key = key.replace(".layer.", ".layers.")
key = key.replace(".self.key.", ".key_proj.")
key = key.replace(".self.query.", ".query_proj.")
key = key.replace(".self.value.", ".value_proj.")
key = key.replace(".attention.output.dense.", ".attention.out_proj.")
key = key.replace(".attention.output.LayerNorm.", ".ln1.")
key = key.replace(".output.LayerNorm.", ".ln2.")
key = key.replace(".intermediate.dense.", ".linear1.")
key = key.replace(".output.dense.", ".linear2.")
key = key.replace(".LayerNorm.", ".norm.")
key = key.replace("pooler.dense.", "pooler.")
return key


def convert(bert_model: str, mlx_model: str) -> None:
model = BertModel.from_pretrained(bert_model)
# save the tensors
tensors = {
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
}
numpy.savez(mlx_model, **tensors)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
parser.add_argument(
"--bert-model",
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"],
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
parser.add_argument(
"--mlx-model",
type=str,
default="weights/bert-base-uncased.npz",
help="The output path for the MLX BERT weights.",
)
args = parser.parse_args()

convert(args.bert_model, args.mlx_model)
36 changes: 36 additions & 0 deletions bert/hf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from transformers import AutoModel, AutoTokenizer

import argparse


def run(bert_model: str):
batch = [
"This is an example of BERT working on MLX.",
"A second string",
"This is another string.",
]

tokenizer = AutoTokenizer.from_pretrained(bert_model)
torch_model = AutoModel.from_pretrained(bert_model)
torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
torch_forward = torch_model(**torch_tokens)
torch_output = torch_forward.last_hidden_state.detach().numpy()
torch_pooled = torch_forward.pooler_output.detach().numpy()

print("\n HF BERT:")
print(torch_output)
print("\n\n HF Pooled:")
print(torch_pooled[0, :20])


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the BERT model using HuggingFace Transformers.")
parser.add_argument(
"--bert-model",
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"],
default="bert-base-uncased",
help="The huggingface name of the BERT model to save.",
)
args = parser.parse_args()

run(args.bert_model)
Loading

0 comments on commit 46c6bbe

Please sign in to comment.