Skip to content

Commit

Permalink
Merge pull request #408 from robertknight/kv-cache-grow-after-run
Browse files Browse the repository at this point in the history
Reserve KV cache capacity after the first model run
  • Loading branch information
robertknight authored Nov 15, 2024
2 parents 773f728 + c24e15f commit 408ebe0
Showing 1 changed file with 99 additions and 14 deletions.
113 changes: 99 additions & 14 deletions rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,76 @@ enum KvCacheData {
BatchHeadSeqChans(NdTensor<f32, 4>),
}

impl KvCacheData {
/// Allocate a KV cache buffer with the given batch size, number of heads
/// and embed size.
///
/// The buffer initially has capacity to be extended to a sequence length
/// of `seq_len_capacity`.
fn with_capacity(
batch_size: usize,
n_heads: Option<usize>,
size: usize,
seq_len_capacity: usize,
) -> KvCacheData {
if let Some(n_heads) = n_heads {
KvCacheData::BatchHeadSeqChans(NdTensor::with_capacity(
[batch_size, n_heads, seq_len_capacity, size],
2, /* seq dim */
))
} else {
KvCacheData::BatchSeqChans(NdTensor::with_capacity(
[batch_size, seq_len_capacity, size],
1, /* seq dim */
))
}
}

/// Return the current sequence length of the cache.
fn sequence_len(&self) -> usize {
match self {
KvCacheData::BatchSeqChans(data) => data.size(1),
KvCacheData::BatchHeadSeqChans(data) => data.size(2),
}
}

/// Return true if the KV cache has capacity for a given sequence length.
fn has_capacity(&self, sequence_len: usize) -> bool {
match self {
KvCacheData::BatchSeqChans(data) => {
data.has_capacity(1 /* seq dim */, sequence_len)
}
KvCacheData::BatchHeadSeqChans(data) => {
data.has_capacity(2 /* seq dim */, sequence_len)
}
}
}

/// Clone this cache into a new buffer with space to store sequences of
/// a given size.
fn clone_with_capacity(&self, max_sequence_len: usize) -> KvCacheData {
let max_sequence_len = max_sequence_len.max(self.sequence_len());
match self {
KvCacheData::BatchSeqChans(data) => {
let [batch, _seq, chans] = data.shape();
let mut new_data =
NdTensor::with_capacity([batch, max_sequence_len, chans], 1 /* seq dim */);
new_data.append(1, data).expect("should have capacity");
KvCacheData::BatchSeqChans(new_data)
}
KvCacheData::BatchHeadSeqChans(data) => {
let [batch, n_heads, _seq, chans] = data.shape();
let mut new_data = NdTensor::with_capacity(
[batch, n_heads, max_sequence_len, chans],
2, /* seq dim */
);
new_data.append(2, data).expect("should have capacity");
KvCacheData::BatchHeadSeqChans(new_data)
}
}
}
}

/// Key-value cache for a single layer of a transformer model.
struct KvCache {
/// Input ID for this cache entry.
Expand Down Expand Up @@ -440,23 +510,28 @@ impl<'a> Generator<'a> {
.find_node(&output_name)
.ok_or(GeneratorError::OutputNotFound(output_name))?;

// This value should be configurable.
let max_seq_len = 512;
// Initial sequence length capacity for KV cache buffer.
//
// For models that execute different operations on the first vs
// subsequent iterations (eg. Hugging Face "merged" models with
// past and no-past branches) the input buffer may not be used in
// the first iteration. Instead we need to reserve capacity once
// the model returns the initial KV cache.
//
// For other simpler models the input KV cache buffer is used for
// all iterations, in which case we would ideally reserve capacity
// up-front based on the max expected sequence length.
let max_seq_len = 1;

let kv_cache_entry = KvCache {
input_id,
output_id,
cache: if let Some(n_heads) = n_heads {
Some(KvCacheData::BatchHeadSeqChans(NdTensor::with_capacity(
[batch_size, n_heads, max_seq_len, size],
2, /* seq dim */
)))
} else {
Some(KvCacheData::BatchSeqChans(NdTensor::with_capacity(
[batch_size, max_seq_len, size],
1, /* seq dim */
)))
},
cache: Some(KvCacheData::with_capacity(
batch_size,
n_heads,
size,
max_seq_len,
)),
};

if kv_pattern.encoder {
Expand Down Expand Up @@ -717,7 +792,7 @@ impl<'a> Generator<'a> {
let output = outputs.remove(0);

let err_context = "failed to save self-attention KV-cache";
let kv_cache = match output.ndim() {
let mut kv_cache = match output.ndim() {
3 => KvCacheData::BatchSeqChans(
output.try_into().map_err(|e| wrap_error(e, err_context))?,
),
Expand All @@ -731,6 +806,16 @@ impl<'a> Generator<'a> {
));
}
};

// Grow the KV cache buffer if it has reached the limit of its
// pre-allocated sequence length.
//
// Double the capacity each time to amortize the costs of copying
// the previous buffer.
if !kv_cache.has_capacity(kv_cache.sequence_len() + 1) {
kv_cache = kv_cache.clone_with_capacity(kv_cache.sequence_len() * 2);
}

cache_entry.cache = Some(kv_cache);
}

Expand Down

0 comments on commit 408ebe0

Please sign in to comment.