From 66efeca355169c54d8c4b89d47f91a9793ae1995 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 30 Oct 2024 00:17:56 -0700 Subject: [PATCH 01/12] add mistral/mixtral peft ci test Signed-off-by: Alexandros Koumparoulis --- .github/workflows/cicd-main.yml | 32 +++++++ tests/collections/llm/lora_mistralai.py | 117 ++++++++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 tests/collections/llm/lora_mistralai.py diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index bb239acb00fc..0e7ae99330ac 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4266,6 +4266,38 @@ jobs: --pp_size 1 \ --mbs 1 --packed + + L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + + python tests/collections/llm/lora_mistralai.py \ + --devices 2 \ + --max_steps 3 \ + --tp 1 \ + --mbs 1 \ + --model mixtral + + L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + + python tests/collections/llm/lora_mistralai.py \ + --devices 2 \ + --max_steps 3 \ + --tp 1 \ + --mbs 1 \ + --model mistral + + L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml diff --git a/tests/collections/llm/lora_mistralai.py b/tests/collections/llm/lora_mistralai.py new file mode 100644 index 000000000000..4c527ec3319f --- /dev/null +++ b/tests/collections/llm/lora_mistralai.py @@ -0,0 +1,117 @@ +import pytorch_lightning as pl +import torch +from megatron.core.optimizer import OptimizerConfig +import argparse +from nemo import lightning as nl +from nemo.collections import llm +from nemo.lightning.io.mixin import track_io + + + +def get_args(): + parser = argparse.ArgumentParser(description='Finetune a small GPT model using NeMo 2.0') + parser.add_argument('--model', type=str.lower, choices=['mistral', 'mixtral'], help="model") + parser.add_argument('--max-steps', type=int, default=9, help="number of devices") + parser.add_argument('--mbs', type=int, default=2, help="micro batch size") + parser.add_argument('--gbs', type=int, default=4, help="global batch size") + parser.add_argument('--tp', type=int, default=1, help="tensor parallel size") + return parser.parse_args() + + +def trainer(devices, tp, sp, max_steps) -> nl.Trainer: + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=tp, + sequence_parallel=sp, + ) + + return nl.Trainer( + devices=tp, + max_steps=max_steps, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + log_every_n_steps=1, + limit_val_batches=0, + val_check_interval=0, + num_sanity_val_steps=0, + ) + + +@track_io +class OrdTokenizer: + def __init__(self, vocab_size=30_000, num_reserved_tokens=128, special_token_names=['bos_id', 'eos_id', 'pad_id']): + self.vocab_size = vocab_size + self.num_reserved_tokens = num_reserved_tokens + self.special_token_names = special_token_names + assert len(self.special_token_names) < num_reserved_tokens + + def __getattr__(self, name): + if name in self.__dict__.get('special_token_names', {}): + return self.__dict__['special_token_names'].index(name) + elif name in self.__dict__: + return self.__dict__[name] + else: + raise AttributeError + + def text_to_ids(self, text): + token_ids = list(map(lambda x: self.num_reserved_tokens + ord(x), list(text))) + assert max(token_ids) < self.vocab_size + return token_ids + + +def logger() -> nl.NeMoLogger: + ckpt = nl.ModelCheckpoint( + save_last=True, + every_n_train_steps=10, + monitor="reduced_train_loss", + save_top_k=1, + save_on_train_epoch_end=True, + save_optim_on_train_end=True, + ) + + return nl.NeMoLogger( + name="nemo2_peft", + log_dir="/tmp/peft_logs", + use_datetime_version=False, # must be false if using auto resume + ckpt=ckpt, + wandb=None, + ) + +def squad(mbs, gbs) -> pl.LightningDataModule: + return llm.SquadDataModule(seq_length=2048, micro_batch_size=mbs, global_batch_size=gbs, num_workers=0) + +def mixtral_8x7b() -> pl.LightningModule: + tokenizer = OrdTokenizer() + model = llm.MixtralModel(llm.MixtralConfig8x7B(num_layers=2), tokenizer=tokenizer) + lora = llm.peft.LoRA() + return model, lora + +def mistral_7b() -> pl.LightningModule: + tokenizer = OrdTokenizer() + model = llm.MistralModel(llm.MistralConfig7B(num_layers=2), tokenizer=tokenizer) + lora = llm.peft.LoRA() + return model, lora + +if __name__ == '__main__': + args = get_args() + if args.model == 'mistral': + model, lora = mistral_7b() + else: + model, lora = mixtral_8x7b() + llm.finetune( + model=model, + data=squad(args.mbs, args.gbs), + trainer=trainer(args.tp, args.tp, args.tp > 1, args.max_steps), + peft=lora, + log=logger(), + optim=nl.MegatronOptimizerModule( + config=OptimizerConfig( + optimizer="adam", + lr=0.0001, + adam_beta2=0.98, + use_distributed_optimizer=True, + clip_grad=1.0, + bf16=True, + ), + ), + ) From 0cd5256b26deb6dd3c9b859dc8ccd607453c77e3 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 30 Oct 2024 00:19:03 -0700 Subject: [PATCH 02/12] add mistral/mixtral peft ci test Signed-off-by: Alexandros Koumparoulis --- tests/collections/llm/lora_mistralai.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/collections/llm/lora_mistralai.py b/tests/collections/llm/lora_mistralai.py index 4c527ec3319f..8c445e95fcd4 100644 --- a/tests/collections/llm/lora_mistralai.py +++ b/tests/collections/llm/lora_mistralai.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytorch_lightning as pl import torch from megatron.core.optimizer import OptimizerConfig From 53fc65012a888c35236552baf1b8ec82288feb77 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 30 Oct 2024 01:05:53 -0700 Subject: [PATCH 03/12] add mistral tp2 Signed-off-by: Alexandros Koumparoulis --- .github/workflows/cicd-main.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 0e7ae99330ac..2f5005b9f946 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4297,6 +4297,21 @@ jobs: --mbs 1 \ --model mistral + L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + + python tests/collections/llm/lora_mistralai.py \ + --devices 2 \ + --max_steps 3 \ + --tp 2 \ + --mbs 1 \ + --model mistral + L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact: needs: [cicd-test-container-setup] From 129aa41409155c485e2f6e8b46f16eeb33b98341 Mon Sep 17 00:00:00 2001 From: akoumpa Date: Wed, 30 Oct 2024 08:06:50 +0000 Subject: [PATCH 04/12] Apply isort and black reformatting Signed-off-by: akoumpa --- tests/collections/llm/lora_mistralai.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/collections/llm/lora_mistralai.py b/tests/collections/llm/lora_mistralai.py index 8c445e95fcd4..006af2124054 100644 --- a/tests/collections/llm/lora_mistralai.py +++ b/tests/collections/llm/lora_mistralai.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse + import pytorch_lightning as pl import torch from megatron.core.optimizer import OptimizerConfig -import argparse + from nemo import lightning as nl from nemo.collections import llm from nemo.lightning.io.mixin import track_io - def get_args(): parser = argparse.ArgumentParser(description='Finetune a small GPT model using NeMo 2.0') parser.add_argument('--model', type=str.lower, choices=['mistral', 'mixtral'], help="model") @@ -91,21 +92,25 @@ def logger() -> nl.NeMoLogger: wandb=None, ) + def squad(mbs, gbs) -> pl.LightningDataModule: return llm.SquadDataModule(seq_length=2048, micro_batch_size=mbs, global_batch_size=gbs, num_workers=0) + def mixtral_8x7b() -> pl.LightningModule: tokenizer = OrdTokenizer() model = llm.MixtralModel(llm.MixtralConfig8x7B(num_layers=2), tokenizer=tokenizer) lora = llm.peft.LoRA() return model, lora + def mistral_7b() -> pl.LightningModule: tokenizer = OrdTokenizer() model = llm.MistralModel(llm.MistralConfig7B(num_layers=2), tokenizer=tokenizer) lora = llm.peft.LoRA() return model, lora + if __name__ == '__main__': args = get_args() if args.model == 'mistral': @@ -119,7 +124,7 @@ def mistral_7b() -> pl.LightningModule: peft=lora, log=logger(), optim=nl.MegatronOptimizerModule( - config=OptimizerConfig( + config=OptimizerConfig( optimizer="adam", lr=0.0001, adam_beta2=0.98, From 5e6bccbaa42bd5af9f791efaea79ed4c6b27e34b Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 30 Oct 2024 06:01:50 -0700 Subject: [PATCH 05/12] add tests to NEMO_CICD_Test Signed-off-by: Alexandros Koumparoulis --- .github/workflows/cicd-main.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 2f5005b9f946..0e0da988393c 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4469,6 +4469,9 @@ jobs: - L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2 - L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2 - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED + - L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1 + - L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1 + - L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1 - L2_NeMo_2_Mixtral_Pretraining - L2_PTQ_Llama2_FP8 - L2_Community_LLM_Checkpoints_tests_Llama3 From 3b04b0430d41c87629e384b20a4c562cadf8897d Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Wed, 30 Oct 2024 06:09:10 -0700 Subject: [PATCH 06/12] Update .github/workflows/cicd-main.yml MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: oliver könig Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> --- .github/workflows/cicd-main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 0e0da988393c..8683b0328a64 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4300,7 +4300,7 @@ jobs: L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1') || needs.cicd-test-container-setup.outputs.all == 'true' + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1') || needs.cicd-test-container-setup.outputs.all == 'true' with: RUNNER: self-hosted-azure SCRIPT: | From ed28860c7d595ad7bf81f76f04d5126066fd3092 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 31 Oct 2024 00:26:38 -0700 Subject: [PATCH 07/12] fix params Signed-off-by: Alexandros Koumparoulis --- .github/workflows/cicd-main.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 8683b0328a64..d7c84a5fd28f 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4277,7 +4277,7 @@ jobs: python tests/collections/llm/lora_mistralai.py \ --devices 2 \ - --max_steps 3 \ + --max-steps 3 \ --tp 1 \ --mbs 1 \ --model mixtral @@ -4292,7 +4292,7 @@ jobs: python tests/collections/llm/lora_mistralai.py \ --devices 2 \ - --max_steps 3 \ + --max-steps 3 \ --tp 1 \ --mbs 1 \ --model mistral @@ -4307,7 +4307,7 @@ jobs: python tests/collections/llm/lora_mistralai.py \ --devices 2 \ - --max_steps 3 \ + --max-steps 3 \ --tp 2 \ --mbs 1 \ --model mistral From 777be9c130f3a2a6d3e9acd516a7239cf59ba1b2 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 31 Oct 2024 06:18:09 -0700 Subject: [PATCH 08/12] rm devices arg Signed-off-by: Alexandros Koumparoulis --- .github/workflows/cicd-main.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index d7c84a5fd28f..21aeea499b88 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4276,7 +4276,6 @@ jobs: SCRIPT: | python tests/collections/llm/lora_mistralai.py \ - --devices 2 \ --max-steps 3 \ --tp 1 \ --mbs 1 \ @@ -4291,7 +4290,6 @@ jobs: SCRIPT: | python tests/collections/llm/lora_mistralai.py \ - --devices 2 \ --max-steps 3 \ --tp 1 \ --mbs 1 \ @@ -4306,7 +4304,6 @@ jobs: SCRIPT: | python tests/collections/llm/lora_mistralai.py \ - --devices 2 \ --max-steps 3 \ --tp 2 \ --mbs 1 \ From 582012f740f8ec0a0ea2bb0e25a55d69440e01b7 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 31 Oct 2024 08:05:41 -0700 Subject: [PATCH 09/12] add --dist-opt arg Signed-off-by: Alexandros Koumparoulis --- .github/workflows/cicd-main.yml | 9 ++++++--- tests/collections/llm/lora_mistralai.py | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 21aeea499b88..2255f362753f 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4279,7 +4279,8 @@ jobs: --max-steps 3 \ --tp 1 \ --mbs 1 \ - --model mixtral + --model mixtral \ + --dist-opt L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1: needs: [cicd-test-container-setup] @@ -4293,7 +4294,8 @@ jobs: --max-steps 3 \ --tp 1 \ --mbs 1 \ - --model mistral + --model mistral \ + --dist-opt L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1: needs: [cicd-test-container-setup] @@ -4307,7 +4309,8 @@ jobs: --max-steps 3 \ --tp 2 \ --mbs 1 \ - --model mistral + --model mistral \ + --dist-opt L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact: diff --git a/tests/collections/llm/lora_mistralai.py b/tests/collections/llm/lora_mistralai.py index 006af2124054..690895ae92a0 100644 --- a/tests/collections/llm/lora_mistralai.py +++ b/tests/collections/llm/lora_mistralai.py @@ -30,6 +30,7 @@ def get_args(): parser.add_argument('--mbs', type=int, default=2, help="micro batch size") parser.add_argument('--gbs', type=int, default=4, help="global batch size") parser.add_argument('--tp', type=int, default=1, help="tensor parallel size") + parser.add_argument('--dist-opt', action='store_true', help='use dist opt') return parser.parse_args() @@ -128,7 +129,7 @@ def mistral_7b() -> pl.LightningModule: optimizer="adam", lr=0.0001, adam_beta2=0.98, - use_distributed_optimizer=True, + use_distributed_optimizer=args.dist_opt, clip_grad=1.0, bf16=True, ), From 382ba26936600bf89f8c2934e445ea3fd520f9c3 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 31 Oct 2024 23:53:14 -0700 Subject: [PATCH 10/12] add tp=2 mixtral Signed-off-by: Alexandros Koumparoulis --- .github/workflows/cicd-main.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 2255f362753f..d19327c2becf 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4282,6 +4282,21 @@ jobs: --model mixtral \ --dist-opt + L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + + python tests/collections/llm/lora_mistralai.py \ + --max-steps 3 \ + --tp 2 \ + --mbs 1 \ + --model mixtral \ + --dist-opt + L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -4470,6 +4485,7 @@ jobs: - L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2 - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED - L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1 + - L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1 - L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1 - L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1 - L2_NeMo_2_Mixtral_Pretraining From 65b50877c18d61faca584c344eb14f3a46f70317 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 1 Nov 2024 00:03:39 -0700 Subject: [PATCH 11/12] add ep test Signed-off-by: Alexandros Koumparoulis --- .github/workflows/cicd-main.yml | 14 ++++++++++++++ tests/collections/llm/lora_mistralai.py | 8 +++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index d19327c2becf..2bdbe673d19b 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4266,6 +4266,19 @@ jobs: --pp_size 1 \ --mbs 1 --packed + L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + + python tests/collections/llm/lora_mistralai.py \ + --max-steps 3 \ + --ep 1 \ + --mbs 2 \ + --model mixtral L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1: needs: [cicd-test-container-setup] @@ -4484,6 +4497,7 @@ jobs: - L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2 - L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2 - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED + - L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2 - L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1 - L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1 - L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1 diff --git a/tests/collections/llm/lora_mistralai.py b/tests/collections/llm/lora_mistralai.py index 690895ae92a0..f6cc880b8ad8 100644 --- a/tests/collections/llm/lora_mistralai.py +++ b/tests/collections/llm/lora_mistralai.py @@ -30,18 +30,20 @@ def get_args(): parser.add_argument('--mbs', type=int, default=2, help="micro batch size") parser.add_argument('--gbs', type=int, default=4, help="global batch size") parser.add_argument('--tp', type=int, default=1, help="tensor parallel size") + parser.add_arguemnt('--ep', type=int, default=1, help="expert parallel size") parser.add_argument('--dist-opt', action='store_true', help='use dist opt') return parser.parse_args() -def trainer(devices, tp, sp, max_steps) -> nl.Trainer: +def trainer(devices, tp, ep, sp, max_steps) -> nl.Trainer: strategy = nl.MegatronStrategy( tensor_model_parallel_size=tp, + expert_model_parallel_size=ep, sequence_parallel=sp, ) return nl.Trainer( - devices=tp, + devices=max(ep, tp), max_steps=max_steps, accelerator="gpu", strategy=strategy, @@ -121,7 +123,7 @@ def mistral_7b() -> pl.LightningModule: llm.finetune( model=model, data=squad(args.mbs, args.gbs), - trainer=trainer(args.tp, args.tp, args.tp > 1, args.max_steps), + trainer=trainer(args.tp, args.tp, args.ep, args.tp > 1, args.max_steps), peft=lora, log=logger(), optim=nl.MegatronOptimizerModule( From 8677d5060ab6787f4f59d7a7063e5a97385479e3 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 1 Nov 2024 02:50:35 -0700 Subject: [PATCH 12/12] fix Signed-off-by: Alexandros Koumparoulis --- tests/collections/llm/lora_mistralai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/collections/llm/lora_mistralai.py b/tests/collections/llm/lora_mistralai.py index f6cc880b8ad8..09a52668e3ee 100644 --- a/tests/collections/llm/lora_mistralai.py +++ b/tests/collections/llm/lora_mistralai.py @@ -30,7 +30,7 @@ def get_args(): parser.add_argument('--mbs', type=int, default=2, help="micro batch size") parser.add_argument('--gbs', type=int, default=4, help="global batch size") parser.add_argument('--tp', type=int, default=1, help="tensor parallel size") - parser.add_arguemnt('--ep', type=int, default=1, help="expert parallel size") + parser.add_argument('--ep', type=int, default=1, help="expert parallel size") parser.add_argument('--dist-opt', action='store_true', help='use dist opt') return parser.parse_args()