From a77d8f9a595ab82cc5beca90b85746eebc74f792 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Thu, 24 Oct 2024 12:04:40 -0500 Subject: [PATCH] Add Nvidia Llama 3.1 70B Nemotron weights (#1803) --- litgpt/config.py | 25 ++++++++++++++++++++++++- litgpt/prompts.py | 2 ++ tutorials/download_model_weights.md | 2 ++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/litgpt/config.py b/litgpt/config.py index 463aa5514c..bb6dd129ce 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -1003,7 +1003,30 @@ def norm_class(self) -> Type: copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) - +######################### +# NVIDIA Llama Nemotron +######################### +configs.append( + dict( + name="Llama-3.1-Nemotron-70B-Instruct-HF", + hf_config=dict(org="nvidia", name="Llama-3.1-Nemotron-70B-Instruct-HF"), + block_size=131072, + vocab_size=128000, + padded_vocab_size=128256, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=28672, + rope_base=500000, + rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) + ), +) ############### # Google Gemma ############### diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 79062adba7..a10086d30f 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -378,6 +378,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Llama2() if re.search("Llama-3.*-Instruct", model_name): return Llama3() + if re.search("Llama-3.*-Instruct-*", model_name): + return Llama3() if re.search("FreeWilly2", model_name): return FreeWilly2() if re.search("Platypus", model_name): diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 35263180ce..9c045ec7d3 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -20,6 +20,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.1 | 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md) | +| Llama 3.1 Nemotron | 70B | NVIDIA | [NVIDIA AI 2024](https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct/modelcard) | | LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | | Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/) | | MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) @@ -152,6 +153,7 @@ mistralai/Mixtral-8x7B-v0.1 NousResearch/Nous-Hermes-13b NousResearch/Nous-Hermes-llama-2-7b NousResearch/Nous-Hermes-Llama2-13b +nvidia/Llama-3.1-Nemotron-70B-Instruct-HF openlm-research/open_llama_13b openlm-research/open_llama_3b openlm-research/open_llama_7b