Skip to content

Commit

Permalink
refactor: GGUF + GGML Loaders with ModelKind (#356)
Browse files Browse the repository at this point in the history
* chore: Communicate actual difference

These methods are very verbose, but really only differ by two params to differentiate from Lora vs XLora.

* refactor: Introduce `ModelConfig` + `from_gguf` proxy

`ModelConfig` groups the common properties used across the `from_gguf()` methods. This will better communicate differences across impl of `from_gguf()`.

The quantized xlora models `from_gguf()` now have a prop to param forwarder as a workaround to minimize breakage elsewhere.

* refactor: Add `from_ggml` proxy

Very similar to the `from_gguf`, except only `quantized_llama.rs` xlora supports this. No `Device` params, slightly different `File` params from GGUF type.

* chore: DRY `ggml.rs` + `gguf.rs` common adapter config logic

Finally, all this extra boilerplate can be shifted into the model config `Adapter` struct to self-contain in a single method.

This required adjusting ownership a little to satisfy the compiler. The original `from_gguf()` and `from_ggml()` methods are unaffected, they still receive the expected params as reference.

* refactor(breaking): Leverage traits for `from_gguf()` / `from_ggml()`

This introduces a slight breaking change, in that using these `from_gguf()` / `from_ggml()` methods now requires importing the trait into scope.

The methods drop the `pub` prefix as they inherit `pub` from the trait requirement itself.

The purpose of this trait is to not require each model to duplicate the structs to params mapping helper method. Instead that can be centralized.

* chore: DRY - Dedupe prop mapping methods

These no longer need to be maintained as copies within the supported model modules.

They now leverage the common shared traits and take an annotated type parameter to handle.

The syntax for usage is presently a little more verbose than desired.

* chore: Contextual commit - Alternative prop mapping approaches

For reference, these alternatives could be considered.

* refactor: Add equivalent support for quant models without adapters

- Impl traits for the non-adapter quant models
- Since adapter variants only append parameters and that is now a distinct struct, `model_config` can be defined earlier and a helper `with_adapter()` can convert to the adapter type variant.

* chore: Fix typo

* refactor: Collapse Lora + XLora arms

With a rebase to adopt new methods for `ModelKind`, the shared logic can be hoisted out of the match arms.

XLora specific variables were moved into `Adapter::try_new()` (`model_config.rs`) as they can share the same `paths` parameter by adding a separate bool to toggle x-lora usage.

By hoisting `model_config` variable out of the match arm, the type would change when calling `with_adapter()`, thus to prevent that the separate `Adapter*` tuple structs have been dropped in favor of `ModelParams<Q>` which uses generic `Q` for the quantization type (trait marker) and a separate adapter optional that can be updated.

`MapParamsToModel<T>` also is no longer compatible as an approach since the trait bound is ambiguous as there is no distinct adapter type (eg: `for AdapterGGUF`) to impl upon, unique method names also become required to avoid conflict on the same type.
- `try_into_model()` + `try_into_model_with_adapter()` for the separate `Q` types (GGUF/GGML) with/without adapter.
- Due to new struct introduced, slight change to destructuring. The `impl` also bundles both methods now for each `Q` variant. Order was adjusted to basic followed by adapter methods for each `Q` variant, instead of both basic, then both adapter variations following afterwards.
- Likewise the `ggml.rs` and `gguf.rs` methods without the `MapParamsToModel<T>` trait now rely on `TryFrom` trait impl.

* refactor: Wrap `ModelParams` into enum for distinct adapter type

This approach introduces another generic parameter `MaybeAdapter` to emulate a `Option<Adapter>` that can be used as type to `impl` upon.

To continue the unified type usage with an adapter variant in `ggml.rs` / `gguf.rs` pipelines, this must now leverage an enum for the two variants.
- Slightly more complexity added as a result.
- Adapter `try_into_model()` methods no longer need to check for `Some(Adapter)` to unwrap, since that should always be the case. This is now guaranteed.
- However similar logic has bubbled up to the `TryFrom` for all impl due to the enum wrapper, thus this approach may not be much better beyond broader consistency. Likewise with the `with_adapter()` method.

To minimize boilerplate in handling unwrapping of the enum in the `TryFrom` methods, `Variantly` has been introduced for it's `expect_variant()` method.

As all four types are distinct, the `_with_adapter()` method can also be `try_into_model()` due to separate impl for the new generic param `MaybeAdapter`.

* chore: Minor improvements

Since the type constraint for `try_into_model()` methods is bound as the return type, it can be inferred without any hint in the `TryFrom`, no need to annotate with `Self`.

Use `derive_more` for terser construction of `Config` struct for `ModelParams` variants.

* refactor: Use `buildstructor` for builder API

This is an alternative approach to build the config. Construction of the config from the param inputs is handled at the end now, not dependent upon separate `new()` + optional `with_adapter()` calls on a mutable variable.

Unfortunately `buildstructor` and `typed-builder` APIs don't allow for easy flexibility of builder methods in different scopes (_due to moves_). `derive-builder` can do this but not with the more complex types due to lack of a `Copy` / `Clone`. Thus the `None` option is required as input regardless of if an adapter is needed.

* chore: Wrap `model_config` assignment into expression

This better communicates the block is only relevant to assigning this value. While the two `is_lora` / `is_xlora` variables are hoisted above due to usage later as metadata inputs.

* fix: Drop `is_lora` from `GeneralMetadata`

`pipeline/gguf.rs` + `pipeline/ggml.rs` now ensure that `activate_adapters()` works for X-LoRA too. This is assumed as a bugfix due to the `XLoraLlama` model the two adapter kinds share along with code everywhere else checking `is_xlora`, no other usage of `is_lora` seems to be used.
- To ensure further ambiguity is avoided, the condition is better communicated as `has_adapter`.
- It is unclear if all usage of `is_xlora` is specific to X-LoRA or also intended to be applicable to LoRA since `XLora*` models do impl `is_xlora() -> true` (except Gemma, which is a potential bug).

`pipeline/normal.rs` handled it's own `is_xlora` bool differently than `gguf.rs` / `ggml.rs` loaders.
- It relied upon`model.is_xlora() && !is_lora`, but we already assume X-LoRA via prior matching  on `ModelKind` which now provides this information via it's own `is_x_lora()` method.
- Only `xlora_models/gemma.rs` would behave differently with this change, but Gemma might have meant to return `true`?

* chore: Match on adapter

Matches are only for `Quantized` or `AdapterQuantized` variants with no difference in handling by `AdapterKind` variant used.

Additionally restores the `GGUF X-LoRA` bail formatted string. For consistency the non-adapter branch also appends `for GGUF` and the architecture in the lora branch now comes before the `ModelKind`.

* chore: Support params via tuple `into()` + add note of possible bug

* breaking: Replace `ModelKind` with new version

A better approach for the most part at encoding the kind info.

* lint(clippy): Appease the lint gods

`model_config.rs` GGUF and GGML structs prefixed with `Params`.

Two exceptions added as the concerns don't seem to warrant change:
- `#[allow(clippy::borrowed_box)]`
- `#[allow(clippy::large_enum_variant)]`

* chore: Convert from `CRLF` to `LF`

This file has no other change in the commit beyond line ending conversion. It was mistakenly using CRLF since creation.

* lint(rustfmt): Appease the lint gods

* fix: Restore `is_lora` condition

`GeneralMetadata` now stores the `ModelKind` for this type of information.

`activate_adapters()` error message revised. `mod.rs` version includes contextual comment about X-LoRA not being equivalent.

* fix: Gemma X-Lora model `is_xlora()` should return `true`

Most likely redundant with `GeneralMetadata` now having `ModelKind` to query, but fixing here until all queries replaced.

Additionally updates `model_config.rs` note to clarify not a bug.
  • Loading branch information
polarathene authored May 31, 2024
1 parent e1c3e6e commit 1d21c5f
Show file tree
Hide file tree
Showing 13 changed files with 519 additions and 337 deletions.
3 changes: 3 additions & 0 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ once_cell.workspace = true
toml = "0.8.12"
strum = { version = "0.26", features = ["derive"] }
derive_more = { version = "0.99.17", default-features = false, features = ["from"] }
akin = "0.4.0"
variantly = "0.4.0"
buildstructor = "0.5.4"
tracing-subscriber.workspace = true
reqwest = { version = "0.12.4", features = ["blocking"] }

Expand Down
11 changes: 8 additions & 3 deletions mistralrs-core/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::layers::{
repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, QRmsNorm, ScaledDotProductAttention,
};
use crate::pipeline::{extract_logits, Cache};
use crate::utils::model_config as ModelConfig;
use crate::DeviceMapMetadata;

const MAX_SEQ_LEN: u32 = 4096;
Expand Down Expand Up @@ -194,8 +195,8 @@ pub struct ModelWeights {
mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
}

impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
impl ModelConfig::FromGGML for ModelWeights {
fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let rotary = RotaryEmbedding::new_partial(
10000.,
Expand Down Expand Up @@ -254,8 +255,10 @@ impl ModelWeights {
mapper: None,
})
}
}

