diff --git a/src/ops/matmul.rs b/src/ops/matmul.rs index 16f64128..c3b4f0ed 100644 --- a/src/ops/matmul.rs +++ b/src/ops/matmul.rs @@ -140,7 +140,7 @@ pub fn matmul( where GemmExecutor: Default, { - matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, None) + matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, None, None) } fn matmul_impl( @@ -150,6 +150,7 @@ fn matmul_impl( packed_b: Option<&PackedBMatrix>, strategy: MatmulStrategy, bias: Option>, + alpha: Option, ) -> Result, OpError> where GemmExecutor: Default, @@ -202,7 +203,15 @@ where // nb. We assume `a` is likely already contiguous, so this will be cheap. let a_contig = a.to_contiguous_in(pool).auto_return(pool); let a_matrix = a_contig.reshaped([num_a_matrices * a_rows, a_cols].as_slice()); - let mut output = matmul_impl(pool, a_matrix.view(), b.clone(), packed_b, strategy, bias)?; + let mut output = matmul_impl( + pool, + a_matrix.view(), + b.clone(), + packed_b, + strategy, + bias, + alpha, + )?; output.reshape(out_shape); return Ok(output); } @@ -273,7 +282,7 @@ where out_row_stride, a_input, b_input, - 1., // alpha + alpha.unwrap_or(1.), bias, ); }); @@ -322,26 +331,30 @@ impl Operator for MatMul { } } -pub fn matmul_add( +pub fn matmul_fused( pool: &TensorPool, a: TensorView, b: TensorView, packed_b: Option<&PackedBMatrix>, - bias: BiasVector, + bias: Option>, + alpha: Option, ) -> Result, OpError> where GemmExecutor: Default, { - matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, Some(bias)) + matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, bias, alpha) } -/// Fusion for `Add(MatMul(a, b), bias)` subgraphs, where `bias` is a vector. +/// MatMul with fused addition of bias and scaling of result. #[derive(Clone, Debug)] -pub struct MatMulAdd {} +pub struct FusedMatMul { + /// Scaling factor to apply to result of matrix multiplication. Defaults to 1. + pub alpha: Option, +} -impl Operator for MatMulAdd { +impl Operator for FusedMatMul { fn name(&self) -> &str { - "MatMulAdd" + "FusedMatMul" } fn run(&self, pool: &TensorPool, inputs: InputList) -> Result { @@ -352,10 +365,14 @@ impl Operator for MatMulAdd { _ => None, }; - let bias = inputs.require_as(2)?; - let bias = static_dims!(bias, 1, "N")?.to_contiguous_in(pool); + let bias = inputs + .get_as::(2)? + .map(|bias| static_dims!(bias, 1, "N")) + .transpose()? + .map(|b| b.to_contiguous_in(pool)); + let bias = bias.as_ref().map(|b| BiasVector::Row(b.data().unwrap())); - matmul_add(pool, a, b, packed_b, BiasVector::Row(bias.data().unwrap())).into_op_result() + matmul_fused(pool, a, b, packed_b, bias, self.alpha).into_op_result() } fn prepack_inputs(&self) -> SmallVec<[usize; 1]> { @@ -508,7 +525,7 @@ mod tests { use crate::tensor_pool::AutoReturn; use super::{ - gemm_op, matmul, matmul_add, matmul_impl, matmul_integer, MatMul, MatMulAdd, + gemm_op, matmul, matmul_fused, matmul_impl, matmul_integer, FusedMatMul, MatMul, MatmulStrategy, OpError, }; @@ -533,6 +550,7 @@ mod tests { mut a: TensorView, mut b: TensorView, bias: Option>, + alpha: Option, ) { // Expand vector inputs to matrices. This follows the rules of // `numpy.matmul`. @@ -576,7 +594,7 @@ mod tests { c_row_stride, GemmInputA::Unpacked(a), GemmInputB::Unpacked(b), - 1., /* alpha */ + alpha.unwrap_or(1.), 0., /* beta */ bias, ) @@ -822,7 +840,7 @@ mod tests { let mut expected = Tensor::zeros(out_shape); - reference_matmul(expected.view_mut(), a.view(), b.view(), None); + reference_matmul(expected.view_mut(), a.view(), b.view(), None, None); let result = matmul(&pool, a.view(), b.view(), None).unwrap(); expect_equal(&result, &expected)?; } @@ -843,7 +861,7 @@ mod tests { bias_input: false, }, Case { - op: Box::new(MatMulAdd {}), + op: Box::new(FusedMatMul { alpha: None }), bias_input: true, }, ]; @@ -865,7 +883,13 @@ mod tests { let packed_b = op.prepack(1, packed_b_input.view().into()).unwrap(); let mut expected = Tensor::zeros(&[5, 3]); - reference_matmul(expected.view_mut(), a.view(), packed_b_input.view(), None); + reference_matmul( + expected.view_mut(), + a.view(), + packed_b_input.view(), + None, + None, + ); let pool = new_pool(); let get_prepacked = |idx| { @@ -891,18 +915,35 @@ mod tests { } #[test] - fn test_matmul_add() -> Result<(), Box> { + fn test_matmul_fused() -> Result<(), Box> { let mut rng = XorShiftRng::new(1234); let a = Tensor::rand(&[10, 15], &mut rng); let b = Tensor::rand(&[15, 5], &mut rng); let bias_data: Vec = (0..b.size(b.ndim() - 1)).map(|_| rng.next_f32()).collect(); - let bias = Some(BiasVector::Row(&bias_data)); - let pool = new_pool(); - let mut expected = Tensor::zeros(&[10, 5]); - reference_matmul(expected.view_mut(), a.view(), b.view(), bias.clone()); - let result = matmul_add(&pool, a.view(), b.view(), None, bias.unwrap()).unwrap(); - expect_equal(&result, &expected)?; + struct Case<'a> { + bias: Option>, + alpha: Option, + } + + let cases = [ + Case { + bias: Some(BiasVector::Row(&bias_data)), + alpha: None, + }, + Case { + bias: None, + alpha: Some(0.5), + }, + ]; + + for Case { bias, alpha } in cases { + let pool = new_pool(); + let mut expected = Tensor::zeros(&[10, 5]); + reference_matmul(expected.view_mut(), a.view(), b.view(), bias.clone(), alpha); + let result = matmul_fused(&pool, a.view(), b.view(), None, bias, alpha).unwrap(); + expect_equal(&result, &expected)?; + } Ok(()) } @@ -1153,7 +1194,7 @@ mod tests { ); let pool = new_pool(); run_bench(trials, Some(&desc), || { - matmul_impl(&pool, a.view(), b.view(), None, strategy, None) + matmul_impl(&pool, a.view(), b.view(), None, strategy, None, None) .unwrap() .auto_return(&pool); }); diff --git a/src/ops/mod.rs b/src/ops/mod.rs index 99cd1b6c..f9bf1bd6 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -84,7 +84,7 @@ pub use layout::{ depth_to_space, expand, flatten, reshape, squeeze, squeeze_in_place, DepthToSpace, DepthToSpaceMode, Expand, Flatten, Reshape, Shape, Size, Squeeze, Transpose, Unsqueeze, }; -pub use matmul::{gemm_op, matmul, Gemm, MatMul, MatMulAdd, MatMulInteger}; +pub use matmul::{gemm_op, matmul, FusedMatMul, Gemm, MatMul, MatMulInteger}; pub use non_max_suppression::{non_max_suppression, BoxOrder, NonMaxSuppression}; pub use norm::{ batch_norm, batch_norm_in_place, instance_normalization, layer_normalization, log_softmax, diff --git a/src/optimize.rs b/src/optimize.rs index d2c3266c..a39562e3 100644 --- a/src/optimize.rs +++ b/src/optimize.rs @@ -9,12 +9,12 @@ use crate::graph::{ CaptureEnv, Constant, ConstantNode, Graph, Node, NodeId, OperatorNode, RunError, TypedConstant, }; use crate::ops::fused::FusedTranspose; -use crate::ops::{Gelu, LayerNormalization, MatMulAdd, Operator, ReduceMean, Silu, Transpose}; +use crate::ops::{FusedMatMul, Gelu, LayerNormalization, Operator, ReduceMean, Silu, Transpose}; use crate::Output; mod pattern_matcher; -use pattern_matcher::{binary_op, const_symbol, symbol, unary_op, unary_op_key}; +use pattern_matcher::{binary_op, const_symbol, symbol, unary_op, unary_op_key, Match}; /// Errors that occur while applying graph optimizations. #[derive(Debug, PartialEq)] @@ -294,11 +294,12 @@ impl GraphOptimizer { } self.propagate_constants(&mut graph_mut)?; - self.fuse_transpose(&mut graph_mut)?; self.fuse_silu(&mut graph_mut)?; self.fuse_gelu(&mut graph_mut)?; self.fuse_layer_norm(&mut graph_mut)?; self.fuse_matmul_add(&mut graph_mut)?; + self.fuse_matmul_scaled(&mut graph_mut)?; + self.fuse_transpose(&mut graph_mut)?; Ok(graph_mut.finalize_graph()) } @@ -389,7 +390,7 @@ impl GraphOptimizer { // Filter against a set of operators which are known to efficiently // handle transposed inputs. - if !["MatMul"].contains(&transpose_target.operator().name()) { + if !["MatMul", "FusedMatMul"].contains(&transpose_target.operator().name()) { return None; } @@ -477,8 +478,78 @@ impl GraphOptimizer { Some(Fusion::from_op( op_node.name(), - MatMulAdd {}, - vec![Some(a_input), Some(b_input), Some(bias_input)], + FusedMatMul { alpha: None }, + [Some(a_input), Some(b_input), Some(bias_input)].into(), + op_output, + )) + }); + + Ok(()) + } + + /// Fuse multiplication or division of MatMul inputs and outputs by + /// scalars. + /// + /// A subgraph of the form `Mul(MatMul(Mul(X, c), Mul(Y, d)), e)` where c, d + /// and e are constants can be rewritten as `FusedMatMul(X, Y, alpha=c * d * + /// e)`. Each `Mul(X, c)` can also be expressed as `Div(X, 1/c)`. + fn fuse_matmul_scaled(&self, graph: &mut GraphMutator) -> Result<(), OptimizeError> { + let x = symbol("x"); + let y = symbol("y"); + + let c = const_symbol("c"); + let d = const_symbol("d"); + + // We currently recognize two common patterns of specifying a scaled + // matmul in the graph, but there are many other permuations (eg. + // only one input scaled, Muls swapped for Divs and vice versa) that + // ideally this transform would recognize. + // + // A more complex situation is where the Mul or Div operand is constant + // in practice, but computed from a value's shape at runtime. Consider + // standard scaled dot product attention: + // + // Attention(Q, K, V) = Softmax(MatMul(Q, Transpose(K)) / Sqrt(dK)) + // + // `Sqrt(dK)` is always constant at runtime, but depending on how the + // model was exported may be computed from `Shape(K)` output. Such + // cases are not currently handled. + + // MatMul(Mul(X, c), Mul(Y, d)) + let matmul_mul_pat = binary_op("MatMul", c.clone() * x.clone(), d.clone() * y.clone()); + + // Div(MatMul(X), c) + let div_matmul_pat = binary_op("MatMul", x.clone(), y.clone()) / c.clone(); + + graph.apply_fusion(|graph, op_node_id, op_node| { + let get_scalar = |match_: &Match, name: &str| -> Option { + let scalar_node_id = match_.resolved_symbol(name)?; + match graph.graph().get_node(scalar_node_id) { + Some(Node::Constant(const_node)) => const_node.as_scalar(), + _ => None, + } + }; + + let (alpha, match_) = + if let Some(match_) = matmul_mul_pat.test(op_node_id, graph.graph()) { + let c = get_scalar(&match_, "c")?; + let d = get_scalar(&match_, "d")?; + (c * d, match_) + } else if let Some(match_) = div_matmul_pat.test(op_node_id, graph.graph()) { + let c = get_scalar(&match_, "c")?; + (1. / c, match_) + } else { + return None; + }; + + let x_input = match_.resolved_symbol("x").unwrap(); + let y_input = match_.resolved_symbol("y").unwrap(); + let op_output = op_node.output_id()?; + + Some(Fusion::from_op( + op_node.name(), + FusedMatMul { alpha: Some(alpha) }, + [Some(x_input), Some(y_input)].into(), op_output, )) }); @@ -637,8 +708,8 @@ mod tests { use crate::downcast::DowncastDyn; use crate::graph::{CaptureEnv, Constant, Graph, Node, NodeId}; use crate::ops::{ - Add, Div, Erf, LayerNormalization, MatMul, Mul, Pow, ReduceMean, Sigmoid, Sqrt, Sub, - Transpose, + Add, Div, Erf, FusedMatMul, LayerNormalization, MatMul, Mul, Pow, ReduceMean, Sigmoid, + Sqrt, Sub, Transpose, }; fn optimize_graph(graph: Graph) -> Result { @@ -804,10 +875,51 @@ mod tests { let graph = optimize_graph(graph).unwrap(); let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap(); - assert_eq!(op.operator().name(), "MatMulAdd"); + assert_eq!(op.operator().name(), "FusedMatMul"); assert_eq!(op.name(), Some("add")); } + #[test] + fn test_fuse_matmul_scaled() { + // Pattern 1: MatMul(Mul(A, c), Mul(B, d)) + let mut graph = Graph::new(); + let a = graph.add_value(None, None, None); + let b = graph.add_value(None, None, None); + let c = graph.add_constant(None, Tensor::from(0.5)); + let d = graph.add_constant(None, Tensor::from(0.3)); + let (_, mul_a_out) = graph.add_simple_op("scale-a", Mul {}, &[a, c]); + let (_, mul_b_out) = graph.add_simple_op("scale-b", Mul {}, &[b, d]); + let (_, matmul_out) = graph.add_simple_op("matmul", MatMul {}, &[mul_a_out, mul_b_out]); + graph.set_input_ids(&[a, b]); + graph.set_output_ids(&[matmul_out]); + + let graph = optimize_graph(graph).unwrap(); + + let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap(); + assert_eq!(op.operator().name(), "FusedMatMul"); + assert_eq!(op.name(), Some("matmul")); + let fused_matmul_op = op.operator().downcast_ref::().unwrap(); + assert_eq!(fused_matmul_op.alpha, Some(0.5 * 0.3)); + + // Pattern 2: Div(MatMul(A, B), c) + let mut graph = Graph::new(); + let a = graph.add_value(None, None, None); + let b = graph.add_value(None, None, None); + let c = graph.add_constant(None, Tensor::from(0.5)); + let (_, matmul_out) = graph.add_simple_op("matmul", MatMul {}, &[a, b]); + let (_, div_out) = graph.add_simple_op("div", Div {}, &[matmul_out, c]); + graph.set_input_ids(&[a, b]); + graph.set_output_ids(&[div_out]); + + let graph = optimize_graph(graph).unwrap(); + + let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap(); + assert_eq!(op.operator().name(), "FusedMatMul"); + assert_eq!(op.name(), Some("div")); + let fused_matmul_op = op.operator().downcast_ref::().unwrap(); + assert_eq!(fused_matmul_op.alpha, Some(1. / 0.5)); + } + #[test] fn test_chained_fused_ops() { let mut graph = Graph::new();