Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BERT family of models #15

Merged
merged 10 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ examples constructed using the [Burn](https://github.com/burn-rs/burn) deep lear

## Collection of Official Models

| Model | Description | Repository Link |
| ---------------------------------------------- | ------------------------------------------------- | -------------------------------------------- |
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| Model | Description | Repository Link |
|------------------------------------------------|-------------------------------------------------------|----------------------------------------------|
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/README.md) |

## Community Contributions

Explore the curated list of models developed by the community ♥.

| Model | Description | Repository Link |
| ------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------------------------------- |
|---------------------------------------------|-------------------------------------------------------------------|-----------------------------------------------------------------------------------|
| [Llama 2](https://arxiv.org/abs/2307.09288) | LLMs by Meta AI, ranging from 7 billion to 70 billion parameters. | [Gadersd/llama2-burn](https://github.com/Gadersd/llama2-burn) |
| [Whisper](https://arxiv.org/abs/2212.04356) | A general-purpose speech recognition model by OpenAI. | [Gadersd/whisper-burn](https://github.com/Gadersd/whisper-burn) |
| Stable Diffusion v1.4 | An image generation model developed by Stability AI. | [Gadersd/stable-diffusion-burn](https://github.com/Gadersd/stable-diffusion-burn) |
Expand Down
37 changes: 37 additions & 0 deletions bert-burn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
[package]
authors = ["Aasheesh Singh [email protected]"]
license = "MIT OR Apache-2.0"
name = "bert-burn"
version="0.1.0"
edition = "2021"

[features]
default = ["burn/dataset"]
f16 = []
ndarray = ["burn/ndarray"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
fusion = ["burn/fusion"]
# To be replaced by burn-safetensors once supported: https://github.com/tracel-ai/burn/issues/626
safetensors = ["candle-core/default"]


[dependencies]
# Burn
burn = {version = "0.12.1", default-features = false}
candle-core = { version = "0.3.2", optional = true}
# Tokenizer
tokenizers = { version = "0.15.0", default-features = false, features = [
"onig",
"http",
] }
burn-import = "0.12.1"
derive-new = "0.6.0"
hf-hub = { version = "0.3.2", features = ["tokio"] }

# Utils
serde = { version = "1.0.196", features = ["std", "derive"] }
libm = "0.2.8"
serde_json = "1.0.113"
tokio = "1.35.1"
1 change: 1 addition & 0 deletions bert-burn/LICENSE-APACHE
1 change: 1 addition & 0 deletions bert-burn/LICENSE-MIT
39 changes: 39 additions & 0 deletions bert-burn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Bert-Burn Model

This project provides an example implementation for inference on the BERT family of models. The following compatible
bert-variants: `roberta-base`(**default**)/`roberta-large`, `bert-base-uncased`/`bert-large-uncased`/`bert-base-cased`/`bert-large-cased`
can be loaded as following. The pre-trained weights and config files are automatically downloaded
from: [HuggingFace Model hub](https://huggingface.co/FacebookAI/roberta-base/tree/main)

### To include the model in your project

Add this to your `Cargo.toml`:

```toml
[dependencies]
bert-burn = { git = "https://github.com/burn-rs/models", package = "bert-burn", default-features = false }
```

## Example Usage

Example usage for getting sentence embedding from given input text. The model supports multiple backends from burn
(e.g. `ndarray`, `wgpu`, `tch-gpu`, `tch-cpu`) which can be selected using the `--features` flag. An example with `wgpu`
backend is shown below. The `fusion` flag is used to enable kernel fusion for the `wgpu` backend. It is not required
with other backends. The `safetensors` flag is used to support loading weights in `safetensors` format via `candle-core`
crate.

### WGPU backend

```bash
cd bert-burn/
# Get sentence embeddings from the RobBERTa encoder (default)
cargo run --example infer-embedding --release --features wgpu,fusion,safetensors

# Using bert-base-uncased model
cargo run --example infer-embedding --release --features wgpu,fusion,safetensors bert-base-uncased

# Using roberta-large model
cargo run --example infer-embedding --release --features wgpu,fusion,safetensors roberta-large
```


154 changes: 154 additions & 0 deletions bert-burn/examples/infer-embedding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
use bert_burn::data::{BertInputBatcher, BertTokenizer};
use bert_burn::loader::{download_hf_model, load_model_config};
use bert_burn::model::BertModel;
use burn::data::dataloader::batcher::Batcher;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use std::env;
use std::sync::Arc;

#[cfg(not(feature = "f16"))]
#[allow(dead_code)]
type ElemType = f32;
#[cfg(feature = "f16")]
type ElemType = burn::tensor::f16;

pub fn launch<B: Backend>(device: B::Device) {
let args: Vec<String> = env::args().collect();
let default_model = "roberta-base".to_string();
let model_variant = if args.len() > 1 {
// Use the argument provided by the user
// Possible values: "bert-base-uncased", "roberta-large" etc.
&args[1]
} else {
// Use the default value if no argument is provided
&default_model
};

println!("Model variant: {}", model_variant);

let text_samples = vec![
"Jays power up to take finale Contrary to popular belief, the power never really \
snapped back at SkyDome on Sunday. The lights came on after an hour delay, but it \
took some extra time for the batting orders to provide some extra wattage."
.to_string(),
"Yemen Sentences 15 Militants on Terror Charges A court in Yemen has sentenced one \
man to death and 14 others to prison terms for a series of attacks and terrorist \
plots in 2002, including the bombing of a French oil tanker."
.to_string(),
"IBM puts grids to work at U.S. Open IBM will put a collection of its On \
Demand-related products and technologies to this test next week at the U.S. Open \
tennis championships, implementing a grid-based infrastructure capable of running \
multiple workloads including two not associated with the tournament."
.to_string(),
];

let (config_file, model_file) = download_hf_model(model_variant);
let model_config = load_model_config(config_file);

let model: BertModel<B> =
BertModel::from_safetensors(model_file, &device, model_config.clone());

let tokenizer = Arc::new(BertTokenizer::new(
model_variant.to_string(),
model_config.pad_token_id.clone(),
));

// Batch the input samples to max sequence length with padding
let batcher = Arc::new(BertInputBatcher::<B>::new(
tokenizer.clone(),
device.clone(),
model_config.max_seq_len.unwrap().clone(),
));

// Batch input samples using the batcher Shape: [Batch size, Seq_len]
let input = batcher.batch(text_samples.clone());
let [batch_size, _seq_len] = input.tokens.dims();
println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape());

let output = model.forward(input);

// get sentence embedding from the first [CLS] token
let cls_token_idx = 0;

// Embedding size
let d_model = model_config.hidden_size.clone();
let sentence_embedding =
output
.clone()
.slice([0..batch_size, cls_token_idx..cls_token_idx + 1, 0..d_model]);

let sentence_embedding: Tensor<B, 2> = sentence_embedding.squeeze(1);
println!(
"Roberta Sentence embedding {:?} // (Batch Size, Embedding_dim)",
sentence_embedding.shape()
);
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};

use crate::{launch, ElemType};

pub fn run() {
launch::<NdArray<ElemType>>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use crate::{launch, ElemType};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<LibTorch<ElemType>>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use crate::{launch, ElemType};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

pub fn run() {
launch::<LibTorch<ElemType>>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::{launch, ElemType};
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
use burn::backend::Fusion;

pub fn run() {
launch::<Fusion<Wgpu<AutoGraphicsApi, ElemType, i32>>>(WgpuDevice::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "tch-cpu")]
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
}
51 changes: 51 additions & 0 deletions bert-burn/src/data/batcher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use super::tokenizer::Tokenizer;
use burn::{
data::dataloader::batcher::Batcher,
nn::attention::generate_padding_mask,
tensor::{backend::Backend, Bool, Int, Tensor},
};
use std::sync::Arc;

#[derive(new)]
pub struct BertInputBatcher<B: Backend> {
/// Tokenizer for converting input text string to token IDs
tokenizer: Arc<dyn Tokenizer>,
/// Device on which to perform computation (e.g., CPU or CUDA device)
device: B::Device,
/// Maximum sequence length for tokenized text
max_seq_length: usize,
}

#[derive(Debug, Clone, new)]
pub struct BertInferenceBatch<B: Backend> {
/// Tokenized text as 2D tensor: [batch_size, max_seq_length]
pub tokens: Tensor<B, 2, Int>,
/// Padding mask for the tokenized text containing booleans for padding locations
pub mask_pad: Tensor<B, 2, Bool>,
}

impl<B: Backend> Batcher<String, BertInferenceBatch<B>> for BertInputBatcher<B> {
/// Batches a vector of strings into an inference batch
fn batch(&self, items: Vec<String>) -> BertInferenceBatch<B> {
let mut tokens_list = Vec::with_capacity(items.len());

// Tokenize each string
for item in items {
tokens_list.push(self.tokenizer.encode(&item));
}

// Generate padding mask for tokenized text
let mask = generate_padding_mask(
self.tokenizer.pad_token(),
tokens_list,
Some(self.max_seq_length),
&self.device,
);

// Create and return inference batch
BertInferenceBatch {
tokens: mask.tensor,
mask_pad: mask.mask,
}
}
}
5 changes: 5 additions & 0 deletions bert-burn/src/data/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod batcher;
mod tokenizer;

pub use batcher::*;
pub use tokenizer::*;
64 changes: 64 additions & 0 deletions bert-burn/src/data/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
pub trait Tokenizer: Send + Sync {
/// Converts a text string into a sequence of tokens.
fn encode(&self, value: &str) -> Vec<usize>;

/// Converts a sequence of tokens back into a text string.
fn decode(&self, tokens: &[usize]) -> String;

/// Gets the size of the tokenizer's vocabulary.
fn vocab_size(&self) -> usize;

/// Gets the token used for padding sequences to a consistent length.
fn pad_token(&self) -> usize;

/// Gets the string representation of the padding token.
/// The default implementation uses `decode` on the padding token.
fn pad_token_value(&self) -> String {
self.decode(&[self.pad_token()])
}
}

/// Struct represents a specific tokenizer using the Roberta BPE tokenization strategy.
pub struct BertTokenizer {
// The underlying tokenizer from the `tokenizers` library.
tokenizer: tokenizers::Tokenizer,
pad_token: usize,
}

// Default implementation for creating a new BertTokenizer.
// Downloads tokenizer from given model_name (eg: "roberta-base").
// Pad_token_id is the id of the padding token used to convert sequences to a consistent length.
// specified in the model's config.json.
impl BertTokenizer {
pub fn new(model_name: String, pad_token_id: usize) -> Self {
Self {
tokenizer: tokenizers::Tokenizer::from_pretrained(model_name, None).unwrap(),
pad_token: pad_token_id,
}
}
}

// Implementation of the Tokenizer trait for BertTokenizer.
impl Tokenizer for BertTokenizer {
/// Convert a text string into a sequence of tokens using the BERT model's tokenization strategy.
fn encode(&self, value: &str) -> Vec<usize> {
let tokens = self.tokenizer.encode(value, true).unwrap();
tokens.get_ids().iter().map(|t| *t as usize).collect()
}

/// Converts a sequence of tokens back into a text string.
fn decode(&self, tokens: &[usize]) -> String {
let tokens = tokens.iter().map(|t| *t as u32).collect::<Vec<u32>>();
self.tokenizer.decode(&tokens, false).unwrap()
}

/// Gets the size of the BERT tokenizer's vocabulary.
fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(true)
}

/// Gets the token used for padding sequences to a consistent length.
fn pad_token(&self) -> usize {
self.pad_token
}
}
Loading