pub fn from_gguf<R: std::io::Seek + std::io::Read>(
impl ModelConfig::FromGGUF for ModelWeights {
fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
Expand Down Expand Up @@ -383,7 +386,9 @@ impl ModelWeights {
mapper: Some(mapper),
})
}
}

impl ModelWeights {
pub fn forward(
&mut self,
x: &Tensor,
Expand Down
7 changes: 5 additions & 2 deletions mistralrs-core/src/models/quantized_phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::device_map::DeviceMapper;
use crate::layers::ScaledDotProductAttention;
use crate::layers::{repeat_kv, CausalMasker, QLinear};
use crate::pipeline::{extract_logits, Cache};
use crate::utils::model_config as ModelConfig;
use crate::DeviceMapMetadata;

pub const MAX_SEQ_LEN: usize = 4096;
Expand Down Expand Up @@ -141,8 +142,8 @@ fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
Ok(ln)
}

impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
impl ModelConfig::FromGGUF for ModelWeights {
fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
Expand Down Expand Up @@ -211,7 +212,9 @@ impl ModelWeights {
mapper,
})
}
}

impl ModelWeights {
pub fn forward(
&mut self,
input_ids: &Tensor,
Expand Down
7 changes: 5 additions & 2 deletions mistralrs-core/src/models/quantized_phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::layers::{
repeat_kv, verify_sanity_gguf, CausalMasker, MatMul, RmsNorm, ScaledDotProductAttention,
};
use crate::pipeline::Cache;
use crate::utils::model_config as ModelConfig;
use crate::DeviceMapMetadata;
use candle_core::quantized::gguf_file;
use candle_core::quantized::QMatMul;
Expand Down Expand Up @@ -159,8 +160,8 @@ fn precomput_freqs_cis(
Ok((cos, sin))
}

impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
impl ModelConfig::FromGGUF for ModelWeights {
fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
Expand Down Expand Up @@ -248,7 +249,9 @@ impl ModelWeights {
max_seq_len: context_window,
})
}
}

