Skip to content

Commit

Permalink
Merge pull request #256 from huggingface/feature/cpu-rope
Browse files Browse the repository at this point in the history
Feature/cpu rope
  • Loading branch information
FL33TW00D authored Oct 29, 2024
2 parents 721f3c6 + b74e4b2 commit a0bd0f1
Show file tree
Hide file tree
Showing 10 changed files with 376 additions and 70 deletions.
59 changes: 43 additions & 16 deletions crates/ratchet-core/src/cpu/gemm.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::{
cpu_store_result, CPUOperation, DType, InvariantError, Matmul, MatmulSpec, OperationError,
Shape, Tensor, TensorDType,
cpu::cpu_store_result, CPUOperation, DType, InvariantError, Matmul, MatmulSpec, OperationError,
Shape, Strides, Tensor, TensorDType,
};
use anyhow::{anyhow, Result};
use core::str::FromStr;
use gemm::{gemm, Parallelism};
use gemm::{gemm as gemm_kernel, Parallelism};
use half::{bf16, f16};
use std::num::NonZeroUsize;

Expand Down Expand Up @@ -56,21 +56,19 @@ fn calculate_skips(
Ok((lhs_skip, rhs_skip))
}

fn gemm_impl<T: TensorDType>(
spec: MatmulSpec,
pub(crate) fn gemm<T: TensorDType>(
lhs: &[T],
lhs_shape: &Shape,
lhs_strides: &Strides,
rhs: &[T],
rhs_shape: &Shape,
rhs_strides: &Strides,
dst_strides: &Strides,
b: usize,
m: usize,
n: usize,
k: usize,
) -> Result<Vec<T>, OperationError> {
let lhs_shape = spec.lhs_shape();
let rhs_shape = spec.rhs_shape();
let lhs_strides = spec.lhs_strides();
let rhs_strides = spec.rhs_strides();
let dst_strides = spec.dst_strides();
let b = spec.stacks();
let m = spec.m();
let n = spec.n();
let k = spec.k();

let lhs_strides = lhs_strides.to_vec();
let rhs_strides = rhs_strides.to_vec();
let rank = lhs_shape.rank();
Expand Down Expand Up @@ -102,7 +100,7 @@ fn gemm_impl<T: TensorDType>(
let rhs_p = &rhs[step * rhs_skip..];
let dst_p = &mut dst[step * dst_skip..];
unsafe {
gemm(
gemm_kernel(
m,
n,
k,
Expand All @@ -128,6 +126,35 @@ fn gemm_impl<T: TensorDType>(
Ok(dst)
}

fn gemm_impl<T: TensorDType>(
spec: MatmulSpec,
lhs: &[T],
rhs: &[T],
) -> Result<Vec<T>, OperationError> {
let lhs_shape = spec.lhs_shape();
let rhs_shape = spec.rhs_shape();
let lhs_strides = spec.lhs_strides();
let rhs_strides = spec.rhs_strides();
let dst_strides = spec.dst_strides();
let b = spec.stacks();
let m = spec.m();
let n = spec.n();
let k = spec.k();
gemm(
lhs,
lhs_shape,
lhs_strides,
rhs,
rhs_shape,
rhs_strides,
dst_strides,
b,
m,
n,
k,
)
}

impl CPUOperation for Matmul {
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
fn run_gemm<T: TensorDType>(
Expand Down
73 changes: 44 additions & 29 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
mod binary;
pub mod gemm;
pub mod rope;
mod unary;
mod utils;

use crate::{
dequantize, Binary, CPUBuffer, Cast, Concat, DType, IndexSelect, InvariantError, LazyOp,
Operation, OperationError, RVec, Storage, Tensor, TensorDType,
dequantize, Binary, BinaryOp, CPUBuffer, Cast, Concat, DType, IndexSelect, InvariantError,
LazyOp, OpGuards, Operation, OperationError, RVec, Shape, Storage, StorageView, Strides,
Tensor, TensorDType, Unary, UnaryOp,
};
use anyhow::anyhow;
use bytemuck::NoUninit;
use core::marker::PhantomData;
use half::{bf16, f16};
use num_traits::Float;
use rope::cpu_rope;
use utils::cpu_store_result;

pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result<Tensor, OperationError> {
match op {
LazyOp::Binary(b) => b.apply_cpu(dst),
LazyOp::Cast(c) => cpu_cast(c, dst),
LazyOp::Matmul(m) => m.apply_cpu(dst),
LazyOp::Softmax(_s) => todo!(),
LazyOp::RoPE(_r) => todo!(),
LazyOp::RoPE(r) => cpu_rope(r, dst),
LazyOp::Unary(u) => u.apply_cpu(dst),
LazyOp::Reindex(_r) => todo!(),
LazyOp::Concat(c) => cpu_concat(c, dst),
Expand Down Expand Up @@ -148,44 +154,57 @@ pub fn cpu_cast(cast: Cast, dst: Tensor) -> Result<Tensor, OperationError> {
Ok(dst)
}

fn concat_inner<T: TensorDType>(
inputs: RVec<Tensor>,
pub(crate) fn concat<T: TensorDType>(
inputs: &[(Shape, Vec<T>)],
dim: usize,
dst: Tensor,
) -> Result<Tensor, OperationError> {
let dst_size = dst.shape().clone().product();
let mut result = vec![T::zero(); dst_size];

let dst_dim_len = dst.shape()[dim];
let block: usize = dst.shape().iter().skip(1 + dim).product();
dst_shape: &Shape,
dst: &mut [T],
) -> Result<(), OperationError> {
let dst_dim_len = dst_shape[dim];
let block: usize = dst_shape.iter().skip(1 + dim).product();
let dst_s = block * dst_dim_len;
let src_o = 0;
let mut dst_o = 0;
for t in inputs {
let src = t.to_vec::<T>()?;

let t_dims = t.shape().as_slice();
let a_dim: usize = t_dims.iter().take(dim).product();
let b_dim = block * t_dims[dim];

for (src_s, src) in inputs {
let a_dim: usize = src_s.iter().take(dim).product();
let b_dim = block * src_s[dim];
for idx in 0..a_dim {
let dst_idx = idx * dst_s + dst_o;
let src_idx = idx * b_dim + src_o;
let dst = &mut result[dst_idx..dst_idx + b_dim];
let dst_t = &mut dst[dst_idx..dst_idx + b_dim];
let src = &src[src_idx..src_idx + b_dim];
dst.copy_from_slice(src)
dst_t.copy_from_slice(src)
}
dst_o += b_dim;
}
Ok(())
}
pub(crate) fn apply_concat<T: TensorDType>(
inputs: RVec<Tensor>,
dim: usize,
dst: Tensor,
) -> Result<Tensor, OperationError> {
let dst_size = dst.shape().numel();
let mut result = vec![T::zero(); dst_size];

let inputs = inputs
.iter()
.map(|t| match t.to_vec::<T>() {
Ok(v) => Ok((t.shape().clone(), v)),
Err(e) => Err(e.into()),
})
.collect::<Result<Vec<_>, OperationError>>();

concat(&inputs?, dim, dst.shape(), &mut result)?;
cpu_store_result(&dst, &result);
Ok(dst)
}

pub fn cpu_concat(Concat { inputs, dim }: Concat, dst: Tensor) -> Result<Tensor, OperationError> {
match dst.dt() {
DType::F32 => concat_inner::<f32>(inputs, dim, dst),
DType::F16 => concat_inner::<f16>(inputs, dim, dst),
DType::BF16 => concat_inner::<bf16>(inputs, dim, dst),
DType::F32 => apply_concat::<f32>(inputs, dim, dst),
DType::F16 => apply_concat::<f16>(inputs, dim, dst),
DType::BF16 => apply_concat::<bf16>(inputs, dim, dst),
dtype => Err(InvariantError::UnsupportedDType(dtype).into()),
}
}
Expand Down Expand Up @@ -266,7 +285,3 @@ pub fn binary_apply_inplace<T: TensorDType>(
cpu_store_result(dst, &lhs);
Ok(())
}

pub fn cpu_store_result<T: NoUninit>(dst: &Tensor, data: &[T]) {
dst.update_storage(Storage::CPU(CPUBuffer::from_slice(data, dst.shape())));
}
170 changes: 170 additions & 0 deletions crates/ratchet-core/src/cpu/rope.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
use crate::{
concat,
cpu::{cpu_store_result, gemm::gemm},
shape, DType, OperationError, RoPE, Shape, Strides, Tensor,
};
use anyhow::anyhow;

pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result<Tensor, OperationError> {
match op.input().dt() {
DType::F32 => {
let dim = op.dim();
let base = op.base();
let offset = op.offset();
let src = op.input().to_vec::<f32>()?;
let result = rope(src, op.input().shape(), dim, base, offset)?;
cpu_store_result(&dst, &result)
}
_ => todo!(),
}

Ok(dst)
}

fn compute_theta(
dim: usize,
seq_len: usize,
base: f32,
offset: usize,
) -> Result<Vec<f32>, OperationError> {
let half_dim = dim / 2;

let positions = (offset..seq_len + offset)
.map(|x| x as f32)
.collect::<Vec<f32>>();

let inv_freqs = (0..half_dim)
.map(|i| -(i as f32))
.map(|i| i * base.ln() / half_dim as f32)
.map(f32::exp)
.collect::<Vec<f32>>();

let p_shape = shape!(seq_len, 1);
let p_strides = Strides::from(&p_shape);
let i_shape = shape!(1, half_dim);
let i_strides = Strides::from(&i_shape);
let dst_strides = Strides::from(&shape!(seq_len, half_dim));
let theta = gemm(
&positions,
&p_shape,
&p_strides,
&inv_freqs,
&i_shape,
&i_strides,
&dst_strides,
1,
seq_len,
half_dim,
1,
)?;

Ok(theta)
}

fn slice(src: &[f32], src_strides: &Strides, start: &[usize], stop: &[usize]) -> Vec<f32> {
assert!(start.len() == stop.len());
assert!(start.len() == src_strides.rank());
start.iter().zip(stop.iter()).for_each(|(s, t)| {
assert!(s < t);
});

let dst_shape: Vec<usize> = stop.iter().zip(start.iter()).map(|(s, t)| s - t).collect();
let dst_numel: usize = dst_shape.iter().product();

let mut dst = vec![0.0; dst_numel];

for i in 0..dst_numel {
let mut src_index = 0;
let mut tmp = i;
for d in 0..dst_shape.len() {
let coord = tmp / dst_shape[d + 1..].iter().product::<usize>().max(1);
tmp %= dst_shape[d + 1..].iter().product::<usize>().max(1);
src_index += (coord + start[d]) * src_strides[d] as usize;
}
dst[i] = src[src_index];
}

dst
}

fn rope(
src: Vec<f32>,
shape: &Shape,
dim: usize,
base: f32,
offset: usize,
) -> Result<Vec<f32>, OperationError> {
let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap();

let half_dim = dim / 2;
let theta = compute_theta(dim, seq_len, base, offset)?;
let (sin, cos): (Vec<f32>, Vec<f32>) = theta.iter().map(|i| i.sin_cos()).unzip();
let src_strides = Strides::from(shape);
let x1 = slice(
&src,
&src_strides,
&[0, 0, 0, 0],
&[batches, num_heads, seq_len, half_dim],
);
let x2 = slice(
&src,
&src_strides,
&[0, 0, 0, half_dim],
&[batches, num_heads, seq_len, dim],
);

//`multiply` as an operation that deals with broadcasting
let x1_cos = x1
.iter()
.zip(cos.iter().cycle())
.map(|(x, c)| x * c)
.collect::<Vec<f32>>();
let x2_sin = x2
.iter()
.zip(sin.iter().cycle())
.map(|(x, s)| x * s)
.collect::<Vec<f32>>();

let mut r1 = x1_cos
.iter()
.zip(x2_sin.iter())
.map(|(x1, x2)| x1 - x2)
.collect::<Vec<f32>>();
r1.extend(vec![0.0; shape.numel() - r1.len()]);

let x1_sin = x1
.iter()
.zip(sin.iter().cycle())
.map(|(x, s)| x * s)
.collect::<Vec<f32>>();
let x2_cos = x2
.iter()
.zip(cos.iter().cycle())
.map(|(x, c)| x * c)
.collect::<Vec<f32>>();
let mut r2 = x1_sin
.iter()
.zip(x2_cos.iter())
.map(|(x1, x2)| x1 + x2)
.collect::<Vec<f32>>();
r2.extend(vec![0.0; shape.numel() - r2.len()]);

let mut to_cat = vec![
(shape![batches, num_heads, seq_len, half_dim], r1),
(shape![batches, num_heads, seq_len, half_dim], r2),
];
if dim < shape[3] {
let r3 = slice(
&src,
&src_strides,
&[0, 0, 0, dim],
&[batches, num_heads, seq_len, head_dim],
);
to_cat.push((shape![batches, num_heads, seq_len, head_dim - dim], r3));
}

let dst_shape = shape![batches, num_heads, seq_len, head_dim];
let mut dst = vec![0.0f32; dst_shape.numel()];
concat(to_cat.as_slice(), 3, &dst_shape, &mut dst)?;
Ok(dst)
}
3 changes: 3 additions & 0 deletions crates/ratchet-core/src/cpu/slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
use crate::{Slice, Tensor};

pub fn cpu_slice(op: Slice, dst: Tensor) -> Result<Tensor, OperationError> {}
6 changes: 6 additions & 0 deletions crates/ratchet-core/src/cpu/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
use crate::{CPUBuffer, Storage, Tensor};
use bytemuck::NoUninit;

pub fn cpu_store_result<T: NoUninit>(dst: &Tensor, data: &[T]) {
dst.update_storage(Storage::CPU(CPUBuffer::from_slice(data, dst.shape())));
}
Loading

0 comments on commit a0bd0f1

Please sign in to comment.