diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 7cecdd513ae4..27f276f99713 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -3888,6 +3888,18 @@ jobs: AFTER_SCRIPT: | rm -rf nemo_experiments + # L2: SpeechLM tests + L2_HF_Transformer_SpeechLM_SFT_2gpu: + needs: [pre-flight, cicd-test-container-build] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_HF_Transformer_SpeechLM_SFT_2gpu') || needs.pre-flight.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + TRANSFORMERS_OFFLINE=1 python tests/collections/speechlm/hf/sft.py --model /home/TestData/speechlm/whisper-small/ --max-steps 10 --devices 2 --strategy ddp + AFTER_SCRIPT: | + rm -rf nemo_experiments + # L2: Megatron Mock Data Generation L2_Megatron_Mock_Data_Generation_MockGPTDataset: needs: [pre-flight, cicd-test-container-build] @@ -5164,6 +5176,7 @@ jobs: - L2_HF_Transformer_PT_2gpu - L2_HF_Transformer_PT_2gpu_nemorun - L2_HF_Transformer_PT_TE_Acceleration + - L2_HF_Transformer_SpeechLM_SFT_2gpu - L2_NeMo_2_SSM_Pretraining - L2_NeMo_2_SSM_Finetuning - L2_NeMo_2_T5_Pretraining diff --git a/examples/speechlm/sft/hf.py b/examples/speechlm/sft/hf.py index 96e785dac97f..3a64ea62dcd3 100755 --- a/examples/speechlm/sft/hf.py +++ b/examples/speechlm/sft/hf.py @@ -27,6 +27,17 @@ class LhotseHfNeMoDataset(torch.utils.data.Dataset): + """Class for a speechLM dataset + + Args: + processor (AutoProcessor): the processor to use + tokenizer (AutoTokenizer): the tokenizer to use + decoder_mask_fill (int): Value to fill in decoder mask + + Returns: + pl.LightningDataModule: the dataset to train with. + """ + def __init__(self, processor, tokenizer, decoder_mask_fill=-100): super().__init__() self.processor = processor @@ -69,6 +80,7 @@ def __getitem__(self, cuts): # Models can be one of the supported ones by AutoModelForSpeechSeq2Seq such as # openai/whisper-large-v3 and facebook/s2t-small-librispeech-asr parser.add_argument('--model', default='openai/whisper-large-v3') + parser.add_argument('--data-path', type=str, required=True) parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp']) parser.add_argument('--devices', default=1) parser.add_argument('--accelerator', default='gpu', choices=['gpu']) @@ -83,7 +95,7 @@ def __getitem__(self, cuts): config = OmegaConf.create( { - "cuts_path": "/opt/checkpoints/lhotse/libri/libri-train-5.jsonl.gz", + "cuts_path": args.data_path, "sample_rate": 16000, "shuffle": True, "num_workers": 2, diff --git a/tests/collections/speechlm/hf/sft.py b/tests/collections/speechlm/hf/sft.py new file mode 100755 index 000000000000..41f626f10852 --- /dev/null +++ b/tests/collections/speechlm/hf/sft.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fiddle as fdl +import torch +from lhotse.dataset.collation import collate_matrices, collate_vectors +from omegaconf import OmegaConf + +from nemo import lightning as nl +from nemo.collections import speechlm +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.speechlm.models import HFAutoModelForSpeechSeq2Seq + +torch.set_float32_matmul_precision("medium") + + +class LhotseHfNeMoDataset(torch.utils.data.Dataset): + def __init__(self, processor, tokenizer, decoder_mask_fill=-100): + super().__init__() + self.processor = processor + self.tokenizer = tokenizer + self.decoder_mask_fill = decoder_mask_fill + + def __getitem__(self, cuts): + features = [] + for cut in cuts: + audio = cut.load_audio() + features.append( + self.processor( + audio, + sampling_rate=cut.sampling_rate, + return_tensors="pt", + text=cut.supervisions[0].text, + ) + ) + + input_features = collate_matrices(tensors=[f["input_features"].squeeze(0) for f in features]) + labels = collate_vectors(tensors=[c.supervisions[0].tokens for c in cuts]) + decoder_input_ids = labels[:, :-1] + decoder_input_ids = decoder_input_ids.masked_fill( + decoder_input_ids == self.decoder_mask_fill, self.tokenizer.pad_id + ) + labels = labels[:, 1:].reshape(-1) + + return { + "input_features": input_features, + "labels": labels, + "decoder_input_ids": decoder_input_ids, + } + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + + # Models can be one of the supported ones by AutoModelForSpeechSeq2Seq such as + # openai/whisper-large-v3 and facebook/s2t-small-librispeech-asr + parser.add_argument('--model', default='openai/whisper-large-v3') + parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp']) + parser.add_argument('--devices', default=1) + parser.add_argument('--accelerator', default='gpu', choices=['gpu']) + parser.add_argument('--max-steps', type=int, default=100) + parser.add_argument('--model-save-path', type=str, default=None) + args = parser.parse_args() + + model = HFAutoModelForSpeechSeq2Seq(model_name=args.model) + model = model.to(torch.float) + processor = model.processor + tokenizer = AutoTokenizer(args.model, include_special_tokens=True) + + config = OmegaConf.create( + { + "cuts_path": "/home/TestData/speechlm/lhotse/libri/libri-train-5.jsonl.gz", + "sample_rate": 16000, + "shuffle": True, + "num_workers": 2, + "batch_size": 4, + "shuffle_buffer_size": 100, + } + ) + + train_dataloader = get_lhotse_dataloader_from_config( + config, + global_rank=0, + world_size=1, + dataset=LhotseHfNeMoDataset( + processor=processor, + tokenizer=tokenizer, + ), + tokenizer=tokenizer, + ) + + speechlm.api.finetune( + model=model, + data=train_dataloader, + trainer=nl.Trainer( + devices=args.devices, + max_steps=args.max_steps, + accelerator=args.accelerator, + strategy=args.strategy, + precision="bf16-mixed", + log_every_n_steps=1, + limit_val_batches=0.0, + num_sanity_val_steps=0, + accumulate_grad_batches=10, + gradient_clip_val=0.5, + use_distributed_sampler=False, + callbacks=[], + logger=None, + ), + optim=fdl.build(speechlm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), + log=None, + ) + + if args.model_save_path is not None: + model.save_pretrained(args.model_save_path)