Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the ModernBert model #459

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
feature: flashmodernbert
kozistr committed Dec 25, 2024
commit 8cc712085fe52b33bcb485b09bb7bffb3c355c8d
191 changes: 120 additions & 71 deletions backends/candle/src/models/flash_modernbert.rs
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ use crate::flash_attn::flash_attn_varlen;
use crate::layers::{LayerNorm, Linear};
use crate::models::modernbert::{
ClassificationHead, ModernBertClassificationHead, ModernBertConfig, ModernBertEmbeddings,
ModernBertMLP,
};
use crate::models::Model;
use candle::{DType, Device, IndexOp, Result, Tensor};
@@ -12,10 +13,6 @@ struct ModernBertAttention {
wqkv: Linear,
wo: Linear,

local_attention: (i64, i64),
cos: Tensor,
sin: Tensor,

num_attention_heads: usize,
attention_head_size: usize,
softmax_scale: f64,
@@ -25,37 +22,45 @@ struct ModernBertAttention {

impl ModernBertAttention {
pub fn load(vb: VarBuilder, config: &BertConfig) -> Result<Self> {
let wi_weight = vb
.pp("Wi")
.get((config.hidden_size, config.intermediate_size * 2), "weight")?;
let wi_bias = vb
.pp("Wi")
.get((config.intermediate_size * 2,), "bias")
.ok();
let wi = Linear::new(wi_weight, wi_bias, None);

let wo_weight = vb
.pp("Wo")
.get((config.intermediate_size * 2, config.hidden_size), "weight")?;
let wo_bias = vb.pp("Wo").get((config.hidden_size,), "bias").ok();
let attention_head_size = config.hidden_size / config.num_attention_heads;
let hidden_size = config.hidden_size;

let wqkv_weight = vb
.pp("Wqkv")
.get((hidden_size * 3, hidden_size), "weight")?;
let wqkv_bias = if config.attention_bias {
vb.pp("Wqkv").get(hidden_size * 3, "bias").ok()
} else {
None
};
let wqkv: Linear = Linear::new(wqkv_weight, wqkv_bias, None);

let wo_weight = vb.pp("Wo").get((hidden_size, hidden_size), "weight")?;
let wo_bias = if config.attention_bias {
vb.pp("Wo").get(hidden_size, "bias").ok()
} else {
None
};
let wo = Linear::new(wo_weight, wo_bias, None);

let activation = Some(config.hidden_activation.clone());
let softmax_scale = 1. / (attention_head_size as f64).sqrt();

Ok(Self {
wi,
wqkv,
wo,
activation,
intermediate_size: config.intermediate_size,
span: tracing::span!(tracing::Level::TRACE, "mlp"),
num_attention_heads: config.num_attention_heads,
attention_head_size,
softmax_scale,
span: tracing::span!(tracing::Level::TRACE, "attention"),
})
}

pub fn forward(
&self,
hidden_states: &Tensor,
cu_seqlens: &Tensor,
cos: &Tensor,
sin: &Tensor,
max_s: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();
@@ -73,9 +78,8 @@ impl ModernBertAttention {
let key_layer = &qkv[1].contiguous()?;
let value_layer = &qkv[2];

let query_layer =
apply_rotary(query_layer, &self.cos, &self.sin, self.attention_head_size)?;
let key_layer = apply_rotary(key_layer, &self.cos, &self.sin, self.attention_head_size)?;
let query_layer = apply_rotary(query_layer, cos, sin, self.attention_head_size)?;
let key_layer = apply_rotary(key_layer, cos, sin, self.attention_head_size)?;

let attention = flash_attn_varlen(
&query_layer,
@@ -88,8 +92,7 @@ impl ModernBertAttention {
max_s,
self.softmax_scale,
false,
self.local_attention[0],
self.local_attention[1],
self.local_attention,
)?;
let attention = attention.flatten_from(candle::D::Minus2)?;

@@ -110,7 +113,7 @@ struct ModernBertEncoderLayer {

impl ModernBertEncoderLayer {
pub fn load(vb: VarBuilder, index: usize, config: &ModernBertConfig) -> Result<Self> {
let attn_norm = if index > 0 {
let attn_norm = if index != 0 {
Some(LayerNorm::load(
vb.pp("attn_norm"),
config.hidden_size,
@@ -120,7 +123,7 @@ impl ModernBertEncoderLayer {
None
};

let attn = ModernBertAttention::load(vb.pp("attn"), index, config)?;
let attn = ModernBertAttention::load(vb.pp("attn"), config)?;

let mlp_norm = LayerNorm::load(
vb.pp("mlp_norm"),
@@ -143,30 +146,38 @@ impl ModernBertEncoderLayer {
fn forward(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
silding_attention_mask: &Tensor,
cu_seqlens: &Tensor,
cos: &Tensor,
sin: &Tensor,
max_s: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();

let mut hidden_states = hidden_states.clone();
let residual = hidden_states.clone();

if let Some(attn_norm) = &self.attn_norm {
hidden_states = attn_norm.forward(&hidden_states, None)?;
}
let attn_norm = if let Some(attn_norm) = &self.attn_norm {
attn_norm.forward(hidden_states, None)?
} else {
hidden_states.clone()
};

let attn_outputs = self.attn.forward(&attn_norm, cu_seqlens, cos, sin, max_s)?;

let hidden_states = residual.add(&attn_outputs)?;

let hidden_states =
self.attn
.forward(&hidden_states, attention_mask, silding_attention_mask)?;
let mlp_output = self
.mlp
.forward(&self.mlp_norm.forward(&hidden_states, None)?)?;

hidden_states.broadcast_add(&mlp_output)
hidden_states.add(&mlp_output)
}
}

struct ModernBertEncoder {
layers: Vec<ModernBertEncoderLayer>,

global_attn_every_n_layers: usize,

span: tracing::Span,
}

@@ -178,22 +189,29 @@ impl ModernBertEncoder {

let span = tracing::span!(tracing::Level::TRACE, "encoder");

Ok(ModernBertEncoder { layers, span })
Ok(ModernBertEncoder {
layers,
global_attn_every_n_layers: config.global_attn_every_n_layers,
span,
})
}

fn forward(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
silding_attention_mask: &Tensor,
cu_seqlens: &Tensor,
rotary_cache: &HashMap<bool, (Tensor, Tensor)>,
max_s: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();

let mut hidden_states = hidden_states.clone();

for layer in self.layers.iter() {
hidden_states =
layer.forward(&hidden_states, attention_mask, silding_attention_mask)?;
for (index, layer) in self.layers.iter().enumerate() {
let use_local_attention = index % self.global_attn_every_n_layers != 0;
let (cos, sin) = &rotary_cache[&use_local_attention];

hidden_states = layer.forward(&hidden_states, cu_seqlens, cos, sin, max_s)?;
}

Ok(hidden_states)
@@ -208,6 +226,10 @@ pub struct FlashModernBertModel {
classifier: Option<Box<dyn ClassificationHead + Send>>,

local_attention: usize,
rotary_dim: usize,
rotary_cache: HashMap<bool, (Tensor, Tensor)>,
pad_token_id: u32,
num_attention_heads: usize,

device: Device,
dtype: DType,
@@ -251,18 +273,45 @@ impl FlashModernBertModel {
let embeddings = ModernBertEmbeddings::load(vb.pp("model.embeddings"), config)?;
let encoder = ModernBertEncoder::load(vb.pp("model.layers"), config)?;
let final_norm = LayerNorm::load(
vb.pp("final_norm"),
vb.pp("model.final_norm"),
config.hidden_size,
config.norm_eps as f32,
)?;

let rotary_dim = config.hidden_size / config.num_attention_heads;
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new();

for use_local_attention in [true, false] {
let rope_theta = if use_local_attention {
config.local_rope_theta
} else {
config.global_rope_theta
};

let max_position_embeddings = if use_local_attention {
config.max_position_embeddings
} else {
config.local_attention
};

let inv_freqs = get_inv_freqs(rotary_dim, rope_theta as f32, vb.device(), None)?;

let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), true)?;

rotary_cache.insert(use_local_attention, (cos, sin));
}

Ok(Self {
embeddings,
encoder,
final_norm,
pool,
classifier,
local_attention: config.local_attention,
rotary_dim,
rotary_cache,
pad_token_id: config.pad_token_id as u32,
num_attention_heads: config.num_attention_heads,
device: vb.device().clone(),
dtype: vb.dtype(),
span: tracing::span!(tracing::Level::TRACE, "model"),
@@ -273,6 +322,7 @@ impl FlashModernBertModel {
&self,
attention_mask: Option<&Tensor>,
input_shape: &Shape,
num_attention_heads: usize,
) -> Result<Tensor> {
let extended_attention_mask = if let Some(attention_mask) = attention_mask {
attention_mask.squeeze(2)?
@@ -283,16 +333,9 @@ impl FlashModernBertModel {
.unsqueeze(1)?
.to_dtype(self.dtype)?;

let min_value = match self.dtype {
DType::F32 => f32::MIN as f64,
_ => -65504.0_f64, // f16 minumum value
};

let extended_attention_mask = ((1.0 - extended_attention_mask)? * min_value)?;

let (bs, seq_len) = input_shape.dims2()?;
let extended_attention_mask =
extended_attention_mask.broadcast_as((bs, 1, seq_len, seq_len))?;
extended_attention_mask.broadcast_as((bs, num_attention_heads, seq_len, seq_len))?;

Ok(extended_attention_mask)
}
@@ -302,28 +345,24 @@ impl FlashModernBertModel {
attention_mask: &Tensor,
local_attention: usize,
) -> Result<Tensor> {
let attention_mask = attention_mask.to_dtype(DType::U8)?;
let mask_shape = attention_mask.shape();
let (_, _, seq_len, _) = mask_shape.dims4()?;

let rows = Tensor::arange(0, seq_len as i64, attention_mask.device())?.unsqueeze(0)?;
let rows = rows.broadcast_as((seq_len, seq_len))?;

let distance = (&rows - &rows.t()?)?.abs()?;

let window_size = local_attention / 2;
let window_mask = distance
.le(window_size as i64)?
.unsqueeze(0)?
.unsqueeze(0)?;

let dtype = attention_mask.dtype();
let min_value = match dtype {
DType::F32 => f32::MIN as f64,
_ => -65504.0, // f16 minimum value
};
.unsqueeze(0)?
.broadcast_as(mask_shape)?;

let inverted_window_mask = window_mask.eq(0_i64)?;
let min_value_tensor = Tensor::full(min_value, mask_shape, attention_mask.device())?;
let sliding_window_mask =
attention_mask.where_cond(&inverted_window_mask, &min_value_tensor)?;
let zero_tensor = Tensor::zeros_like(&attention_mask)?;
let sliding_window_mask = attention_mask.where_cond(&window_mask, &zero_tensor)?;

Ok(sliding_window_mask)
}
@@ -335,22 +374,32 @@ impl FlashModernBertModel {
let shape = batch.input_ids.len();

let input_ids = Tensor::from_vec(batch.input_ids, shape, &self.device)?;
let position_ids = Tensor::from_vec(batch.position_ids, shape, &self.device)?;
let cu_seqlens = Tensor::from_vec(
batch.cumulative_seq_lengths.clone(),
batch_size + 1,
&self.device,
)?;

let global_attention_mask =
self.get_global_attention_mask(attention_mask.as_ref(), input_ids.shape())?;
let silding_attention_mask =
self.get_silding_window_mask(&global_attention_mask, self.local_attention)?;
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new();
for use_local_attention in [true, false] {
let (cos, sin) = &self.rotary_cache[&use_local_attention];

let cos = cos.index_select(&position_ids, 0)?;
let sin = sin.index_select(&position_ids, 0)?;

let cos = cos.reshape((batch_size, 1, max_length, self.rotary_dim))?;
let sin = sin.reshape((batch_size, 1, max_length, self.rotary_dim))?;

rotary_cache.insert(use_local_attention, (cos, sin));
}

let hidden_states = self.embeddings.forward(&input_ids)?;
let hidden_states = self.encoder.forward(
&hidden_states,
&global_attention_mask,
&silding_attention_mask,
&cu_seqlens,
&rotary_cache,
batch.max_length as usize,
)?;
let outputs = self.final_norm.forward(&hidden_states, None)?;