Skip to content

Commit

Permalink
Support fusing LayerNormalization without bias
Browse files Browse the repository at this point in the history
Support matching and fusing LayerNormalization sub-graphs that do not have a
final `Add` operation to add a bias. The LayerNormalization operator already
supports this case, so we just have to recognize it in the optimizer.

Tested with the [ModernBERT](https://huggingface.co/answerdotai/ModernBERT-base)
ONNX model.
robertknight committed Dec 19, 2024
1 parent 1fd414f commit 9845034
Showing 1 changed file with 57 additions and 19 deletions.
76 changes: 57 additions & 19 deletions src/optimize.rs
Original file line number Diff line number Diff line change
@@ -531,10 +531,11 @@ impl GraphOptimizer {
epsilon + unary_op_key("ReduceMean", binary_op("Pow", x.clone(), 2.0), "norm_mean"),
);

// Final step: Shift and scale the normalized values
// Final step: Scale, and optionally shift, the normalized values
let bias = const_symbol("bias");
let scale = const_symbol("scale");
let shift_scale_pat = (x.clone() * scale) + bias;
let shift_scale_pat = (x.clone() * scale.clone()) + bias;
let scale_pat = x.clone() * scale;

graph.apply_fusion(|graph, op_node_id, op_node| {
// Test if a node is a `ReduceMean` operator that reduces over its
@@ -562,10 +563,21 @@ impl GraphOptimizer {
_ => false,
};

let shift_scale_match = shift_scale_pat.test(op_node_id, graph.graph())?;
let shift_scale_input = shift_scale_match.resolved_symbol("x").unwrap();
let bias_input = shift_scale_match.resolved_symbol("bias").unwrap();
let scale_input = shift_scale_match.resolved_symbol("scale").unwrap();
let (shift_scale_input, bias_input, scale_input) =
if let Some(shift_scale_match) = shift_scale_pat.test(op_node_id, graph.graph()) {
// Found match for scale + bias.
let shift_scale_input = shift_scale_match.resolved_symbol("x").unwrap();
let bias_input = shift_scale_match.resolved_symbol("bias").unwrap();
let scale_input = shift_scale_match.resolved_symbol("scale").unwrap();
(shift_scale_input, Some(bias_input), scale_input)
} else if let Some(scale_match) = scale_pat.test(op_node_id, graph.graph()) {
// Found match for scale only.
let x_input = scale_match.resolved_symbol("x").unwrap();
let scale_input = scale_match.resolved_symbol("scale").unwrap();
(x_input, None, scale_input)
} else {
return None;
};

let norm_match = normalize_variance_pat.test(shift_scale_input, graph.graph())?;
let norm_input = norm_match.resolved_symbol("x").unwrap();
@@ -598,7 +610,7 @@ impl GraphOptimizer {
axis: -1,
epsilon: Some(epsilon),
},
vec![Some(center_input), Some(scale_input), Some(bias_input)],
vec![Some(center_input), Some(scale_input), bias_input],
op_output,
))
});
@@ -844,7 +856,7 @@ mod tests {
assert_eq!(op.name(), Some("mul_half"));
}

fn layer_norm_graph() -> Graph {
fn layer_norm_graph(with_bias: bool) -> Graph {
let mut graph = Graph::new();
let input = graph.add_value(None, None, None);

@@ -876,27 +888,53 @@ mod tests {
let (_, div_out) = graph.add_simple_op("div", Div {}, &[sub_out, sqrt_out]);

// Shift and scale
let bias = graph.add_constant(None, Tensor::from([1., 2., 3.]));
let scale = graph.add_constant(None, Tensor::from([3., 4., 5.]));
let (_, mul_out) = graph.add_simple_op("mul", Mul {}, &[div_out, scale]);
let (_, add_out) = graph.add_simple_op("final_add", Add {}, &[mul_out, bias]);

graph.set_input_ids(&[input]);
graph.set_output_ids(&[add_out]);
if with_bias {
let bias = graph.add_constant(None, Tensor::from([1., 2., 3.]));
let (_, add_out) = graph.add_simple_op("final_add", Add {}, &[mul_out, bias]);
graph.set_output_ids(&[add_out]);
} else {
graph.set_output_ids(&[mul_out]);
}

graph.set_input_ids(&[input]);
graph
}

#[test]
fn test_fuse_layer_norm() {
let graph = layer_norm_graph();
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "LayerNormalization");
assert_eq!(op.name(), Some("final_add"));
struct Case<'a> {
with_bias: bool,
output_name: &'a str,
}

let cases = [
Case {
with_bias: true,
output_name: "final_add",
},
Case {
with_bias: false,
output_name: "mul",
},
];

let layer_norm = op.operator().downcast_ref::<LayerNormalization>().unwrap();
assert_eq!(layer_norm.epsilon, Some(1e-6));
for Case {
with_bias,
output_name,
} in cases
{
let graph = layer_norm_graph(with_bias);
let graph = optimize_graph(graph).unwrap();
let (_, op) = graph.get_source_node(graph.output_ids()[0]).unwrap();
assert_eq!(op.operator().name(), "LayerNormalization");
assert_eq!(op.name(), Some(output_name));

let layer_norm = op.operator().downcast_ref::<LayerNormalization>().unwrap();
assert_eq!(layer_norm.epsilon, Some(1e-6));
}
}

#[test]

0 comments on commit 9845034

Please sign in to comment.