Skip to content

Commit

Permalink
Add JetMoE model (#30005)
Browse files Browse the repository at this point in the history
* init jetmoe code

* update archive maps

* remove flax import

* fix import error

* update README

* ruff fix

* update readme

* fix

* update config

* fix issue

* merge files

* fix model bug

* fix test

* auto fix

* model size

* add comments

* fix form

* add flash attention support

* fix attention head number

* fix init

* fix support list

* sort auto mapping

* fix test

* fix docs

* update test

* fix test

* fix test

* change variable name

* fix config

* fix init

* update format

* clean code

* fix config

* fix config

* change default config

* update config

* fix issues

* update formate

* update config argument

* update format

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* change to mixtral aux loss

* change to cache_position

* debug

* fix bugs

* debug

* fix format

* fix format

* fix copy

* fix format

* fix format

* fix sort

* fix sort

* fix sort

* add copy comment

* add copy from

* remove debug code

* revert readme update

* add copy

* debug

* remove debug code

* fix flash attention

* add comments

* clean code

* clean format

* fix format

* fix format

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Younes Belkada <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Younes Belkada <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Younes Belkada <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Younes Belkada <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Younes Belkada <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Younes Belkada <[email protected]>

* change variable name

* add copied from

* fix variable name

* remove deprecated functinos

* sync to llama implementation

* fix format

* fix copy

* fix format

* update format

* remove repr

* add comment for moe weight

* fix copy

* Update src/transformers/models/jetmoe/configuration_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* add comments and reformat config

* fix format

* fix format

* fix format

* update test

* update doc string in config

* Update src/transformers/models/jetmoe/modeling_jetmoe.py

Co-authored-by: Arthur <[email protected]>

* update config doc

* update attention cache

* fix format

* fix copy

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
  • Loading branch information
3 people authored May 14, 2024
1 parent d84f34a commit ccdabc5
Show file tree
Hide file tree
Showing 15 changed files with 2,455 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@
title: I-BERT
- local: model_doc/jamba
title: Jamba
- local: model_doc/jetmoe
title: JetMoe
- local: model_doc/jukebox
title: Jukebox
- local: model_doc/led
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Informer](model_doc/informer) ||||
| [InstructBLIP](model_doc/instructblip) ||||
| [Jamba](model_doc/jamba) ||||
| [JetMoe](model_doc/jetmoe) ||||
| [Jukebox](model_doc/jukebox) ||||
| [KOSMOS-2](model_doc/kosmos-2) ||||
| [LayoutLM](model_doc/layoutlm) ||||
Expand Down
49 changes: 49 additions & 0 deletions docs/source/en/model_doc/jetmoe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
<!--Copyright 2024 JetMoe team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# JetMoe

## Overview

