Skip to content

Commit

Permalink
Merge pull request #487 from robertknight/scaled-mul-fusion-only
Browse files Browse the repository at this point in the history
Fuse MatMul + Mul/Div by constant
  • Loading branch information
robertknight authored Dec 26, 2024
2 parents ac4b5d1 + 6654523 commit 7f74d0c
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 36 deletions.
93 changes: 67 additions & 26 deletions src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ pub fn matmul<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
where
GemmExecutor<LhsT, RhsT, OutT>: Default,
{
matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, None)
matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, None, None)
}

fn matmul_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
Expand All @@ -150,6 +150,7 @@ fn matmul_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
packed_b: Option<&PackedBMatrix<RhsT>>,
strategy: MatmulStrategy,
bias: Option<BiasVector<OutT>>,
alpha: Option<f32>,
) -> Result<Tensor<OutT>, OpError>
where
GemmExecutor<LhsT, RhsT, OutT>: Default,
Expand Down Expand Up @@ -202,7 +203,15 @@ where
// nb. We assume `a` is likely already contiguous, so this will be cheap.
let a_contig = a.to_contiguous_in(pool).auto_return(pool);
let a_matrix = a_contig.reshaped([num_a_matrices * a_rows, a_cols].as_slice());
let mut output = matmul_impl(pool, a_matrix.view(), b.clone(), packed_b, strategy, bias)?;
let mut output = matmul_impl(
pool,
a_matrix.view(),
b.clone(),
packed_b,
strategy,
bias,
alpha,
)?;
output.reshape(out_shape);
return Ok(output);
}
Expand Down Expand Up @@ -273,7 +282,7 @@ where
out_row_stride,
a_input,
b_input,
1., // alpha
alpha.unwrap_or(1.),
bias,
);
});
Expand Down Expand Up @@ -322,26 +331,30 @@ impl Operator for MatMul {
}
}

pub fn matmul_add<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
pub fn matmul_fused<LhsT: GemmInT, RhsT: GemmInT, OutT: Default + GemmOutT>(
pool: &TensorPool,
a: TensorView<LhsT>,
b: TensorView<RhsT>,
packed_b: Option<&PackedBMatrix<RhsT>>,
bias: BiasVector<OutT>,
bias: Option<BiasVector<OutT>>,
alpha: Option<f32>,
) -> Result<Tensor<OutT>, OpError>
where
GemmExecutor<LhsT, RhsT, OutT>: Default,
{
matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, Some(bias))
matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, bias, alpha)
}

/// Fusion for `Add(MatMul(a, b), bias)` subgraphs, where `bias` is a vector.
/// MatMul with fused addition of bias and scaling of result.
#[derive(Clone, Debug)]
pub struct MatMulAdd {}
pub struct FusedMatMul {
/// Scaling factor to apply to result of matrix multiplication. Defaults to 1.
pub alpha: Option<f32>,
}

