Skip to content

Commit

Permalink
Fuse MatMul + Mul/Div by constant
Browse files Browse the repository at this point in the history
A subgraph of the form `Mul(MatMul(Mul(A, c), Mul(B, d)), e)` where all of the
`Mul`s are optional and c, d and e are constants can be rewritten as `MatMul(A,
B, alpha = c * d * e)` where `alpha` is the scaling already handled by the
`C = alpha * AB + beta * C` result that GEMM already computes. Such
scaling is common in transformers as part of attention operations (SDPA).

The initial implementation only handles two specific cases of this form which
have been seen in real models. This should be generalized to all possible cases.
  • Loading branch information
robertknight committed Dec 26, 2024
1 parent ac4b5d1 commit 6654523
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 36 deletions.
93 changes: 67 additions & 26 deletions src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ pub fn matmul<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
where
GemmExecutor<LhsT, RhsT, OutT>: Default,
{
matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, None)
matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, None, None)
}

fn matmul_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
Expand All @@ -150,6 +150,7 @@ fn matmul_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
packed_b: Option<&PackedBMatrix<RhsT>>,
strategy: MatmulStrategy,
bias: Option<BiasVector<OutT>>,
alpha: Option<f32>,
) -> Result<Tensor<OutT>, OpError>
where
GemmExecutor<LhsT, RhsT, OutT>: Default,
Expand Down Expand Up @@ -202,7 +203,15 @@ where
// nb. We assume `a` is likely already contiguous, so this will be cheap.
let a_contig = a.to_contiguous_in(pool).auto_return(pool);
let a_matrix = a_contig.reshaped([num_a_matrices * a_rows, a_cols].as_slice());
let mut output = matmul_impl(pool, a_matrix.view(), b.clone(), packed_b, strategy, bias)?;
let mut output = matmul_impl(
pool,
a_matrix.view(),
b.clone(),
packed_b,
strategy,
bias,
alpha,
)?;
output.reshape(out_shape);
return Ok(output);
}
Expand Down Expand Up @@ -273,7 +282,7 @@ where
out_row_stride,
a_input,
b_input,
1., // alpha
alpha.unwrap_or(1.),
bias,
);
});
Expand Down Expand Up @@ -322,26 +331,30 @@ impl Operator for MatMul {
}
}

pub fn matmul_add<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
pub fn matmul_fused<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
pool: &TensorPool,
a: TensorView<LhsT>,
b: TensorView<RhsT>,
packed_b: Option<&PackedBMatrix<RhsT>>,
bias: BiasVector<OutT>,
bias: Option<BiasVector<OutT>>,
alpha: Option<f32>,
) -> Result<Tensor<OutT>, OpError>
where
GemmExecutor<LhsT, RhsT, OutT>: Default,
{
matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, Some(bias))
matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, bias, alpha)
}

/// Fusion for `Add(MatMul(a, b), bias)` subgraphs, where `bias` is a vector.
/// MatMul with fused addition of bias and scaling of result.
#[derive(Clone, Debug)]
pub struct MatMulAdd {}
pub struct FusedMatMul {
/// Scaling factor to apply to result of matrix multiplication. Defaults to 1.
pub alpha: Option<f32>,
}