impl ModelWeights {
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offsets: &[usize]) -> Result<Tensor> {
let (_b_sz, seq_len) = input_ids.dims2()?;
let mut xs = self.tok_embeddings.forward(input_ids)?;
Expand Down
124 changes: 46 additions & 78 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::cache_manager::DefaultCacheManager;
use super::{
get_model_paths, get_xlora_paths, CacheManager, GeneralMetadata, Loader, ModelInputs,
ModelKind, ModelPaths, Pipeline, TokenSource, XLoraPaths,
get_model_paths, get_xlora_paths, AdapterKind, CacheManager, GeneralMetadata, Loader,
ModelInputs, ModelKind, ModelPaths, Pipeline, QuantizationKind, TokenSource, XLoraPaths,
};
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
Expand All @@ -11,8 +11,8 @@ use crate::pipeline::{get_chat_template, Cache};
use crate::pipeline::{ChatTemplate, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
use crate::sequence::Sequence;
use crate::utils::model_config as ModelConfig;
use crate::utils::tokenizer::get_tokenizer;
use crate::utils::varbuilder_utils::{from_mmaped_safetensors, load_preload_adapters};
use crate::xlora_models::NonGranularState;
use crate::{do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, DEBUG};
use crate::{
Expand Down Expand Up @@ -96,12 +96,16 @@ impl GGMLLoaderBuilder {
quantized_model_id: String,
quantized_filename: String,
) -> Self {
let kind = ModelKind::Quantized {
quant: QuantizationKind::Ggml,
};

Self {
config,
chat_template,
tokenizer_json,
model_id,
kind: ModelKind::QuantizedGGML,
kind,
quantized_filename,
quantized_model_id,
..Default::default()
Expand Down Expand Up @@ -138,7 +142,8 @@ impl GGMLLoaderBuilder {
no_kv_cache: bool,
tgt_non_granular_index: Option<usize>,
) -> Self {
self.kind = ModelKind::XLoraGGML;
self.kind = (AdapterKind::XLora, QuantizationKind::Ggml).into();

self.with_adapter(
xlora_model_id,
xlora_order,
Expand All @@ -148,7 +153,8 @@ impl GGMLLoaderBuilder {
}

pub fn with_lora(mut self, lora_model_id: String, lora_order: Ordering) -> Self {
self.kind = ModelKind::LoraGGML;
self.kind = (AdapterKind::Lora, QuantizationKind::Ggml).into();

self.with_adapter(lora_model_id, lora_order, false, None)
}

Expand Down Expand Up @@ -236,7 +242,7 @@ impl Loader for GGMLLoader {

if in_situ_quant.is_some() {
anyhow::bail!(
"You are trying to in-situ quantize a GGUF model. This will not do anything."
"You are trying to in-situ quantize a GGML model. This will not do anything."
);
}
if !mapper.is_dummy() {
Expand Down Expand Up @@ -267,69 +273,33 @@ impl Loader for GGMLLoader {
info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_ggml_tensors.txt`.");
}

let mut is_lora = false;
let model = match self.kind {
ModelKind::QuantizedGGML => Model::Llama(QLlama::from_ggml(model, self.config.gqa)?),
ModelKind::XLoraGGML => {
let vb = from_mmaped_safetensors(
vec![paths.get_classifier_path().as_ref().unwrap().to_path_buf()],
paths
.get_adapter_filenames()
.as_ref()
.unwrap()
.iter()
.map(|(_, x)| (*x).to_owned())
.collect::<Vec<_>>(),
DType::F32,
device,
silent,
)?;

Model::XLoraLlama(XLoraQLlama::from_ggml(
model,
self.config.gqa,
paths.get_adapter_configs().as_ref().unwrap(),
&vb,
paths.get_ordering().as_ref().unwrap(),
Some(paths.get_classifier_config().as_ref().unwrap().clone()),
&load_preload_adapters(
paths.get_lora_preload_adapter_info(),
DType::F32,
device,
silent,
)?,
)?)
let has_adapter = self.kind.is_adapted();
let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());

let model_config = {
// Base config (quantization only):
let quant = ModelConfig::ParamsGGML((model, self.config.gqa).into());

// With optional adapter config:
let mut adapter = None;
if has_adapter {
adapter.replace(ModelConfig::Adapter::try_new(
paths, device, silent, is_xlora,
)?);
}
ModelKind::LoraGGML => {
is_lora = true;
let vb = from_mmaped_safetensors(
vec![],
paths
.get_adapter_filenames()
.as_ref()
.unwrap()
.iter()
.map(|(_, x)| (*x).to_owned())
.collect::<Vec<_>>(),
DType::F32,
device,
silent,
)?;

Model::XLoraLlama(XLoraQLlama::from_ggml(
model,
self.config.gqa,
paths.get_adapter_configs().as_ref().unwrap(),
&vb,
paths.get_ordering().as_ref().unwrap(),
None,
&load_preload_adapters(
paths.get_lora_preload_adapter_info(),
DType::F32,
device,
silent,
)?,
)?)

ModelConfig::ModelParams::builder()
.quant(quant)
.and_adapter(adapter)
.build()
};

// Config into model:
// NOTE: No architecture to infer like GGUF, Llama model is implicitly matched
let model = match self.kind {
ModelKind::Quantized { .. } => Model::Llama(QLlama::try_from(model_config)?),
ModelKind::AdapterQuantized { .. } => {
Model::XLoraLlama(XLoraQLlama::try_from(model_config)?)
}
_ => unreachable!(),
};
Expand All @@ -345,10 +315,6 @@ impl Loader for GGMLLoader {
Model::XLoraLlama(ref xl) => xl.max_seq_len,
};
let tok_trie: Arc<TokTrie> = build_tok_trie(tokenizer.clone()).into();
let is_xlora = match &model {
Model::Llama(_) => false,
Model::XLoraLlama(_) => !is_lora,
};
let num_hidden_layers = match model {
Model::Llama(ref model) => model.cache.lock().len(),
Model::XLoraLlama(ref model) => model.cache.lock().len(),
Expand All @@ -372,10 +338,10 @@ impl Loader for GGMLLoader {
repeat_last_n: self.config.repeat_last_n,
tok_trie,
has_no_kv_cache: self.no_kv_cache,
is_xlora,
num_hidden_layers,
eos_tok: eos,
is_lora,
kind: self.kind.clone(),
is_xlora,
},
})))
}
Expand Down Expand Up @@ -508,14 +474,16 @@ impl Pipeline for GGMLPipeline {
}
}
fn activate_adapters(&mut self, adapter_names: Vec<String>) -> anyhow::Result<usize> {
if !self.metadata.is_lora {
anyhow::bail!("Cannot activate adapters non-LoRA models.")
let is_lora = self.metadata.kind.is_adapted_and(|a| a.is_lora());
if !is_lora {
anyhow::bail!("Activating adapters is only supported for models fine-tuned with LoRA.")
}

match self.model {
Model::Llama(_) => unreachable!(),
Model::XLoraLlama(ref mut model) => model
.activate_adapters(adapter_names)
.map_err(anyhow::Error::msg),
_ => unreachable!(),
}
}
}
Loading

0 comments on commit 1d21c5f

Please sign in to comment.