**JetMoe-8B** is an 8B Mixture-of-Experts (MoE) language model developed by [Yikang Shen](https://scholar.google.com.hk/citations?user=qff5rRYAAAAJ) and [MyShell](https://myshell.ai/).
JetMoe project aims to provide a LLaMA2-level performance and efficient language model with a limited budget.
To achieve this goal, JetMoe uses a sparsely activated architecture inspired by the [ModuleFormer](https://arxiv.org/abs/2306.04640).
Each JetMoe block consists of two MoE layers: Mixture of Attention Heads and Mixture of MLP Experts.
Given the input tokens, it activates a subset of its experts to process them.
This sparse activation schema enables JetMoe to achieve much better training throughput than similar size dense models.
The training throughput of JetMoe-8B is around 100B tokens per day on a cluster of 96 H100 GPUs with a straightforward 3-way pipeline parallelism strategy.

This model was contributed by [Yikang Shen](https://huggingface.co/YikangS).


## JetMoeConfig

[[autodoc]] JetMoeConfig

## JetMoeModel

[[autodoc]] JetMoeModel
- forward

## JetMoeForCausalLM

[[autodoc]] JetMoeForCausalLM
- forward

## JetMoeForSequenceClassification

[[autodoc]] JetMoeForSequenceClassification
- forward
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel)
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
Expand Down Expand Up @@ -198,6 +199,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@
"InstructBlipVisionConfig",
],
"models.jamba": ["JambaConfig"],
"models.jetmoe": ["JetMoeConfig"],
"models.jukebox": [
"JukeboxConfig",
"JukeboxPriorConfig",
Expand Down Expand Up @@ -2202,6 +2203,14 @@
"JambaPreTrainedModel",
]
)
_import_structure["models.jetmoe"].extend(
[
"JetMoeForCausalLM",
"JetMoeForSequenceClassification",
"JetMoeModel",
"JetMoePreTrainedModel",
]
)
_import_structure["models.jukebox"].extend(
[
"JukeboxModel",
Expand Down Expand Up @@ -4973,6 +4982,7 @@
InstructBlipVisionConfig,
)
from .models.jamba import JambaConfig
from .models.jetmoe import JetMoeConfig
from .models.jukebox import (
JukeboxConfig,
JukeboxPriorConfig,
Expand Down Expand Up @@ -6591,6 +6601,12 @@
JambaModel,
JambaPreTrainedModel,
)
from .models.jetmoe import (
JetMoeForCausalLM,
JetMoeForSequenceClassification,
JetMoeModel,
JetMoePreTrainedModel,
)
from .models.jukebox import (
JukeboxModel,
JukeboxPreTrainedModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
informer,
instructblip,
jamba,
jetmoe,
jukebox,
kosmos2,
layoutlm,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
("informer", "InformerConfig"),
("instructblip", "InstructBlipConfig"),
("jamba", "JambaConfig"),
("jetmoe", "JetMoeConfig"),
("jukebox", "JukeboxConfig"),
("kosmos-2", "Kosmos2Config"),
("layoutlm", "LayoutLMConfig"),
Expand Down Expand Up @@ -399,6 +400,7 @@
("informer", "Informer"),
("instructblip", "InstructBLIP"),
("jamba", "Jamba"),
("jetmoe", "JetMoe"),
("jukebox", "Jukebox"),
("kosmos-2", "KOSMOS-2"),
("layoutlm", "LayoutLM"),
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("jamba", "JambaModel"),
("jetmoe", "JetMoeModel"),
("jukebox", "JukeboxModel"),
("kosmos-2", "Kosmos2Model"),
("layoutlm", "LayoutLMModel"),
Expand Down Expand Up @@ -458,6 +459,7 @@
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
("gptj", "GPTJForCausalLM"),
("jamba", "JambaForCausalLM"),
("jetmoe", "JetMoeForCausalLM"),
("llama", "LlamaForCausalLM"),
("mamba", "MambaForCausalLM"),
("marian", "MarianForCausalLM"),
Expand Down Expand Up @@ -860,6 +862,7 @@
("gptj", "GPTJForSequenceClassification"),
("ibert", "IBertForSequenceClassification"),
("jamba", "JambaForSequenceClassification"),
("jetmoe", "JetMoeForSequenceClassification"),
("layoutlm", "LayoutLMForSequenceClassification"),
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"jetmoe",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("jukebox", ("JukeboxTokenizer", None)),
(
"kosmos-2",
Expand Down
56 changes: 56 additions & 0 deletions src/transformers/models/jetmoe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 JetMoe AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available


_import_structure = {
"configuration_jetmoe": ["JetMoeConfig"],
}


try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_jetmoe"] = [
"JetMoeForCausalLM",
"JetMoeModel",
"JetMoePreTrainedModel",
"JetMoeForSequenceClassification",
]

if TYPE_CHECKING:
from .configuration_jetmoe import JetMoeConfig

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_jetmoe import (
JetMoeForCausalLM,
JetMoeForSequenceClassification,
JetMoeModel,
JetMoePreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
149 changes: 149 additions & 0 deletions src/transformers/models/jetmoe/configuration_jetmoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# coding=utf-8
# Copyright 2024 JetMoe AI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JetMoe model configuration"""

from ...configuration_utils import PretrainedConfig
from ...utils import logging


logger = logging.get_logger(__name__)


class JetMoeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`JetMoeModel`]. It is used to instantiate a
JetMoe model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a configuration of the JetMoe-4B.
[jetmoe/jetmoe-8b](https://huggingface.co/jetmoe/jetmoe-8b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the JetMoe model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`JetMoeModel`]
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each key and value in the Transformer encoder.
kv_channels (`int`, *optional*, defaults to 128):
Defines the number of channels for the key and value tensors.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimension of the MLP representations.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with. JetMoe's attention allows sequence of
up to 4096 tokens.
activation_function (`string`, *optional*, defaults to `"silu"`):
Defines the activation function for MLP experts.
num_local_experts (`int`, *optional*, defaults to 8):
Defines the number of experts in the MoE and MoA.
num_experts_per_tok (`int, *optional*, defaults to 2):
The number of experts to route per-token and for MoE and MoA.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabeling this will also
allow the model to output the auxiliary loss.
aux_loss_coef (`float`, *optional*, defaults to 0.01):
The coefficient for the auxiliary loss.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
initializer_range (`float`, *optional*, defaults to 0.01):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import JetMoeModel, JetMoeConfig
>>> # Initializing a JetMoe 4B style configuration
>>> configuration = JetMoeConfig()
>>> # Initializing a model from the JetMoe 4B style configuration
>>> model = JetMoeModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "jetmoe"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=32000,
hidden_size=2048,
num_hidden_layers=12,
num_key_value_heads=16,
kv_channels=128,
intermediate_size=5632,
max_position_embeddings=4096,
activation_function="silu",
num_local_experts=8,
num_experts_per_tok=2,
output_router_logits=False,
aux_loss_coef=0.01,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=True,
rope_theta=10000.0,
rms_norm_eps=1e-6,
initializer_range=0.01,
attention_dropout=0.0,
**kwargs,
):
if num_experts_per_tok > num_local_experts:
raise ValueError("`num_experts_per_tok` must be less than or equal to `num_local_experts`")
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_key_value_heads * num_experts_per_tok
self.num_key_value_heads = num_key_value_heads
self.kv_channels = kv_channels
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
self.activation_function = activation_function
self.num_local_experts = num_local_experts
self.num_experts_per_tok = num_experts_per_tok
self.output_router_logits = output_router_logits
self.aux_loss_coef = aux_loss_coef
self.use_cache = use_cache
self.initializer_range = initializer_range
self.attention_dropout = attention_dropout

self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

self.rope_theta = rope_theta
self.rms_norm_eps = rms_norm_eps

super().__init__(
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
)
Loading

0 comments on commit ccdabc5

Please sign in to comment.