impl Operator for MatMulAdd {
impl Operator for FusedMatMul {
fn name(&self) -> &str {
"MatMulAdd"
"FusedMatMul"
}

fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
Expand All @@ -352,10 +365,14 @@ impl Operator for MatMulAdd {
_ => None,
};

let bias = inputs.require_as(2)?;
let bias = static_dims!(bias, 1, "N")?.to_contiguous_in(pool);
let bias = inputs
.get_as::<f32>(2)?
.map(|bias| static_dims!(bias, 1, "N"))
.transpose()?
.map(|b| b.to_contiguous_in(pool));
let bias = bias.as_ref().map(|b| BiasVector::Row(b.data().unwrap()));

matmul_add(pool, a, b, packed_b, BiasVector::Row(bias.data().unwrap())).into_op_result()
matmul_fused(pool, a, b, packed_b, bias, self.alpha).into_op_result()
}

fn prepack_inputs(&self) -> SmallVec<[usize; 1]> {
Expand Down Expand Up @@ -508,7 +525,7 @@ mod tests {
use crate::tensor_pool::AutoReturn;

use super::{
gemm_op, matmul, matmul_add, matmul_impl, matmul_integer, MatMul, MatMulAdd,
gemm_op, matmul, matmul_fused, matmul_impl, matmul_integer, FusedMatMul, MatMul,
MatmulStrategy, OpError,
};

Expand All @@ -533,6 +550,7 @@ mod tests {
mut a: TensorView,
mut b: TensorView,
bias: Option<BiasVector<f32>>,
alpha: Option<f32>,
) {
// Expand vector inputs to matrices. This follows the rules of
// `numpy.matmul`.
Expand Down Expand Up @@ -576,7 +594,7 @@ mod tests {
c_row_stride,
GemmInputA::Unpacked(a),
GemmInputB::Unpacked(b),
1., /* alpha */
alpha.unwrap_or(1.),
0., /* beta */
bias,
)
Expand Down Expand Up @@ -822,7 +840,7 @@ mod tests {

let mut expected = Tensor::zeros(out_shape);

reference_matmul(expected.view_mut(), a.view(), b.view(), None);
reference_matmul(expected.view_mut(), a.view(), b.view(), None, None);
let result = matmul(&pool, a.view(), b.view(), None).unwrap();
expect_equal(&result, &expected)?;
}
Expand All @@ -843,7 +861,7 @@ mod tests {
bias_input: false,
},
Case {
op: Box::new(MatMulAdd {}),
op: Box::new(FusedMatMul { alpha: None }),
bias_input: true,
},
];
Expand All @@ -865,7 +883,13 @@ mod tests {
let packed_b = op.prepack(1, packed_b_input.view().into()).unwrap();

let mut expected = Tensor::zeros(&[5, 3]);
reference_matmul(expected.view_mut(), a.view(), packed_b_input.view(), None);
reference_matmul(
expected.view_mut(),
a.view(),
packed_b_input.view(),
None,
None,
);

let pool = new_pool();
let get_prepacked = |idx| {
Expand All @@ -891,18 +915,35 @@ mod tests {
}

#[test]
fn test_matmul_add() -> Result<(), Box<dyn Error>> {
fn test_matmul_fused() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(1234);
let a = Tensor::rand(&[10, 15], &mut rng);
let b = Tensor::rand(&[15, 5], &mut rng);
let bias_data: Vec<f32> = (0..b.size(b.ndim() - 1)).map(|_| rng.next_f32()).collect();
let bias = Some(BiasVector::Row(&bias_data));

let pool = new_pool();
let mut expected = Tensor::zeros(&[10, 5]);
reference_matmul(expected.view_mut(), a.view(), b.view(), bias.clone());
let result = matmul_add(&pool, a.view(), b.view(), None, bias.unwrap()).unwrap();
expect_equal(&result, &expected)?;
struct Case<'a> {
bias: Option<BiasVector<'a, f32>>,
alpha: Option<f32>,
}

let cases = [
Case {
bias: Some(BiasVector::Row(&bias_data)),
alpha: None,
},
Case {
bias: None,
alpha: Some(0.5),
},
];

for Case { bias, alpha } in cases {
let pool = new_pool();
let mut expected = Tensor::zeros(&[10, 5]);
reference_matmul(expected.view_mut(), a.view(), b.view(), bias.clone(), alpha);
let result = matmul_fused(&pool, a.view(), b.view(), None, bias, alpha).unwrap();
expect_equal(&result, &expected)?;
}

Ok(())
}
Expand Down Expand Up @@ -1153,7 +1194,7 @@ mod tests {
);
let pool = new_pool();
run_bench(trials, Some(&desc), || {
matmul_impl(&pool, a.view(), b.view(), None, strategy, None)
matmul_impl(&pool, a.view(), b.view(), None, strategy, None, None)
.unwrap()
.auto_return(&pool);
});
Expand Down
2 changes: 1 addition & 1 deletion src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ pub use layout::{
depth_to_space, expand, flatten, reshape, squeeze, squeeze_in_place, DepthToSpace,
DepthToSpaceMode, Expand, Flatten, Reshape, Shape, Size, Squeeze, Transpose, Unsqueeze,
};
pub use matmul::{gemm_op, matmul, Gemm, MatMul, MatMulAdd, MatMulInteger};
pub use matmul::{gemm_op, matmul, FusedMatMul, Gemm, MatMul, MatMulInteger};
pub use non_max_suppression::{non_max_suppression, BoxOrder, NonMaxSuppression};
pub use norm::{
batch_norm, batch_norm_in_place, instance_normalization, layer_normalization, log_softmax,
Expand Down
Loading

0 comments on commit 7f74d0c

Please sign in to comment.