Skip to content

Commit

Permalink
Support pre-packing weights after model optimization
Browse files Browse the repository at this point in the history
This reduces inference time for matmul operations at the cost of higher memory
usage.

 - Add methods to `Operator` trait to declare which inputs can
   potentially be pre-packed and to prepack those inputs.

 - Add `Graph::prepack_weights` method to traverse operators and prepack
   inputs whose values are constant nodes.

 - Implement prepacking methods for MatMul and fused MatMul ops

There are some caveats:

 - Non-MatMul operations which use matmuls internally (Conv, ConvTranspose,
   LSTM, GRU etc.) currently don't prepack their weights.

 - MatMul operations which turn out to be matrix-vector (gemv) products don't
   use the prepacked weights. This affects transformer decoders doing
   non-batched generation after the initial prompt encoding step.
  • Loading branch information
robertknight committed Dec 24, 2024
1 parent 996d062 commit 817097b
Show file tree
Hide file tree
Showing 11 changed files with 401 additions and 50 deletions.
10 changes: 10 additions & 0 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ impl<T> PackedBMatrix<T> {
fn panel_len(&self) -> usize {
self.panel_width * self.rows
}

/// Number of rows in the unpacked matrix.
pub fn rows(&self) -> usize {
self.rows
}

/// Number of columns in the unpacked matrix.
pub fn cols(&self) -> usize {
self.cols
}
}

impl<T> ExtractBuffer for PackedBMatrix<T> {
Expand Down
Loading

0 comments on commit 817097b

Please sign in to comment.