impl Operator for MatMulAdd {
impl Operator for FusedMatMul {
fn name(&self) -> &str {
"MatMulAdd"
"FusedMatMul"
}

fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
Expand All @@ -352,10 +365,14 @@ impl Operator for MatMulAdd {
_ => None,
};

let bias = inputs.require_as(2)?;
let bias = static_dims!(bias, 1, "N")?.to_contiguous_in(pool);
let bias = inputs
.get_as::<f32>(2)?
.map(|bias| static_dims!(bias, 1, "N"))
.transpose()?
.map(|b| b.to_contiguous_in(pool));
let bias = bias.as_ref().map(|b| BiasVector::Row(b.data().unwrap()));

matmul_add(pool, a, b, packed_b, BiasVector::Row(bias.data().unwrap())).into_op_result()
matmul_fused(pool, a, b, packed_b, bias, self.alpha).into_op_result()
}

fn prepack_inputs(&self) -> SmallVec<[usize; 1]> {
Expand Down Expand Up @@ -508,7 +525,7 @@ mod tests {
use crate::tensor_pool::AutoReturn;

use super::{
gemm_op, matmul, matmul_add, matmul_impl, matmul_integer, MatMul, MatMulAdd,
gemm_op, matmul, matmul_fused, matmul_impl, matmul_integer, FusedMatMul, MatMul,
MatmulStrategy, OpError,
};

Expand All @@ -533,6 +550,7 @@ mod tests {
mut a: TensorView,
mut b: TensorView,
bias: Option<BiasVector<f32>>,
alpha: Option<f32>,
) {
// Expand vector inputs to matrices. This follows the rules of
// `numpy.matmul`.
Expand Down Expand Up @@ -576,7 +594,7 @@ mod tests {
c_row_stride,
GemmInputA::Unpacked(a),
GemmInputB::Unpacked(b),
1., /* alpha */
alpha.unwrap_or(1.),
0., /* beta */
bias,
)
Expand Down Expand Up @@ -822,7 +840,7 @@ mod tests {

let mut expected = Tensor::zeros(out_shape);

reference_matmul(expected.view_mut(), a.view(), b.view(), None);
reference_matmul(expected.view_mut(), a.view(), b.view(), None, None);
let result = matmul(&pool, a.view(), b.view(), None).unwrap();
expect_equal(&result, &expected)?;
}
Expand All @@ -843,7 +861,7 @@ mod tests {
bias_input: false,
},
Case {
op: Box::new(MatMulAdd {}),
op: Box::new(FusedMatMul { alpha: None }),
bias_input: true,
},
];
Expand All @@ -865,7 +883,13 @@ mod tests {
let packed_b = op.prepack(1, packed_b_input.view().into()).unwrap();

let mut expected = Tensor::zeros(&[5, 3]);
reference_matmul(expected.view_mut(), a.view(), packed_b_input.view(), None);
reference_matmul(
expected.view_mut(),
a.view(),
packed_b_input.view(),
None,
None,
);

let pool = new_pool();
let get_prepacked = |idx| {
Expand All @@ -891,18 +915,35 @@ mod tests {
}

#[test]
fn test_matmul_add() -> Result<(), Box<dyn Error>> {
fn test_matmul_fused() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[10, 15], &mut rng);
let b = Tensor::rand(&[15, 5], &mut rng);
let bias_data: Vec<f32> = (0..b.size(b.ndim() - 1)).map(|_| rng.next_f32()).collect();
let bias = Some(BiasVector::Row(&bias_data));

let pool = new_pool();
let mut expected = Tensor::zeros(&[10, 5]);
reference_matmul(expected.view_mut(), a.view(), b.view(), bias.clone());
let result = matmul_add(&pool, a.view(), b.view(), None, bias.unwrap()).unwrap();
expect_equal(&result, &expected)?;
struct Case<'a> {
bias: Option<BiasVector<'a, f32>>,
alpha: Option<f32>,
}

let cases = [
Case {
bias: Some(BiasVector::Row(&bias_data)),
alpha: None,
},
Case {
bias: None,
alpha: Some(0.5),
},
];

for Case { bias, alpha } in cases {
let pool = new_pool();
let mut expected = Tensor::zeros(&[10, 5]);
reference_matmul(expected.view_mut(), a.view(), b.view(), bias.clone(), alpha);
let result = matmul_fused(&pool, a.view(), b.view(), None, bias, alpha).unwrap();
expect_equal(&result, &expected)?;
}

Ok(())
}
Expand Down Expand Up @@ -1153,7 +1194,7 @@ mod tests {
);
let pool = new_pool();
run_bench(trials, Some(&desc), || {
matmul_impl(&pool, a.view(), b.view(), None, strategy, None)
matmul_impl(&pool, a.view(), b.view(), None, strategy, None, None)
.unwrap()
.auto_return(&pool);
});
Expand Down
2 changes: 1 addition & 1 deletion src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ pub use layout::{
depth_to_space, expand, flatten, reshape, squeeze, squeeze_in_place, DepthToSpace,
DepthToSpaceMode, Expand, Flatten, Reshape, Shape, Size, Squeeze, Transpose, Unsqueeze,
};
pub use matmul::{gemm_op, matmul, Gemm, MatMul, MatMulAdd, MatMulInteger};
pub use matmul::{gemm_op, matmul, FusedMatMul, Gemm, MatMul, MatMulInteger};
pub use non_max_suppression::{non_max_suppression, BoxOrder, NonMaxSuppression};
pub use norm::{
batch_norm, batch_norm_in_place, instance_normalization, layer_normalization, log_softmax,
Expand Down
Loading

0 comments on commit 6654523

Please sign in to comment.