Skip to content

Commit

Permalink
fix: rotary embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Jan 19, 2025
1 parent 3b20211 commit 63c4224
Showing 1 changed file with 11 additions and 18 deletions.
29 changes: 11 additions & 18 deletions backends/candle/src/models/modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::layers::{
apply_rotary, get_cos_sin, get_cublas_lt_wrapper, get_inv_freqs, HiddenAct, LayerNorm, Linear,
};
use crate::models::Model;
use candle::{DType, Device, IndexOp, Module, Result, Shape, Tensor, D};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Embedding, VarBuilder};
use serde::Deserialize;
use text_embeddings_backend_core::{Batch, ModelType, Pool};
Expand Down Expand Up @@ -454,7 +454,7 @@ pub struct ModernBertModel {

local_attention: usize,
rotary_dim: usize,
rotary_cache: HashMap<bool, (Tensor, Tensor)>,
inv_freqs_cache: HashMap<bool, Tensor>,
pad_token_id: u32,
num_attention_heads: usize,

Expand Down Expand Up @@ -506,7 +506,7 @@ impl ModernBertModel {
})?;

let rotary_dim = config.hidden_size / config.num_attention_heads;
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new();
let mut inv_freqs_cache: HashMap<bool, Tensor> = HashMap::new();

for use_local_attention in [true, false] {
let rope_theta = if use_local_attention {
Expand All @@ -515,17 +515,9 @@ impl ModernBertModel {
config.global_rope_theta
};

let max_position_embeddings = if use_local_attention {
config.max_position_embeddings
} else {
config.local_attention
};

let inv_freqs = get_inv_freqs(rotary_dim, rope_theta as f32, vb.device(), None)?;

let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), true)?;

rotary_cache.insert(use_local_attention, (cos, sin));
inv_freqs_cache.insert(use_local_attention, inv_freqs);
}

Ok(Self {
Expand All @@ -536,7 +528,7 @@ impl ModernBertModel {
classifier,
local_attention: config.local_attention,
rotary_dim,
rotary_cache,
inv_freqs_cache,
pad_token_id: config.pad_token_id as u32,
num_attention_heads: config.num_attention_heads,
device: vb.device().clone(),
Expand All @@ -548,18 +540,18 @@ impl ModernBertModel {
fn get_global_attention_mask(
&self,
attention_mask: Option<&Tensor>,
input_shape: &Shape,
input_shape: &(usize, usize),
) -> Result<Tensor> {
let extended_attention_mask = if let Some(attention_mask) = attention_mask {
attention_mask.squeeze(2)?
} else {
Tensor::ones(input_shape, DType::F32, &self.device)?
Tensor::ones(*input_shape, DType::F32, &self.device)?
}
.unsqueeze(1)?
.unsqueeze(1)?
.to_dtype(self.dtype)?;

let (bs, seq_len) = input_shape.dims2()?;
let (bs, seq_len) = *input_shape;
let extended_attention_mask = extended_attention_mask.broadcast_as((
bs,
self.num_attention_heads,
Expand Down Expand Up @@ -664,7 +656,7 @@ impl ModernBertModel {
Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?;

let global_attention_mask = self
.get_global_attention_mask(attention_mask.as_ref(), input_ids.shape())?
.get_global_attention_mask(attention_mask.as_ref(), &shape)?
.to_dtype(self.dtype)?;
let silding_attention_mask = self
.get_silding_window_mask(&global_attention_mask)?
Expand All @@ -680,7 +672,8 @@ impl ModernBertModel {

let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new();
for use_local_attention in [true, false] {
let (cos, sin) = &self.rotary_cache[&use_local_attention];
let inv_freq = &self.inv_freqs_cache[&use_local_attention];
let (cos, sin) = get_cos_sin(max_length, inv_freq, self.dtype, true)?;

let cos = cos.index_select(&position_ids, 0)?;
let sin = sin.index_select(&position_ids, 0)?;
Expand Down

0 comments on commit 63c4224

Please sign in to comment.