From 4e68b69f51177a49c4c06daa3ccafd7dd4813e85 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 22 Dec 2024 12:45:09 +0000 Subject: [PATCH 1/2] Support pre-packing weights after model optimization This reduces inference time for matmul operations at the cost of higher memory usage. - Add methods to `Operator` trait to declare which inputs can potentially be pre-packed and to prepack those inputs. - Add `Graph::prepack_weights` method to traverse operators and prepack inputs whose values are constant nodes. - Implement prepacking methods for MatMul and fused MatMul ops - Add APIs to `ModelOptions` to enable prepacking. There are some caveats: - Non-MatMul operations which use matmuls internally (Conv, ConvTranspose, LSTM, GRU etc.) currently don't prepack their weights. - MatMul operations which turn out to be matrix-vector (gemv) products don't use the prepacked weights. This affects transformer decoders doing non-batched generation after the initial prompt encoding step. --- src/gemm.rs | 10 ++ src/graph.rs | 264 ++++++++++++++++++++++++++++++++++++---- src/lib.rs | 1 + src/model.rs | 70 ++++++++++- src/ops/control_flow.rs | 4 + src/ops/einsum.rs | 8 +- src/ops/fused.rs | 11 +- src/ops/matmul.rs | 155 ++++++++++++++++++++--- src/ops/mod.rs | 49 ++++++++ src/ops/operators.rs | 2 +- src/wasm_api.rs | 2 +- src/weight_cache.rs | 71 +++++++++++ 12 files changed, 590 insertions(+), 57 deletions(-) create mode 100644 src/weight_cache.rs diff --git a/src/gemm.rs b/src/gemm.rs index abcf5734..e8f656b1 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -108,6 +108,16 @@ impl PackedBMatrix { fn panel_len(&self) -> usize { self.panel_width * self.rows } + + /// Number of rows in the unpacked matrix. + pub fn rows(&self) -> usize { + self.rows + } + + /// Number of columns in the unpacked matrix. + pub fn cols(&self) -> usize { + self.cols + } } impl ExtractBuffer for PackedBMatrix { diff --git a/src/graph.rs b/src/graph.rs index 45b66e8f..e3ed5818 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -19,10 +19,12 @@ use crate::constant_storage::ArcTensorView; use crate::env::env_flag; use crate::ops::{ DataType, Input, InputList, InputOrOutput, OpError, Operator, Output, OutputList, + PrepackedInput, }; use crate::tensor_pool::TensorPool; use crate::threading; use crate::timing::{InputShape, Instant, RunTiming, TimingRecord, TimingSort}; +use crate::weight_cache::WeightCache; /// Represents the size of a dimension of a runtime-provided value, such as /// an operator input, output or intermediate value. @@ -895,6 +897,56 @@ impl Graph { ) } + /// Pre-pack constant inputs (ie. weights) to operators. + /// + /// When loading models, prepacking should be performed after graph + /// optimization. There may be other nodes in between the weight constant + /// and the compute node, which would prevent prepacking. Graph optimization + /// can eliminate these. A common example is when weights are transposed. + pub fn prepack_weights(&self, cache: &mut WeightCache) { + for (op_node_id, op_node) in self.iter().filter_map(|(node_id, node)| match node { + Node::Operator(op) => Some((node_id, op)), + _ => None, + }) { + for input_index in op_node.operator.prepack_inputs() { + let Some(input_id) = op_node.input_ids().get(input_index).copied().flatten() else { + continue; + }; + + if cache.contains(input_id) { + // Input was already pre-packed. This might happen if the + // input is used by multiple operators. + continue; + } + + let Some(Node::Constant(const_node)) = self.get_node(input_id) else { + // Input is a value computed during inference, so we don't have it to prepack. + continue; + }; + + let Some(packed) = op_node.operator.prepack(input_index, const_node.as_input()) + else { + // Operator doesn't support or decided not to prepack this value. + continue; + }; + + cache.insert(input_id, packed); + } + + let subgraph_caches: Vec<_> = op_node + .operator + .subgraphs() + .into_iter() + .map(|subgraph| { + let mut subgraph_cache = WeightCache::new(); + subgraph.prepack_weights(&mut subgraph_cache); + subgraph_cache + }) + .collect(); + cache.insert_subgraph_caches(op_node_id, subgraph_caches); + } + } + /// Add a constant node to the graph. /// /// Returns the ID of the added node. @@ -1005,6 +1057,7 @@ impl Graph { &self, inputs: Vec<(NodeId, InputOrOutput)>, outputs: &[NodeId], + weight_cache: Option<&WeightCache>, opts: Option, ) -> Result, RunError> { let input_ids: Vec<_> = inputs.iter().map(|(node_id, _)| *node_id).collect(); @@ -1016,6 +1069,7 @@ impl Graph { outputs, None, /* captures */ None, /* pool */ + weight_cache, opts, ) }) @@ -1031,11 +1085,20 @@ impl Graph { outputs: &[NodeId], captures: CaptureEnv, pool: Option<&TensorPool>, + weight_cache: Option<&WeightCache>, opts: Option, ) -> Result, RunError> { let input_ids: Vec<_> = inputs.iter().map(|(node_id, _)| *node_id).collect(); let plan = self.get_cached_plan(&input_ids, outputs, true /* is_subgraph */)?; - self.run_plan(inputs, plan.plan(), outputs, Some(captures), pool, opts) + self.run_plan( + inputs, + plan.plan(), + outputs, + Some(captures), + pool, + weight_cache, + opts, + ) } fn get_cached_plan( @@ -1075,6 +1138,7 @@ impl Graph { outputs: &[NodeId], mut captures: Option, pool: Option<&TensorPool>, + weight_cache: Option<&WeightCache>, opts: Option, ) -> Result, RunError> { let opts = opts.unwrap_or_default(); @@ -1103,7 +1167,7 @@ impl Graph { Some(Node::Constant(constant)) => Some(constant.as_input()), Some(Node::Value(_)) => inputs_by_id.get(&node_id).map(|input| input.as_input()), _ => { - panic!("node is not a value or constant"); + panic!("node {} is not a value or constant", node_id); } } }; @@ -1303,12 +1367,23 @@ impl Graph { pool, InputList::from_optional(&op_inputs), capture_env, + weight_cache.and_then(|wc| wc.get_subgraph_caches(op_node_id)), Some(opts.clone()), ) } else { + let get_prepacked = |input_index: usize| -> Option<&PrepackedInput> { + op_node + .input_ids() + .get(input_index) + .copied() + .flatten() + .and_then(|node_id| weight_cache.and_then(|wc| wc.get(node_id))) + }; + let input_list = + InputList::from_optional(&op_inputs).with_prepacked(&get_prepacked); op_node .operator - .run(pool, InputList::from_optional(&op_inputs)) + .run(pool, input_list) .map_err(op_error_to_run_error) }; std::mem::drop(op_inputs); @@ -1490,6 +1565,7 @@ impl Graph { &pruned_plan_output_ids, None, /* captures */ None, /* pool */ + None, /* weight cache */ opts, ) })?; @@ -1824,10 +1900,11 @@ mod tests { use super::{CachedPlan, CaptureEnv}; use crate::graph::{Dimension, Graph, Node, NodeId, RunError, RunOptions, TypedConstant}; use crate::ops::{ - Add, Concat, Conv, DataType, Identity, If, InputList, IntoOpResult, Mul, OpError, Operator, - Output, OutputList, Relu, Shape, + Add, Concat, Conv, DataType, Identity, If, Input, InputList, IntoOpResult, MatMul, Mul, + OpError, Operator, Output, OutputList, PrepackedInput, Relu, Shape, }; use crate::tensor_pool::TensorPool; + use crate::weight_cache::WeightCache; #[derive(Clone, Debug, Default)] struct Metrics { @@ -1928,7 +2005,7 @@ mod tests { ); let results = g - .run(vec![(input_id, input.into())], &[relu_out], None) + .run(vec![(input_id, input.into())], &[relu_out], None, None) .unwrap(); let expected = Tensor::from_data( @@ -2076,13 +2153,18 @@ mod tests { let input = Tensor::from([1.]); let results = g - .run(vec![(input_id, input.view().into())], &[op_c_out], None) + .run( + vec![(input_id, input.view().into())], + &[op_c_out], + None, + None, + ) .unwrap(); let expected = Tensor::from([2., 3.]); expect_equal(&results[0].as_tensor_view().unwrap(), &expected.view())?; let results = g - .run(vec![(input_id, input.into())], &[op_d_out], None) + .run(vec![(input_id, input.into())], &[op_d_out], None, None) .unwrap(); let expected = Tensor::from([3., 2.]); expect_equal(&results[0].as_tensor_view().unwrap(), &expected.view())?; @@ -2127,7 +2209,12 @@ mod tests { let input = Tensor::from(0.); let results = g - .run(vec![(input_id, input.into())], &[op_a_out, op_b_out], None) + .run( + vec![(input_id, input.into())], + &[op_a_out, op_b_out], + None, + None, + ) .unwrap(); assert_eq!( &results[0].as_tensor_view().unwrap(), @@ -2159,7 +2246,7 @@ mod tests { } let results = g - .run(vec![(input_id, input.into())], &[prev_output], None) + .run(vec![(input_id, input.into())], &[prev_output], None, None) .unwrap(); let expected = Tensor::from([101., 102., 103., 104., 105.]); @@ -2176,7 +2263,12 @@ mod tests { let input_id = g.add_value(Some("input"), None, None); let results = g - .run(vec![(input_id, input.view().into())], &[input_id], None) + .run( + vec![(input_id, input.view().into())], + &[input_id], + None, + None, + ) .unwrap(); expect_equal(&results[0].as_tensor_view().unwrap(), &input.view())?; @@ -2191,7 +2283,7 @@ mod tests { let value = Tensor::from([1., 2., 3., 4., 5.]); let const_id = g.add_constant(Some("weight"), value.clone()); - let results = g.run(vec![], &[const_id], None).unwrap(); + let results = g.run(vec![], &[const_id], None, None).unwrap(); expect_equal(&results[0].as_tensor_view().unwrap(), &value.view())?; @@ -2239,7 +2331,7 @@ mod tests { #[test] fn test_no_outputs() { let g = Graph::new(); - let results = g.run(vec![], &[], None).unwrap(); + let results = g.run(vec![], &[], None, None).unwrap(); assert_eq!(results.len(), 0); } @@ -2255,6 +2347,7 @@ mod tests { ], &[input_id], None, + None, ); assert_eq!( result, @@ -2271,7 +2364,12 @@ mod tests { let input = Tensor::from([1.]); - let result = g.run(vec![(input_id, input.into())], &[op_a_out, op_a_out], None); + let result = g.run( + vec![(input_id, input.into())], + &[op_a_out, op_a_out], + None, + None, + ); assert_eq!( result, @@ -2289,7 +2387,7 @@ mod tests { let output = g.add_value(None, None, None); g.add_op(Some("shape"), Box::new(Shape {}), &[None], &[Some(output)]); - let results = g.run(vec![], &[output], None); + let results = g.run(vec![], &[output], None, None); assert_eq!( results.err(), @@ -2303,7 +2401,7 @@ mod tests { #[test] fn test_err_if_invalid_output() { let g = Graph::new(); - let result = g.run(vec![], &[NodeId::from_u32(123)], None); + let result = g.run(vec![], &[NodeId::from_u32(123)], None, None); assert_eq!( result.err(), Some(RunError::PlanningError("Missing output 123".to_string())) @@ -2314,7 +2412,7 @@ mod tests { fn test_err_if_missing_operator_input() { let mut g = Graph::new(); let (_, output) = g.add_simple_op("op", Relu {}, &[NodeId::from_u32(42)]); - let result = g.run(vec![], &[output], None); + let result = g.run(vec![], &[output], None, None); assert_eq!( result.err(), Some(RunError::PlanningError( @@ -2370,14 +2468,24 @@ mod tests { // First operator should not be run in-place, since it has an // immutable input. The result should be the same as the input. let results = g - .run(vec![(input_id, input.view().into())], &[op1_out], None) + .run( + vec![(input_id, input.view().into())], + &[op1_out], + None, + None, + ) .unwrap(); assert_eq!(results[0].as_tensor_view::().unwrap()[[0, 0]], 0.0); // Second operator should be run in-place, as it meets all the // requirements for this optimization. let results = g - .run(vec![(input_id, input.view().into())], &[op2_out], None) + .run( + vec![(input_id, input.view().into())], + &[op2_out], + None, + None, + ) .unwrap(); assert_eq!(results[0].as_tensor_view::().unwrap()[[0, 0]], 1.0); @@ -2389,6 +2497,7 @@ mod tests { vec![(input_id, input.view().into())], &[op3_out, op4_out], None, + None, ) .unwrap(); assert_eq!(results[0].as_tensor_view::().unwrap()[[0, 0]], 1.0); @@ -2428,6 +2537,7 @@ mod tests { vec![(input_id, input.view().into()), (bias_id, bias.into())], &[op2_out], None, + None, ) .unwrap(); @@ -2510,6 +2620,7 @@ mod tests { vec![(input_id, input.into())], &[left_split_out, right_split_out], None, + None, ) .unwrap(); @@ -2683,6 +2794,7 @@ mod tests { pool: &TensorPool, inputs: InputList, captures: CaptureEnv, + weight_caches: Option<&[WeightCache]>, options: Option, ) -> Result { let inputs = self @@ -2698,6 +2810,7 @@ mod tests { self.graph.output_ids(), captures, Some(pool), + weight_caches.map(|wcs| &wcs[0]), options, ) .map(|xs| xs.into_iter().collect()) @@ -2741,6 +2854,7 @@ mod tests { ], &[if_out], None, + None, ) .unwrap(); let result: Tensor = result.remove(0).try_into().unwrap(); @@ -2755,6 +2869,7 @@ mod tests { ], &[if_out], None, + None, ) .unwrap(); let result: Tensor = result.remove(0).try_into().unwrap(); @@ -2785,7 +2900,12 @@ mod tests { let (_, sg_out) = g.add_simple_op("Subgraph", Subgraph { graph: subgraph }, &[]); let mut result = g - .run(vec![(input, Tensor::from(2.).into())], &[sg_out], None) + .run( + vec![(input, Tensor::from(2.).into())], + &[sg_out], + None, + None, + ) .unwrap(); let result: Tensor = result.remove(0).try_into().unwrap(); assert_eq!(result, Tensor::from(2.)); @@ -2809,7 +2929,7 @@ mod tests { let result = subgraph.partial_run(Vec::new(), &[sg_add], None).unwrap(); assert_eq!(result.len(), 0); - let result = subgraph.run(Vec::new(), &[sg_add], None); + let result = subgraph.run(Vec::new(), &[sg_add], None, None); assert_eq!( result, Err(RunError::PlanningError( @@ -2864,7 +2984,9 @@ mod tests { // Run the graph. The planner must account for captured dependencies // in the `Subgraph` op. let input = Tensor::from(3.); - let mut result = g.run(vec![(input_id, input.into())], &[out], None).unwrap(); + let mut result = g + .run(vec![(input_id, input.into())], &[out], None, None) + .unwrap(); let result: Tensor = result.remove(0).try_into().unwrap(); assert_eq!(result.item(), Some(&6.)); } @@ -2899,7 +3021,9 @@ mod tests { // Run the graph. The planner must account for captured dependencies // from the innermost graph in the `Subgraph` op. let input = Tensor::from(3.); - let mut result = g.run(vec![(input_id, input.into())], &[out], None).unwrap(); + let mut result = g + .run(vec![(input_id, input.into())], &[out], None, None) + .unwrap(); let result: Tensor = result.remove(0).try_into().unwrap(); assert_eq!(result.item(), Some(&6.)); } @@ -2927,7 +3051,9 @@ mod tests { let (_, out) = g.add_simple_op("Subgraph", Subgraph { graph: subgraph }, &[mul_out]); let input = Tensor::from(3.); - let mut result = g.run(vec![(input_id, input.into())], &[out], None).unwrap(); + let mut result = g + .run(vec![(input_id, input.into())], &[out], None, None) + .unwrap(); let result: Tensor = result.remove(0).try_into().unwrap(); assert_eq!(result.item(), Some(&3.)); } @@ -2951,7 +3077,9 @@ mod tests { // Run graph with an owned value as input. let input = Tensor::from(42.); - let mut result = g.run(vec![(input_id, input.into())], &[out], None).unwrap(); + let mut result = g + .run(vec![(input_id, input.into())], &[out], None, None) + .unwrap(); // Check result and that Identity operation was run in-place. let result: Tensor = result.remove(0).try_into().unwrap(); @@ -2966,7 +3094,7 @@ mod tests { // Run graph with view as input. let input = Tensor::from(42.); let mut result = g - .run(vec![(input_id, input.view().into())], &[out], None) + .run(vec![(input_id, input.view().into())], &[out], None, None) .unwrap(); // Check result and that Identity operation was not run in-place. @@ -2979,4 +3107,88 @@ mod tests { assert_eq!(id_op_metrics.run_in_place_count, 1); } } + + // MatMul wrapper that verifies its B input (ie. the weights) are prepacked. + #[derive(Debug)] + struct MatMulExpectPacked { + inner: MatMul, + } + + impl MatMulExpectPacked { + fn new() -> Self { + MatMulExpectPacked { inner: MatMul {} } + } + } + + impl Operator for MatMulExpectPacked { + fn name(&self) -> &str { + "MatMulExpectPacked" + } + + fn prepack_inputs(&self) -> SmallVec<[usize; 1]> { + [1].into() + } + + fn prepack(&self, index: usize, input: Input) -> Option { + self.inner.prepack(index, input) + } + + fn run(&self, pool: &TensorPool, inputs: InputList) -> Result { + let prepacked = inputs.get_prepacked(1); + assert!(prepacked.is_some()); + self.inner.run(pool, inputs) + } + } + + #[test] + fn test_prepack_weights() { + // Create a graph and a subgraph, both with operators that can + // use prepacked weights. + let mut graph = Graph::new(); + let mut cache = WeightCache::new(); + + let input = graph.add_value(Some("input"), None, None); + let weights = graph.add_constant(None, Tensor::::zeros(&[10, 7])); + let (_, matmul_out) = + graph.add_simple_op("MatMul", MatMulExpectPacked::new(), &[input, weights]); + + let mut subgraph = Graph::new(); + let sg_input = subgraph.add_value(Some("sg-input"), None, None); + let sg_weights = subgraph.add_constant(None, Tensor::::zeros(&[7, 5])); + let (_, sg_matmul_out) = subgraph.add_simple_op( + "sg-MatMul", + MatMulExpectPacked::new(), + &[sg_input, sg_weights], + ); + subgraph.set_input_ids(&[sg_input]); + subgraph.set_output_ids(&[sg_matmul_out]); + + let (subgraph_op, subgraph_out) = + graph.add_simple_op("Subgraph", Subgraph { graph: subgraph }, &[matmul_out]); + graph.set_input_ids(&[input]); + graph.set_output_ids(&[subgraph_out]); + + // Prepack weights and verify that the cache was populated. + graph.prepack_weights(&mut cache); + assert_eq!(cache.len(), 2); + assert!(cache.get(weights).is_some()); + + let sg_cache = cache + .get_subgraph_caches(subgraph_op) + .map(|caches| &caches[0]) + .unwrap(); + assert!(sg_cache.get(sg_weights).is_some()); + + // Run the graph, passing the cache. The MatMul wrapper will verify + // that the B / RHS inputs were passed from the cache. + let input_value = Tensor::::zeros(&[3, 10]); + graph + .run( + [(input, input_value.into())].into(), + &[subgraph_out], + Some(&cache), + None, + ) + .unwrap(); + } } diff --git a/src/lib.rs b/src/lib.rs index 43fc04af..176fb1a0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,6 +109,7 @@ mod slice_reductions; mod tensor_pool; mod threading; mod timing; +mod weight_cache; #[cfg(feature = "wasm_api")] mod wasm_api; diff --git a/src/model.rs b/src/model.rs index e428ad5e..d5987fa0 100644 --- a/src/model.rs +++ b/src/model.rs @@ -27,6 +27,7 @@ use crate::optimize::GraphOptimizer; use crate::schema_generated as sg; use crate::schema_generated::root_as_model; use crate::timing::TimingSort; +use crate::weight_cache::WeightCache; /// The central type used to execute RTen machine learning models. /// @@ -86,6 +87,19 @@ use crate::timing::TimingSort; /// the model's inputs and outputs, but other nodes may be replaced or /// eliminated. To configure or disable optimizations, use [`ModelOptions`]. /// +/// ## Weight prepacking +/// +/// In addition to optimizing the structure of the graph, RTen can create copies +/// of the weights with an optimized ("packed") data layout at model load time. +/// Enabling this will increase model load time and memory usage but reduce the +/// time taken per inference. When this option is disabled, weights are packed +/// temporarily on-demand just before they are used for computation. +/// +/// For generative transformer models (aka. "transformer decoders") prepacking +/// is generally only useful when processing multiple input tokens at a time. +/// +/// Prepacking is disabled by default but can be enabled using [`ModelOptions`]. +/// /// ## Partial evaluation /// /// Some models, such as transformer decoders, are evaluated repeatedly in a @@ -102,6 +116,7 @@ use crate::timing::TimingSort; pub struct Model { graph: Graph, metadata: ModelMetadata, + weight_cache: WeightCache, } /// Provides access to metadata about a graph node. @@ -181,6 +196,7 @@ struct SubgraphOptions<'a> { pub struct ModelOptions { registry: OpRegistry, optimize: bool, + prepack_weights: bool, } impl ModelOptions { @@ -197,6 +213,7 @@ impl ModelOptions { ModelOptions { registry: ops, optimize: true, + prepack_weights: false, } } @@ -206,6 +223,16 @@ impl ModelOptions { self } + /// Set whether weights are prepacked. + /// + /// Prepacking creates copies of the weights with an optimized data layout. + /// Enabling this will increase model load time and memory usage but allow + /// for faster inference. + pub fn prepack_weights(&mut self, prepack: bool) -> &mut Self { + self.prepack_weights = prepack; + self + } + /// Load the model from a file. See [`Model::load_file`]. pub fn load_file>(&self, path: P) -> Result { let data = std::fs::read(path).map_err(ModelLoadError::ReadFailed)?; @@ -340,12 +367,21 @@ impl Model { None, /* capture_env */ )?; + let mut weight_cache = WeightCache::new(); + if options.prepack_weights { + graph.prepack_weights(&mut weight_cache); + } + let metadata = model .metadata() .map(ModelMetadata::deserialize) .unwrap_or_default(); - let model = Model { graph, metadata }; + let model = Model { + graph, + metadata, + weight_cache, + }; Ok(model) } @@ -674,7 +710,8 @@ impl Model { let timing_var = timing_var.to_string_lossy(); parse_timing_config(&timing_var, &mut opts); } - self.graph.run(inputs, outputs, Some(opts)) + self.graph + .run(inputs, outputs, Some(&self.weight_cache), Some(opts)) } /// Run a model and retrieve `N` outputs. @@ -1013,21 +1050,46 @@ mod tests { fn test_load_and_run_model() { struct Case { format: ModelFormat, + opts: Option, } let cases = [ Case { format: ModelFormat::V1, + opts: None, }, Case { format: ModelFormat::V2, + opts: None, + }, + // Graph optimizations disabled + Case { + format: ModelFormat::V2, + opts: Some({ + let mut opts = ModelOptions::with_all_ops(); + opts.enable_optimization(false); + opts + }), + }, + // Prepacking enabled + Case { + format: ModelFormat::V2, + opts: Some({ + let mut opts = ModelOptions::with_all_ops(); + opts.prepack_weights(true); + opts + }), }, ]; - for Case { format } in cases { + for Case { format, opts } in cases { let buffer = generate_model_buffer(format); - let model = Model::load(buffer).unwrap(); + let model = if let Some(opts) = opts { + opts.load(buffer).unwrap() + } else { + Model::load(buffer).unwrap() + }; let input_id = model.input_ids()[0]; let output_id = model.output_ids()[0]; diff --git a/src/ops/control_flow.rs b/src/ops/control_flow.rs index ac1d7c2e..70f0b096 100644 --- a/src/ops/control_flow.rs +++ b/src/ops/control_flow.rs @@ -4,6 +4,7 @@ use smallvec::SmallVec; use crate::graph::{CaptureEnv, Graph, RunError, RunOptions}; use crate::ops::{InputList, OpError, Operator, Output, OutputList}; use crate::tensor_pool::TensorPool; +use crate::weight_cache::WeightCache; fn output_list_from_vec(xs: Vec) -> OutputList { xs.into_iter().collect() @@ -47,6 +48,7 @@ impl Operator for If { pool: &TensorPool, inputs: InputList, captures: CaptureEnv, + weight_caches: Option<&[WeightCache]>, run_opts: Option, ) -> Result { let cond: TensorView = inputs.require_as(0).map_err(run_error_from_op_error)?; @@ -63,6 +65,7 @@ impl Operator for If { self.then_branch.output_ids(), captures, Some(pool), + weight_caches.map(|wcs| &wcs[0]), run_opts, ) .map(output_list_from_vec) @@ -73,6 +76,7 @@ impl Operator for If { self.else_branch.output_ids(), captures, Some(pool), + weight_caches.map(|wcs| &wcs[1]), run_opts, ) .map(output_list_from_vec) diff --git a/src/ops/einsum.rs b/src/ops/einsum.rs index 369f8c37..3864c320 100644 --- a/src/ops/einsum.rs +++ b/src/ops/einsum.rs @@ -497,7 +497,7 @@ fn einsum_matmul( let xp = permute_and_insert_axes(x, term1, &x_order); let yp = permute_and_insert_axes(y, term2, &y_order); - let mut out = matmul(pool, xp, yp)?; + let mut out = matmul(pool, xp, yp, None)?; if !matmul_m.is_ascii_lowercase() { out.remove_axis(out.ndim() - 2); @@ -720,7 +720,7 @@ mod tests { let mat_a = Tensor::from([[1., 2., 3.], [4., 5., 6.]]); let mat_b = Tensor::from([[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10., 11., 12.]]); - let matmul_ab = matmul(&pool, mat_a.view(), mat_b.view()).unwrap(); + let matmul_ab = matmul(&pool, mat_a.view(), mat_b.view(), None).unwrap(); let matmul_ba = matmul_ab.transposed().to_tensor(); let outer_mat_ab = mul( &pool, @@ -867,7 +867,7 @@ mod tests { Case { equation: "ij,j->i", inputs: vec![mat_a.view(), mat_b.slice((.., 0))], - expected: Ok(matmul(&pool, mat_a.view(), mat_b.slice((.., ..1))) + expected: Ok(matmul(&pool, mat_a.view(), mat_b.slice((.., ..1)), None) .unwrap() .into_shape([mat_a.size(0)].as_slice())), }, @@ -875,7 +875,7 @@ mod tests { Case { equation: "j,jk->k", inputs: vec![mat_a.slice(0), mat_b.view()], - expected: Ok(matmul(&pool, mat_a.slice((..1, ..)), mat_b.view()) + expected: Ok(matmul(&pool, mat_a.slice((..1, ..)), mat_b.view(), None) .unwrap() .into_shape([mat_b.size(1)].as_slice())), }, diff --git a/src/ops/fused.rs b/src/ops/fused.rs index 4604c0f4..fe479903 100644 --- a/src/ops/fused.rs +++ b/src/ops/fused.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use rten_tensor::prelude::*; +use smallvec::SmallVec; -use crate::ops::{Input, InputList, OpError, Operator, OutputList}; +use crate::ops::{Input, InputList, OpError, Operator, OutputList, PrepackedInput}; use crate::tensor_pool::TensorPool; /// Specifies a permutation to an operator input. @@ -76,6 +77,14 @@ impl Operator for FusedTranspose { self.perm.apply(&mut inputs)?; self.inner.run(pool, inputs) } + + fn prepack_inputs(&self) -> SmallVec<[usize; 1]> { + self.inner.prepack_inputs() + } + + fn prepack(&self, index: usize, input: Input) -> Option { + self.inner.prepack(index, input) + } } #[cfg(test)] diff --git a/src/ops/matmul.rs b/src/ops/matmul.rs index 7005009f..16f64128 100644 --- a/src/ops/matmul.rs +++ b/src/ops/matmul.rs @@ -1,13 +1,18 @@ use rayon::prelude::*; use rten_tensor::prelude::*; -use rten_tensor::{Tensor, TensorView}; +use rten_tensor::{Matrix, Tensor, TensorView}; +use smallvec::SmallVec; -use crate::gemm::{BiasVector, GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT}; +use crate::gemm::{ + BiasVector, GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT, PackedBMatrix, +}; use crate::iter_util::range_chunks; use crate::ops::binary_elementwise::broadcast_shapes; use crate::ops::layout::expand_to; -use crate::ops::{static_dims, InputList, IntoOpResult, OpError, Operator, OutputList}; +use crate::ops::{ + static_dims, Input, InputList, IntoOpResult, OpError, Operator, OutputList, PrepackedInput, +}; use crate::tensor_pool::{AutoReturn, TensorPool}; /// Compute the General Matrix Multiplication (GEMM) `c = alpha * (ab) + beta * c`. @@ -119,21 +124,30 @@ enum MatmulStrategy { Batch, } +fn matmul_prepack_b(input: Input) -> Option { + let executor = GemmExecutor::default(); + let tensor: TensorView = input.try_into().ok()?; + let matrix: Matrix = tensor.try_into().ok()?; + Some(PrepackedInput::FloatBMatrix(executor.prepack_b(matrix))) +} + pub fn matmul( pool: &TensorPool, a: TensorView, b: TensorView, + packed_b: Option<&PackedBMatrix>, ) -> Result, OpError> where GemmExecutor: Default, { - matmul_impl(pool, a, b, MatmulStrategy::Auto, None) + matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, None) } fn matmul_impl( pool: &TensorPool, mut a: TensorView, mut b: TensorView, + packed_b: Option<&PackedBMatrix>, strategy: MatmulStrategy, bias: Option>, ) -> Result, OpError> @@ -188,7 +202,7 @@ where // nb. We assume `a` is likely already contiguous, so this will be cheap. let a_contig = a.to_contiguous_in(pool).auto_return(pool); let a_matrix = a_contig.reshaped([num_a_matrices * a_rows, a_cols].as_slice()); - let mut output = matmul_impl(pool, a_matrix.view(), b.clone(), strategy, bias)?; + let mut output = matmul_impl(pool, a_matrix.view(), b.clone(), packed_b, strategy, bias)?; output.reshape(out_shape); return Ok(output); } @@ -214,7 +228,8 @@ where let gemm = GemmExecutor::default(); - // Prepack re-used inputs to amortize packing cost. + // Prepack inputs if they are re-used, to amortize packing cost, or we + // were already called with prepacked inputs. // // We don't prepack when the "A" matrix is a vector because that uses a // special case vector-matrix algorithm that doesn't benefit from packing. @@ -224,11 +239,16 @@ where }); let prepacked_a = prepacked_a.as_deref(); - let prepacked_b = (num_a_matrices > 1 && num_b_matrices == 1 && a_rows > 1).then(|| { - let b_matrix = b.inner_iter::<2>().next().unwrap(); - gemm.prepack_b_in(pool, b_matrix).auto_return(pool) - }); - let prepacked_b = prepacked_b.as_deref(); + let prepacked_b = + (num_b_matrices == 1 && num_a_matrices > 1 && a_rows > 1 && packed_b.is_none()).then( + || { + let b_matrix = b.inner_iter::<2>().next().unwrap(); + gemm.prepack_b_in(pool, b_matrix).auto_return(pool) + }, + ); + let prepacked_b = prepacked_b + .as_deref() + .or(if a_rows > 1 { packed_b } else { None }); a_broadcast .inner_iter::<2>() @@ -282,7 +302,23 @@ impl Operator for MatMul { fn run(&self, pool: &TensorPool, inputs: InputList) -> Result { let a = inputs.require_as(0)?; let b = inputs.require_as(1)?; - matmul::(pool, a, b).into_op_result() + let packed_b = match inputs.get_prepacked(1) { + Some(PrepackedInput::FloatBMatrix(pb)) => Some(pb), + _ => None, + }; + matmul::(pool, a, b, packed_b).into_op_result() + } + + fn prepack_inputs(&self) -> SmallVec<[usize; 1]> { + [1].into() + } + + fn prepack(&self, index: usize, input: Input) -> Option { + if index == 1 { + matmul_prepack_b(input) + } else { + None + } } } @@ -290,12 +326,13 @@ pub fn matmul_add( pool: &TensorPool, a: TensorView, b: TensorView, + packed_b: Option<&PackedBMatrix>, bias: BiasVector, ) -> Result, OpError> where GemmExecutor: Default, { - matmul_impl(pool, a, b, MatmulStrategy::Auto, Some(bias)) + matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, Some(bias)) } /// Fusion for `Add(MatMul(a, b), bias)` subgraphs, where `bias` is a vector. @@ -310,11 +347,27 @@ impl Operator for MatMulAdd { fn run(&self, pool: &TensorPool, inputs: InputList) -> Result { let a = inputs.require_as(0)?; let b = inputs.require_as(1)?; + let packed_b = match inputs.get_prepacked(1) { + Some(PrepackedInput::FloatBMatrix(pb)) => Some(pb), + _ => None, + }; let bias = inputs.require_as(2)?; let bias = static_dims!(bias, 1, "N")?.to_contiguous_in(pool); - matmul_add(pool, a, b, BiasVector::Row(bias.data().unwrap())).into_op_result() + matmul_add(pool, a, b, packed_b, BiasVector::Row(bias.data().unwrap())).into_op_result() + } + + fn prepack_inputs(&self) -> SmallVec<[usize; 1]> { + [1].into() + } + + fn prepack(&self, index: usize, input: Input) -> Option { + if index == 1 { + matmul_prepack_b(input) + } else { + None + } } } @@ -451,10 +504,12 @@ mod tests { use crate::gemm::{gemm, BiasVector, GemmExecutor, GemmInputA, GemmInputB}; use crate::ops::binary_elementwise::broadcast_shapes; use crate::ops::tests::new_pool; + use crate::ops::{InputList, Operator}; use crate::tensor_pool::AutoReturn; use super::{ - gemm_op, matmul, matmul_add, matmul_impl, matmul_integer, MatmulStrategy, OpError, + gemm_op, matmul, matmul_add, matmul_impl, matmul_integer, MatMul, MatMulAdd, + MatmulStrategy, OpError, }; fn gemm_tensors(c: &mut Tensor, a: &Tensor, b: &Tensor, alpha: f32, beta: f32) { @@ -768,7 +823,67 @@ mod tests { let mut expected = Tensor::zeros(out_shape); reference_matmul(expected.view_mut(), a.view(), b.view(), None); - let result = matmul(&pool, a.view(), b.view()).unwrap(); + let result = matmul(&pool, a.view(), b.view(), None).unwrap(); + expect_equal(&result, &expected)?; + } + + Ok(()) + } + + #[test] + fn test_matmul_with_prepacked_inputs() -> Result<(), Box> { + struct Case { + op: Box, + bias_input: bool, + } + + let cases = [ + Case { + op: Box::new(MatMul {}), + bias_input: false, + }, + Case { + op: Box::new(MatMulAdd {}), + bias_input: true, + }, + ]; + + let mut rng = XorShiftRng::new(1234); + + let a = Tensor::rand(&[5, 10], &mut rng); + + // The unpacked and pre-packed versions of an input should use the + // same data. Here we intentionally use different tensors with + // the same shape so we can verify if the packed data was used. + let b = Tensor::::rand(&[10, 3], &mut rng); + + // Dummy zero bias. + let bias = Tensor::::zeros(&[3]); + + for Case { op, bias_input } in cases { + let packed_b_input = Tensor::rand(&[10, 3], &mut rng); + let packed_b = op.prepack(1, packed_b_input.view().into()).unwrap(); + + let mut expected = Tensor::zeros(&[5, 3]); + reference_matmul(expected.view_mut(), a.view(), packed_b_input.view(), None); + + let pool = new_pool(); + let get_prepacked = |idx| { + if idx == 1 { + Some(&packed_b) + } else { + None + } + }; + let mut inputs = + InputList::from(&[a.view().into(), b.view().into()]).with_prepacked(&get_prepacked); + if bias_input { + inputs.push(bias.view()); + } + + let mut result = op.run(&pool, inputs).unwrap(); + let result: Tensor = result.remove(0).try_into().unwrap(); + expect_equal(&result, &expected)?; } @@ -786,7 +901,7 @@ mod tests { let pool = new_pool(); let mut expected = Tensor::zeros(&[10, 5]); reference_matmul(expected.view_mut(), a.view(), b.view(), bias.clone()); - let result = matmul_add(&pool, a.view(), b.view(), bias.unwrap()).unwrap(); + let result = matmul_add(&pool, a.view(), b.view(), None, bias.unwrap()).unwrap(); expect_equal(&result, &expected)?; Ok(()) @@ -836,7 +951,7 @@ mod tests { let a = Tensor::rand(a_shape, &mut rng); let b = Tensor::rand(b_shape, &mut rng); - let result = matmul(&pool, a.view(), b.view()); + let result = matmul(&pool, a.view(), b.view(), None); assert_eq!(result, Err(error)); } @@ -862,7 +977,7 @@ mod tests { let mut rng = XorShiftRng::new(1234); let a = Tensor::rand(&[m, k], &mut rng); let b = Tensor::rand(&[k, n], &mut rng); - let result = matmul(&pool, a.view(), b.view()).unwrap(); + let result = matmul(&pool, a.view(), b.view(), None).unwrap(); assert_eq!(result.shape(), &[m, n]); if k == 0 { @@ -1038,7 +1153,7 @@ mod tests { ); let pool = new_pool(); run_bench(trials, Some(&desc), || { - matmul_impl(&pool, a.view(), b.view(), strategy, None) + matmul_impl(&pool, a.view(), b.view(), None, strategy, None) .unwrap() .auto_return(&pool); }); diff --git a/src/ops/mod.rs b/src/ops/mod.rs index b1ec466b..99cd1b6c 100644 --- a/src/ops/mod.rs +++ b/src/ops/mod.rs @@ -26,8 +26,10 @@ use rten_tensor::{ }; use crate::downcast::impl_downcastdyn; +use crate::gemm::PackedBMatrix; use crate::graph::{CaptureEnv, Graph, RunError, RunOptions}; use crate::tensor_pool::{ExtractBuffer, TensorPool}; +use crate::weight_cache::WeightCache; mod binary_elementwise; mod concat; @@ -336,6 +338,13 @@ impl<'a> From<&'a Output> for Input<'a> { } } +/// An operator input which has been pre-packed for more efficient use during +/// inference. +pub enum PrepackedInput { + /// Prepacked RHS / B input for matrix multiplication with f32 weights. + FloatBMatrix(PackedBMatrix), +} + /// Enum of the different types of output tensor that a model or operator can /// return. #[derive(Debug, Clone, PartialEq)] @@ -809,6 +818,23 @@ pub trait Operator: Any + Debug { SmallVec::new() } + /// Return the IDs of inputs which can be pre-packed using [`prepack`](Operator::prepack). + fn prepack_inputs(&self) -> SmallVec<[usize; 1]> { + SmallVec::new() + } + + /// Pre-pack an input for more efficient inference later. + /// + /// `index` specifies the input ID and should be one of the inputs returned + /// by [`prepack_inputs`](Operator::prepack_inputs). + fn prepack( + &self, + #[allow(unused)] index: usize, + #[allow(unused)] input: Input, + ) -> Option { + None + } + /// Execute the operator with the given inputs and captured values. /// /// This method will be called instead of `run` if the operator reports that @@ -825,6 +851,7 @@ pub trait Operator: Any + Debug { pool: &TensorPool, input: InputList, #[allow(unused)] captures: CaptureEnv, + #[allow(unused)] weight_cache: Option<&[WeightCache]>, #[allow(unused)] run_opts: Option, ) -> Result { self.run(pool, input) @@ -847,6 +874,10 @@ impl_downcastdyn!(Operator); /// references using `into`. pub struct InputList<'a> { inputs: Cow<'a, [Option>]>, + + /// Callback that retrieves the pre-packed copy of an input with a given + /// index. + get_prepacked: Option<&'a dyn Fn(usize) -> Option<&'a PrepackedInput>>, } impl<'a> InputList<'a> { @@ -854,6 +885,7 @@ impl<'a> InputList<'a> { pub fn new() -> InputList<'a> { InputList { inputs: Cow::Owned(vec![]), + get_prepacked: None, } } @@ -879,6 +911,7 @@ impl<'a> InputList<'a> { pub fn from(inputs: &[Input<'a>]) -> InputList<'a> { InputList { inputs: inputs.iter().cloned().map(Some).collect(), + get_prepacked: None, } } @@ -888,14 +921,30 @@ impl<'a> InputList<'a> { pub fn from_optional(inputs: &'a [Option>]) -> InputList<'a> { InputList { inputs: Cow::Borrowed(inputs), + get_prepacked: None, } } + /// Configure a callback that will get or create a pre-packed copy of the + /// input with a given index. + pub fn with_prepacked( + mut self, + lookup: &'a dyn Fn(usize) -> Option<&'a PrepackedInput>, + ) -> Self { + self.get_prepacked = Some(lookup); + self + } + /// Get an optional input. pub fn get(&self, index: usize) -> Option> { self.inputs.get(index).cloned().flatten() } + /// Get the pre-packed version of a weight input, if available. + pub fn get_prepacked(&self, index: usize) -> Option<&'a PrepackedInput> { + self.get_prepacked.and_then(|gp| gp(index)) + } + /// Get a mutable reference to an input. /// /// This will convert the list into an owned list of inputs first. diff --git a/src/ops/operators.rs b/src/ops/operators.rs index 2631076c..66595ac0 100644 --- a/src/ops/operators.rs +++ b/src/ops/operators.rs @@ -194,7 +194,7 @@ impl, L: MutLayout> Operators for TensorBase impl, L: MutLayout> FloatOperators for TensorBase { fn matmul(&self, other: TensorView) -> Result { let view = self.as_dyn(); - use_thread_pool(|| matmul(&TensorPool::new(), view, other)) + use_thread_pool(|| matmul(&TensorPool::new(), view, other, None)) } fn reduce_l2(&self, axes: Option<&[i32]>, keep_dims: bool) -> Result { diff --git a/src/wasm_api.rs b/src/wasm_api.rs index edbceafe..03126942 100644 --- a/src/wasm_api.rs +++ b/src/wasm_api.rs @@ -214,7 +214,7 @@ impl Tensor { let a = self.as_float()?; let b = other.as_float()?; let pool = TensorPool::new(); - let out = matmul(&pool, a, b).map_err(|e| e.to_string())?; + let out = matmul(&pool, a, b, None).map_err(|e| e.to_string())?; Ok(Tensor::from_output(out.into())) } } diff --git a/src/weight_cache.rs b/src/weight_cache.rs new file mode 100644 index 00000000..cea70bb4 --- /dev/null +++ b/src/weight_cache.rs @@ -0,0 +1,71 @@ +use rustc_hash::FxHashMap; + +use crate::graph::NodeId; +use crate::ops::PrepackedInput; + +/// A cache of prepacked weights for graph operators. +/// +/// The weight cache has a hierarchical structure which mirrors the model +/// graph. At the top level is the root graph. For each operator with a +/// subgraph (eg. control flow operators) there are separate sub-caches. +pub struct WeightCache { + /// Map of constant node ID to prepacked weights. + cache: FxHashMap, + + /// Map of operator ID to caches for the operator's subgraphs. + subgraph_caches: FxHashMap>, +} + +impl WeightCache { + /// Create an empty cache. + pub fn new() -> WeightCache { + WeightCache { + cache: FxHashMap::default(), + subgraph_caches: FxHashMap::default(), + } + } + + /// Check if a pre-packed weight exists for a given constant node ID. + pub fn contains(&self, node: NodeId) -> bool { + self.cache.contains_key(&node) + } + + /// Add a prepacked weight to the cache. + pub fn insert(&mut self, node: NodeId, packed: PrepackedInput) { + self.cache.insert(node, packed); + } + + /// Look up weight in the cache. + pub fn get(&self, node: NodeId) -> Option<&PrepackedInput> { + self.cache.get(&node) + } + + /// Add caches for subgraphs belonging to an operator. + pub fn insert_subgraph_caches(&mut self, operator_id: NodeId, caches: Vec) { + self.subgraph_caches.insert(operator_id, caches); + } + + /// Look up caches for an operator's subgraphs. + pub fn get_subgraph_caches(&self, operator_id: NodeId) -> Option<&[WeightCache]> { + self.subgraph_caches + .get(&operator_id) + .map(|wcs| wcs.as_slice()) + } + + /// Return the total number of cached weights, including in subgraphs. + pub fn len(&self) -> usize { + self.cache.len() + + self + .subgraph_caches + .values() + .flat_map(|caches| caches.iter()) + .map(|cache| cache.len()) + .sum::() + } +} + +impl Default for WeightCache { + fn default() -> Self { + WeightCache::new() + } +} From dd38a4aa6e055646129066268c86be1f2d6256d7 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Wed, 25 Dec 2024 06:41:48 +0000 Subject: [PATCH 2/2] Add CLI option to enable prepacking of model weights --- rten-cli/src/main.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/rten-cli/src/main.rs b/rten-cli/src/main.rs index a4a67d1f..8c3d4ee5 100644 --- a/rten-cli/src/main.rs +++ b/rten-cli/src/main.rs @@ -16,6 +16,9 @@ struct Args { /// Whether to enable graph optimizations optimize: bool, + /// Whether to enable prepacking of weights + prepack_weights: bool, + /// Run model and don't produce other output quiet: bool, @@ -115,6 +118,7 @@ fn parse_args() -> Result { let mut verbose = false; let mut input_sizes = Vec::new(); let mut optimize = true; + let mut prepack_weights = false; let mut parser = lexopt::Parser::from_env(); while let Some(arg) = parser.next()? { @@ -128,6 +132,7 @@ fn parse_args() -> Result { .map_err(|_| "Unable to parse `n_iters`".to_string())?; } Long("no-optimize") => optimize = false, + Short('p') | Long("prepack") => prepack_weights = true, Short('q') | Long("quiet") => quiet = true, Short('v') | Long("verbose") => verbose = true, Short('V') | Long("version") => { @@ -163,12 +168,15 @@ Options: -q, --quiet Run model and don't produce other output - -t, --timing Output timing info + -p, --prepack Enable prepacking of weights. + This requires additional memory but makes inference faster. -s, --size Specify size for a dynamic dimension in the form `dim_name=size` or `input_name.dim_name=size` + -t, --timing Output timing info + -v, --verbose Enable verbose logging -V, --version Display RTen version ", @@ -183,14 +191,15 @@ Options: let model = values.pop_front().ok_or("missing `` arg")?; Ok(Args { + input_sizes, + mmap, model, n_iters, - mmap, optimize, + prepack_weights, quiet, timing, verbose, - input_sizes, }) } @@ -461,6 +470,7 @@ fn main() -> Result<(), Box> { let mut model_opts = ModelOptions::with_all_ops(); model_opts.enable_optimization(args.optimize); + model_opts.prepack_weights(args.prepack_weights); let model = if args.mmap { unsafe { model_opts.load_mmap(args.model)? }