Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement prompt chunking #623

Merged
merged 20 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,6 @@ opt-level = 3
inherits = "release"
lto = "thin"

[profile.release]
codegen-units = 1
lto = "fat"
#[profile.release]
EricLBuehler marked this conversation as resolved.
Show resolved Hide resolved
#codegen-units = 1
#lto = "fat"
15 changes: 14 additions & 1 deletion mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use mistralrs_core::{
MistralRsBuilder, ModelDType, ModelSelected, NormalRequest, PagedAttentionConfig, Request,
RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
};
use std::fmt::Display;
use std::sync::Arc;
use std::{fmt::Display, num::NonZeroUsize};
use tokio::sync::mpsc::channel;
use tracing::{info, warn};

Expand Down Expand Up @@ -309,6 +309,10 @@ struct Args {
/// Disable PagedAttention on CUDA.
#[arg(long = "no_paged_attn", default_value_t = false)]
no_paged_attn: bool,

/// Number of tokens to batch the prompt step into. This can help with OOM errors when in the prompt step, but reduces performance.
#[arg(long = "prompt-batchsize")]
prompt_batchsize: Option<usize>,
}

fn main() -> anyhow::Result<()> {
Expand All @@ -322,8 +326,17 @@ fn main() -> anyhow::Result<()> {
#[cfg(feature = "flash-attn")]
let use_flash_attn = true;

let prompt_batchsize = match args.prompt_batchsize {
Some(0) => {
anyhow::bail!("`prompt_batchsize` must be a strictly positive integer, got 0.",)
}
Some(x) => Some(NonZeroUsize::new(x).unwrap()),
None => None,
};

let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
.with_use_flash_attn(use_flash_attn)
.with_prompt_batchsize(prompt_batchsize)
.build()?;
let model_name = loader.get_id();

Expand Down
50 changes: 42 additions & 8 deletions mistralrs-core/src/model_loader.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::fs::{self, File};
use std::{
fs::{self, File},
num::NonZeroUsize,
};

use crate::{
get_toml_selected_model_dtype,
Expand All @@ -13,6 +16,7 @@ pub struct LoaderBuilder {
no_kv_cache: bool,
chat_template: Option<String>,
use_flash_attn: bool,
prompt_batchsize: Option<NonZeroUsize>,
}

impl LoaderBuilder {
Expand All @@ -22,6 +26,7 @@ impl LoaderBuilder {
no_kv_cache: false,
chat_template: None,
use_flash_attn: false,
prompt_batchsize: None,
}
}

Expand All @@ -37,6 +42,10 @@ impl LoaderBuilder {
self.use_flash_attn = use_flash_attn;
self
}
pub fn with_prompt_batchsize(mut self, prompt_batchsize: Option<NonZeroUsize>) -> Self {
self.prompt_batchsize = prompt_batchsize;
self
}

pub fn build(self) -> anyhow::Result<Box<dyn Loader>> {
loader_from_model_selected(self)
Expand Down Expand Up @@ -102,6 +111,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
use_flash_attn,
chat_template: args.chat_template,
no_kv_cache: args.no_kv_cache,
prompt_batchsize: args.prompt_batchsize,
};
(selector, args).try_into()?
}
Expand All @@ -111,7 +121,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
arch,
dtype: _,
} => NormalLoaderBuilder::new(
NormalSpecificConfig { use_flash_attn },
NormalSpecificConfig {
use_flash_attn,
prompt_batchsize: args.prompt_batchsize,
},
args.chat_template,
tokenizer_json,
Some(model_id),
Expand All @@ -126,7 +139,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
arch,
dtype: _,
} => NormalLoaderBuilder::new(
NormalSpecificConfig { use_flash_attn },
NormalSpecificConfig {
use_flash_attn,
prompt_batchsize: args.prompt_batchsize,
},
args.chat_template,
tokenizer_json,
model_id,
Expand All @@ -149,7 +165,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
arch,
dtype: _,
} => NormalLoaderBuilder::new(
NormalSpecificConfig { use_flash_attn },
NormalSpecificConfig {
use_flash_attn,
prompt_batchsize: args.prompt_batchsize,
},
args.chat_template,
tokenizer_json,
model_id,
Expand All @@ -171,6 +190,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tok_model_id,
quantized_model_id,
quantized_filename,
args.prompt_batchsize,
)
.build(),
ModelSelected::XLoraGGUF {
Expand All @@ -185,6 +205,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tok_model_id,
quantized_model_id,
quantized_filename,
args.prompt_batchsize,
)
.with_xlora(
xlora_model_id,
Expand All @@ -207,6 +228,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tok_model_id,
quantized_model_id,
quantized_filename,
args.prompt_batchsize,
)
.with_lora(
adapters_model_id,
Expand All @@ -223,7 +245,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
quantized_filename,
gqa,
} => GGMLLoaderBuilder::new(
GGMLSpecificConfig { gqa },
GGMLSpecificConfig {
gqa,
prompt_batchsize: args.prompt_batchsize,
},
args.chat_template,
tokenizer_json,
Some(tok_model_id),
Expand All @@ -241,7 +266,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tgt_non_granular_index,
gqa,
} => GGMLLoaderBuilder::new(
GGMLSpecificConfig { gqa },
GGMLSpecificConfig {
gqa,
prompt_batchsize: args.prompt_batchsize,
},
args.chat_template,
tokenizer_json,
tok_model_id,
Expand All @@ -267,7 +295,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
order,
gqa,
} => GGMLLoaderBuilder::new(
GGMLSpecificConfig { gqa },
GGMLSpecificConfig {
gqa,
prompt_batchsize: args.prompt_batchsize,
},
args.chat_template,
tokenizer_json,
tok_model_id,
Expand All @@ -288,7 +319,10 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
arch,
dtype: _,
} => VisionLoaderBuilder::new(
VisionSpecificConfig { use_flash_attn },
VisionSpecificConfig {
use_flash_attn,
prompt_batchsize: args.prompt_batchsize,
},
args.chat_template,
tokenizer_json,
Some(model_id),
Expand Down
24 changes: 14 additions & 10 deletions mistralrs-core/src/models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,16 +472,20 @@ impl Model {
) -> Result<Tensor> {
let mut xs = self.embed_tokens.forward(input_ids)?;
let mut cache = self.cache.lock();
let attention_mask = CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
self.sliding_window,
xs.dtype(),
self.layers[0].self_attn.num_heads,
)?;
let attention_mask = if seqlen_offsets[0] == 0 {
CausalMasker.make_causal_mask_with_sliding_window_as_attn_bias(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(&*cache as &dyn PastKvLenCache),
self.sliding_window,
xs.dtype(),
self.layers[0].self_attn.num_heads,
)?
} else {
None
};

for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
Expand Down
6 changes: 4 additions & 2 deletions mistralrs-core/src/pipeline/amoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ impl Pipeline for AnyMoePipeline {
async fn sample(
&self,
seqs: &mut [&mut Sequence],
logits: Tensor,
logits: Vec<Tensor>,
prefix_cacher: &mut PrefixCacheManager,
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
Expand Down Expand Up @@ -441,13 +441,15 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
None,
input_processor_cfg.clone(),
None, // TODO: get block tables/handle it for PagedAttention
None, // TODO: prompt chunking doesn't work.
)
.nth(0)
.unwrap();

// === PREPARE AND RUN MODEL ==

// Run the model, ignoring the logits
let _ = target.forward_inputs(inputs)?;
let _ = target.forward_inputs(inputs.unwrap().inputs)?;

// Clear the KV cache
target.set_none_cache(true, true);
Expand Down
5 changes: 4 additions & 1 deletion mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use rand_isaac::Isaac64Rng;
use std::any::Any;
use std::fs;
use std::num::NonZeroUsize;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
Expand Down Expand Up @@ -77,6 +78,7 @@ pub struct GGMLLoader {
/// Config for a GGML loader.
pub struct GGMLSpecificConfig {
pub gqa: usize,
pub prompt_batchsize: Option<NonZeroUsize>,
}

#[derive(Default)]
Expand Down Expand Up @@ -355,6 +357,7 @@ impl Loader for GGMLLoader {
sliding_window: None,
cache_config: None,
cache_engine: None,
prompt_batchsize: self.config.prompt_batchsize,
}),
})))
}
Expand Down Expand Up @@ -519,7 +522,7 @@ impl Pipeline for GGMLPipeline {
async fn sample(
&self,
seqs: &mut [&mut Sequence],
logits: Tensor,
logits: Vec<Tensor>,
prefix_cacher: &mut PrefixCacheManager,
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
Expand Down
11 changes: 10 additions & 1 deletion mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use rand_isaac::Isaac64Rng;
use std::any::Any;
use std::fs;
use std::num::NonZeroUsize;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
Expand Down Expand Up @@ -89,6 +90,7 @@ pub struct GGUFLoader {
chat_template: Option<String>,
kind: ModelKind,
tgt_non_granular_index: Option<usize>,
prompt_batchsize: Option<NonZeroUsize>,
}

#[derive(Debug, EnumString)]
Expand Down Expand Up @@ -131,6 +133,7 @@ pub struct GGUFLoaderBuilder {
no_kv_cache: bool,
chat_template: Option<String>,
tgt_non_granular_index: Option<usize>,
prompt_batchsize: Option<NonZeroUsize>,
}

impl GGUFLoaderBuilder {
Expand All @@ -142,6 +145,7 @@ impl GGUFLoaderBuilder {
tok_model_id: Option<String>,
quantized_model_id: String,
quantized_filename: String,
prompt_batchsize: Option<NonZeroUsize>,
) -> Self {
let kind = ModelKind::Quantized {
quant: QuantizationKind::Gguf,
Expand All @@ -153,6 +157,7 @@ impl GGUFLoaderBuilder {
kind,
quantized_filename,
quantized_model_id,
prompt_batchsize,
..Default::default()
}
}
Expand Down Expand Up @@ -214,6 +219,7 @@ impl GGUFLoaderBuilder {
tgt_non_granular_index: self.tgt_non_granular_index,
quantized_filename: self.quantized_filename,
quantized_model_id: self.quantized_model_id,
prompt_batchsize: self.prompt_batchsize,
})
}
}
Expand All @@ -230,6 +236,7 @@ impl GGUFLoader {
no_kv_cache: bool,
chat_template: Option<String>,
tgt_non_granular_index: Option<usize>,
prompt_batchsize: Option<NonZeroUsize>,
) -> Self {
let model_id = if let Some(id) = model_id {
Some(id)
Expand All @@ -252,6 +259,7 @@ impl GGUFLoader {
chat_template,
kind,
tgt_non_granular_index,
prompt_batchsize,
}
}
}
Expand Down Expand Up @@ -578,6 +586,7 @@ impl Loader for GGUFLoader {
sliding_window: None,
cache_config,
cache_engine,
prompt_batchsize: self.prompt_batchsize,
}),
})))
}
Expand Down Expand Up @@ -769,7 +778,7 @@ impl Pipeline for GGUFPipeline {
async fn sample(
&self,
seqs: &mut [&mut Sequence],
logits: Tensor,
logits: Vec<Tensor>,
prefix_cacher: &mut PrefixCacheManager,
disable_eos_stop: bool,
rng: Arc<std::sync::Mutex<Isaac64Rng>>,
Expand Down
Loading
Loading