Skip to content

Commit

Permalink
Merge pull request #489 from robertknight/generalize-matmul-scale
Browse files Browse the repository at this point in the history
Generalize scaled MatMul fusion
  • Loading branch information
robertknight authored Dec 27, 2024
2 parents 7f74d0c + ffe4c59 commit a33e0db
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 49 deletions.
6 changes: 6 additions & 0 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub enum Dimension {
Symbolic(String),
}

#[derive(Debug)]
pub struct OperatorNode {
name: Option<String>,
inputs: Vec<Option<NodeId>>,
Expand Down Expand Up @@ -87,6 +88,7 @@ impl OperatorNode {
}
}

#[derive(Debug)]
pub struct ValueNode {
name: Option<String>,
shape: Option<Vec<Dimension>>,
Expand All @@ -100,6 +102,7 @@ impl ValueNode {
}

/// Data for a constant node (ie. model weights) in a [`Graph`].
#[derive(Debug)]
pub enum ConstantNodeData<T> {
Owned(Tensor<T>),
Arc(ArcTensorView<T>),
Expand All @@ -126,6 +129,7 @@ impl<T> From<ArcTensorView<T>> for ConstantNodeData<T> {
}
}

#[derive(Debug)]
pub struct ConstantNode<T> {
name: Option<String>,
data: ConstantNodeData<T>,
Expand Down Expand Up @@ -155,6 +159,7 @@ impl<T> ConstantNode<T> {
}
}

#[derive(Debug)]
pub enum Constant {
Float(ConstantNode<f32>),
Int32(ConstantNode<i32>),
Expand Down Expand Up @@ -268,6 +273,7 @@ impl_typed_constant!(i32, Int32);
impl_typed_constant!(i8, Int8);
impl_typed_constant!(u8, UInt8);

#[derive(Debug)]
pub enum Node {
Operator(OperatorNode),
Constant(Constant),
Expand Down
138 changes: 89 additions & 49 deletions src/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::Output;

mod pattern_matcher;

use pattern_matcher::{binary_op, const_symbol, symbol, unary_op, unary_op_key, Match};
use pattern_matcher::{binary_op, const_symbol, symbol, unary_op, unary_op_key};

/// Errors that occur while applying graph optimizations.
#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -494,62 +494,100 @@ impl GraphOptimizer {
/// and e are constants can be rewritten as `FusedMatMul(X, Y, alpha=c * d *
/// e)`. Each `Mul(X, c)` can also be expressed as `Div(X, 1/c)`.
fn fuse_matmul_scaled(&self, graph: &mut GraphMutator) -> Result<(), OptimizeError> {
let x = symbol("x");
let y = symbol("y");

let c = const_symbol("c");
let d = const_symbol("d");
let binary_op_input_ids = |op: &OperatorNode| -> Option<[NodeId; 2]> {
match op.input_ids() {
[Some(lhs_id), Some(rhs_id)] => Some([*lhs_id, *rhs_id]),
_ => None,
}
};

// We currently recognize two common patterns of specifying a scaled
// matmul in the graph, but there are many other permuations (eg.
// only one input scaled, Muls swapped for Divs and vice versa) that
// ideally this transform would recognize.
//
// A more complex situation is where the Mul or Div operand is constant
// in practice, but computed from a value's shape at runtime. Consider
// standard scaled dot product attention:
//
// Attention(Q, K, V) = Softmax(MatMul(Q, Transpose(K)) / Sqrt(dK))
//
// `Sqrt(dK)` is always constant at runtime, but depending on how the
// model was exported may be computed from `Shape(K)` output. Such
// cases are not currently handled.
let get_const_scalar = |graph: &Graph, node_id: NodeId| -> Option<f32> {
graph.get_node(node_id).and_then(|node| match node {
Node::Constant(const_node) => const_node.as_scalar(),
_ => None,
})
};

// MatMul(Mul(X, c), Mul(Y, d))
let matmul_mul_pat = binary_op("MatMul", c.clone() * x.clone(), d.clone() * y.clone());
// Test if `op_node` is a Mul or Div node with one constant scalar
// input and one non-constant input. If so, returns the constant scalar
// which the node multiplies the other input by and the ID of the other
// input.
let get_scale_factor = |graph: &Graph, op_node: &OperatorNode| -> Option<(f32, NodeId)> {
let op_type = op_node.operator().name();
if !["Mul", "Div"].contains(&op_type) {
return None;
}

// Div(MatMul(X), c)
let div_matmul_pat = binary_op("MatMul", x.clone(), y.clone()) / c.clone();
let [lhs, rhs] = binary_op_input_ids(op_node)?;
let lhs_scalar = get_const_scalar(graph, lhs);
let rhs_scalar = get_const_scalar(graph, rhs);

graph.apply_fusion(|graph, op_node_id, op_node| {
let get_scalar = |match_: &Match, name: &str| -> Option<f32> {
let scalar_node_id = match_.resolved_symbol(name)?;
match graph.graph().get_node(scalar_node_id) {
Some(Node::Constant(const_node)) => const_node.as_scalar(),
match op_type {
"Mul" => match (lhs_scalar, rhs_scalar) {
(Some(lhs_scale), None) => Some((lhs_scale, rhs)),
(None, Some(rhs_scale)) => Some((rhs_scale, lhs)),
_ => None,
}
};
},
"Div" => match (lhs_scalar, rhs_scalar) {
(None, Some(rhs_scale)) => Some((1. / rhs_scale, lhs)),
_ => None,
},
_ => None,
}
};

let (alpha, match_) =
if let Some(match_) = matmul_mul_pat.test(op_node_id, graph.graph()) {
let c = get_scalar(&match_, "c")?;
let d = get_scalar(&match_, "d")?;
(c * d, match_)
} else if let Some(match_) = div_matmul_pat.test(op_node_id, graph.graph()) {
let c = get_scalar(&match_, "c")?;
(1. / c, match_)
} else {
return None;
};
graph.apply_fusion(|graph, _op_node_id, op_node| {
// Accumulated scale factor from scalings applied to MatMul inputs
// and outputs.
let mut alpha = 1.0;

let x_input = match_.resolved_symbol("x").unwrap();
let y_input = match_.resolved_symbol("y").unwrap();
let op_output = op_node.output_id()?;
let graph = graph.graph();

// Check if this is a Mul/Div node scaling the output of a MatMul.
let matmul_node = if ["Mul", "Div"].contains(&op_node.operator().name()) {
let (output_scale, scale_input) = get_scale_factor(graph, op_node)?;
alpha *= output_scale;
let (_, scale_input_op) = graph.get_source_node(scale_input)?;
scale_input_op
} else {
op_node
};

if matmul_node.operator().name() != "MatMul" {
return None;
}

let [matmul_lhs, matmul_rhs] = binary_op_input_ids(matmul_node)?;
let lhs_input = if let Some((_, lhs_source_op)) = graph.get_source_node(matmul_lhs) {
let (lhs_scale, lhs_input) =
get_scale_factor(graph, lhs_source_op).unwrap_or((1.0, matmul_lhs));
alpha *= lhs_scale;
lhs_input
} else {
// MatMul LHS is not computed by an upstream operator.
matmul_lhs
};

let rhs_input = if let Some((_, rhs_source_op)) = graph.get_source_node(matmul_rhs) {
let (rhs_scale, rhs_input) =
get_scale_factor(graph, rhs_source_op).unwrap_or((1.0, matmul_rhs));
alpha *= rhs_scale;
rhs_input
} else {
// MatMul RHS is not computed by an upstream operator.
matmul_rhs
};

if alpha == 1.0 {
// Scale factor of 1 has no effect.
return None;
}

Some(Fusion::from_op(
op_node.name(),
matmul_node.name(),
FusedMatMul { alpha: Some(alpha) },
[Some(x_input), Some(y_input)].into(),
[Some(lhs_input), Some(rhs_input)].into(),
op_output,
))
});
Expand Down Expand Up @@ -881,7 +919,8 @@ mod tests {

#[test]
fn test_fuse_matmul_scaled() {
// Pattern 1: MatMul(Mul(A, c), Mul(B, d))
// Pattern 1: MatMul(Mul(A, c), Mul(B, d)). This has scale applied to
// inputs via `Mul` ops.
let mut graph = Graph::new();
let a = graph.add_value(None, None, None);
let b = graph.add_value(None, None, None);
Expand All @@ -901,7 +940,8 @@ mod tests {
let fused_matmul_op = op.operator().downcast_ref::<FusedMatMul>().unwrap();
assert_eq!(fused_matmul_op.alpha, Some(0.5 * 0.3));

// Pattern 2: Div(MatMul(A, B), c)
// Pattern 2: Div(MatMul(A, B), c). This has scale applied to outputs
// via `Div` ops.
let mut graph = Graph::new();
let a = graph.add_value(None, None, None);
let b = graph.add_value(None, None, None);
Expand All @@ -915,7 +955,7 @@ mod tests {

let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "FusedMatMul");
assert_eq!(op.name(), Some("div"));
assert_eq!(op.name(), Some("matmul"));
let fused_matmul_op = op.operator().downcast_ref::<FusedMatMul>().unwrap();
assert_eq!(fused_matmul_op.alpha, Some(1. / 0.5));
}
Expand Down

0 comments on commit a33e0db

Please sign in to comment.