-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for the ArcticForCausalLM. #7020
Added support for the ArcticForCausalLM. #7020
Conversation
It's possible to only offload dense part of the model onto GPU |
convert-hf-to-gguf.py
Outdated
model_arch = gguf.MODEL_ARCH.ARCTIC | ||
|
||
def set_vocab(self): | ||
self._set_vocab_llama_hf() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
re: #6877 (comment), this should be:
self._set_vocab_llama_hf() | |
try: | |
self. _set_vocab_sentencepiece() | |
except FileNotFoundError: | |
self._set_vocab_llama_hf() |
The assertion exists because LlamaHfVocab was primarily written to convert HF "fast" tokenizers with a tokenizer.json. Since before it existed, "slow" sentencepiece tokenizers with a tokenizer.model have (almost?) always been converted using SentencePieceProcessor, which doesn't depend on HF transformers and directly preserves the token types and scores.
If you want to start converting slow tokenizers using HfVocab as well, I won't stop you, but in order to be consistent you'd have to remove all references to SentencePieceProcessor in the convert scripts, and make HF transformers a hard requirement for converting models with a Llama vocab. Otherwise, we'd be making an exception for this model for no clear reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My reason is that the official tokenizer.model file for snowflake-arctic-instruct contains wrong BOS and EOS tokens as confirmed in: https://huggingface.co/Snowflake/snowflake-arctic-instruct/discussions/12
That's why I used llama_hf vocab that reads tokens from json files instead. If there is a better solution for this I'm fully open to any suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cebtenzzre What if I implement ArcticModel::set_vocab() myself like XverseForCausalLM did, is that acceptable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cebtenzzre I now load vocabulary with SentencePieceProcessor as you suggested and apply necessary token modifications based on added_tokens_decoder field from tokenizer_config.json.
…d of HF tokenizer. Add/redefine tokens accordingly to added_tokens_decoder from tokenizer_config.json
…erganov#7075 (shameless copy from LlamaModel).
gguf-py/gguf/tensor_mapping.py
Outdated
if arch in self.arch_block_mappings_cfg: | ||
block_mappings = self.arch_block_mappings_cfg[arch] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means architecture-specific block mappings can't partially override the common mappings (they have to totally re-define everything)?
Maybe this is fixable by adding the common mappings first to self.mapping
, then the architecture-specific mappings?
So maybe using the union operator for dicts would be appropriate here
if arch in self.arch_block_mappings_cfg:
block_mappings = self.block_mappings_cfg | self.arch_block_mappings_cfg[arch]
But that's only supported since Python 3.9, and gguf-py
targets python = ">=3.8"
In this case using {**x, **y}
instead of x | y
would be more compatible for older-than-3.9 versions of Python, and would allow making a new dict with the content of x
augmented/overridden by y
. But the new syntax is clearer in my opinion.
After that, the architecture-specific mapping of MODEL_ARCH.ARCTIC
should be simpler (since they won't need to include duplicates of the common mappings).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the idea is to keep only "conflicting" block mappings in architecture-specific mappings and "non-conflicting" mappings in general mappings? I think using dict.update() is a better idea then. Mappings for ARCTIC arch would be shortened to:
# architecture-specific block mappings
arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = {
MODEL_ARCH.ARCTIC: {
MODEL_TENSOR.FFN_NORM: (
"model.layers.{bid}.residual_layernorm",
),
MODEL_TENSOR.FFN_NORM_EXP: (
"model.layers.{bid}.post_attention_layernorm",
),
},
}
while in the TensorNameMap init we would only have to add:
if arch in self.arch_block_mappings_cfg:
self.block_mappings_cfg.update(self.arch_block_mappings_cfg[arch])
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the idea is to keep only "conflicting" block mappings in architecture-specific mappings and "non-conflicting" mappings in general mappings?
Yes, exactly.
What do you think?
I think using dict.update()
would be good. My proposed approach would have made a copy of the dict, but you're right, updating in-place would work too and would be better, since the original block_mappings_cfg
isn't used later on (I think?).
I agree with using dict.update()
for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, done
…ific ARCTIC mappigs to general mappings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not test this (the model is quite big), but the code looks good to me. Nice work @fairydreaming!
gguf-py/gguf/constants.py
Outdated
@@ -181,6 +182,7 @@ class MODEL_TENSOR(IntEnum): | |||
SSM_A = auto() | |||
SSM_D = auto() | |||
SSM_OUT = auto() | |||
FFN_NORM_EXP = auto() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the actual numbers associated to the enum values of MODEL_TENSOR
don't really matter (their names (from TENSOR_NAMES
) are used instead in GGUF), maybe FFN_NORM_EXP
could be placed right before FFN_GATE_EXP
, a bit like FFN_NORM
is right before FFN_GATE
, for consistency.
If this is changed, it should also be placed similarly in TENSOR_NAMES
and MODEL_TENSORS[MODEL.ARCTIC]
in gguf-py/gguf/constants.py
as well as in the llm_tensor
enum, the LLM_TENSOR_NAMES
mapping, and the llama_layer
struct (and maybe the LLM_ARCH_ARCTIC
case in llm_load_tensors
?) in llama.cpp
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the order as requested, but in llama_layer struct the order is different, so I didn't touch it. In llm_load_tensors I think it was already in the requested order.
I noticed that the arctic model doesn't use bias tensors, so I removed usage of bias tensors in the LLM_ARCH_ARCTIC-related code (they were all nulls anyway). |
I haven't tested as well, but it seems good so feel free to merge |
@ggerganov I noticed that Snowflake changed the Arctic model 2 weeks ago. The commit says: "Fixes for GQA support" and num_key_value_heads in config.json changed value from 56 to 8, so I have to redownload the model and check if it still works. |
Add support for ArcticForCausalLM (ggerganov#7020)
* common : increase max number of experts to 128 * common : add tensor LLM_TENSOR_FFN_NORM_EXPS for normalization before MoE that runs in parallel to attention + ffn * gguf-py : add architecture-specific block mappings that override selected general block mappings * convert-hf : add model conversion support for ArcticForCausalLM * convert-hf : use added_tokens_decoder from tokenizer_config.json to redefine tokens from SentencePiece model (only for ArcticForCausalLM) * llama : add inference support for LLM_ARCH_ARCTIC --------- Co-authored-by: Stanisław Szymczyk <[email protected]>
Fixes #6877
Contains the following changes:
Model files for testing: https://huggingface.co/sszymczyk/snowflake-arctic-instruct-GGUF