From 11f44e146a33a9490d086b04f77b73bb10638c22 Mon Sep 17 00:00:00 2001 From: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Date: Fri, 19 Jul 2024 11:02:08 -0700 Subject: [PATCH] Fix few issues and docs for neva and clip in r2.0.0rc1 (#9681) * Fix clip model Signed-off-by: yaoyu-33 * fix neva loading due to fp8 change Signed-off-by: yaoyu-33 * Fix CLIP comment issues Signed-off-by: yaoyu-33 * update docs Signed-off-by: yaoyu-33 * Fix neva tutorial Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * fix siglip few things Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * Update convert_siglip_hf_to_nemo.py Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> * fix comments Signed-off-by: yaoyu-33 --------- Signed-off-by: yaoyu-33 Signed-off-by: yaoyu-33 Signed-off-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: yaoyu-33 --- docs/source/multimodal/mllm/checkpoint.rst | 114 ------- docs/source/multimodal/mllm/intro.rst | 1 - docs/source/multimodal/vlm/checkpoint.rst | 56 ++-- .../neva/conf/llava_config.yaml | 4 +- .../clip/conf/megatron_siglip_config.yaml | 253 ++++++++++++++++ .../clip/convert_external_clip_to_nemo.py | 5 +- .../multimodal/data/clip/clip_dataset.py | 5 +- .../models/multimodal_llm/neva/neva_model.py | 11 + .../clip/megatron_clip_models.py | 1 + .../convert_clip_hf_to_nemo.py | 2 +- .../convert_llava_hf_to_nemo.py | 4 +- .../convert_siglip_hf_to_nemo.py | 40 ++- tutorials/multimodal/NeVA Tutorial.ipynb | 277 +++++++++--------- 13 files changed, 459 insertions(+), 314 deletions(-) delete mode 100644 docs/source/multimodal/mllm/checkpoint.rst create mode 100644 examples/multimodal/vision_language_foundation/clip/conf/megatron_siglip_config.yaml diff --git a/docs/source/multimodal/mllm/checkpoint.rst b/docs/source/multimodal/mllm/checkpoint.rst deleted file mode 100644 index d1fe7b651e66..000000000000 --- a/docs/source/multimodal/mllm/checkpoint.rst +++ /dev/null @@ -1,114 +0,0 @@ -Checkpoints -=========== - -In this section, we present four key functionalities of NVIDIA NeMo related to checkpoint management: - -1. **Checkpoint Loading**: Load local ``.nemo`` checkpoint files with the :code:`restore_from()` method. -2. **Partial Checkpoint Conversion**: Convert partially-trained ``.ckpt`` checkpoints to the ``.nemo`` format. -3. **Community Checkpoint Conversion**: Transition checkpoints from community sources, like HuggingFace, into the ``.nemo`` format. -4. **Model Parallelism Adjustment**: Modify model parallelism to efficiently train models that exceed the memory of a single GPU. NeMo employs both tensor (intra-layer) and pipeline (inter-layer) model parallelisms. Dive deeper with "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM" (`link `_). This tool aids in adjusting model parallelism, accommodating users who need to deploy on larger GPU arrays due to memory constraints. - -Understanding Checkpoint Formats --------------------------------- - -A ``.nemo`` checkpoint is fundamentally a tar file that bundles the model configurations (given as a YAML file), model weights, and other pertinent artifacts like tokenizer models or vocabulary files. This consolidated design streamlines sharing, loading, tuning, evaluating, and inference. - -On the other hand, the ``.ckpt`` file is a product of PyTorch Lightning training. It stores model weights and optimizer states, and it's generally used for resuming training. - -Subsequent sections delve into each of the previously listed functionalities, emphasizing the loading of fully trained checkpoints for evaluation or additional fine-tuning. - - -Loading Local Checkpoints -------------------------- - -NeMo inherently saves any model's checkpoints in the ``.nemo`` format. To manually save a model at any stage: - -.. code-block:: python - - model.save_to(.nemo) - -To load a local ``.nemo`` checkpoint: - -.. code-block:: python - - import nemo.collections.multimodal as nemo_multimodal - model = nemo_multimodal.models..restore_from(restore_path="") - -Replace `` with the appropriate MM model class. - -Converting Local Checkpoints ----------------------------- - -The training script only auto-converts the final checkpoint into the ``.nemo`` format. To evaluate intermediate training checkpoints, conversion to ``.nemo`` might be needed. For this: - -.. code-block:: bash - - python -m torch.distributed.launch --nproc_per_node= * \ - examples/multimodal/convert_ckpt_to_nemo.py \ - --checkpoint_folder \ - --checkpoint_name \ - --nemo_file_path \ - --tensor_model_parallel_size \ - --pipeline_model_parallel_size - -Converting Community Checkpoints --------------------------------- - -NeVA Checkpoints -^^^^^^^^^^^^^^^^ - -Currently, the conversion mainly supports LLaVA checkpoints based on "llama-2 chat" checkpoints. As a reference, we'll consider the checkpoint `llava-llama-2-13b-chat-lightning-preview `_. - -After downloading this checkpoint and saving it at ``/path/to/llava-llama-2-13b-chat-lightning-preview``, undertake the following procedures: - -Modifying the Tokenizer -""""""""""""""""""""""" - -NeMo mandates adding specific tokens to the tokenizer model for peak performance. To modify an existing tokenizer located in ``/path/to/llava-llama-2-13b-chat-lightning-preview/tokenizer``, execute the following in the NeMo container: - -.. code-block:: bash - - cd /opt/sentencepiece/src/ - protoc --python_out=/opt/NeMo/scripts/tokenizers/ sentencepiece_model.proto - python /opt/NeMo/scripts/tokenizers/add_special_tokens_to_sentencepiece.py \ - --input_file /path/to/llava-llama-2-13b-chat-lightning-preview/tokenizer.model \ - --output_file /path/to/llava-llama-2-13b-chat-lightning-preview/tokenizer_neva.model \ - --is_userdefined \ - --tokens "" "" "" "" \ - "" "" "" "" - -Checkpoint Conversion -""""""""""""""""""""" - -For conversion: - -.. code-block:: bash - - python examples/multimodal/mllm/neva/convert_hf_llava_to_neva.py \ - --in-file /path/to/llava-llama-2-13b-chat-lightning-preview \ - --out-file /path/to/neva-llava-llama-2-13b-chat-lightning-preview.nemo \ - --tokenizer-model /path/to/llava-llama-2-13b-chat-lightning-preview/tokenizer_add_special.model - --conv-template llama_2 - - -Model Parallelism Adjustment ----------------------------- - -NeVA Checkpoints -^^^^^^^^^^^^^^^^ - -Adjust model parallelism with: - -.. code-block:: bash - - python examples/nlp/language_modeling/megatron_change_num_partitions.py \ - --model_file=/path/to/source.nemo \ - --target_file=/path/to/target.nemo \ - --tensor_model_parallel_size=??? \ - --target_tensor_model_parallel_size=??? \ - --pipeline_model_parallel_size=??? \ - --target_pipeline_model_parallel_size=??? \ - --model_class="nemo.collections.multimodal.models.multimodal_llm.neva.neva_model.MegatronNevaModel" \ - --precision=32 \ - --tokenizer_model_path=/path/to/tokenizer.model \ - --tp_conversion_only diff --git a/docs/source/multimodal/mllm/intro.rst b/docs/source/multimodal/mllm/intro.rst index 0e76a9737a0f..48bfd56f9ae1 100644 --- a/docs/source/multimodal/mllm/intro.rst +++ b/docs/source/multimodal/mllm/intro.rst @@ -8,7 +8,6 @@ The endeavor to extend Language Models (LLMs) into multimodal domains by integra datasets configs - checkpoint neva video_neva sequence_packing diff --git a/docs/source/multimodal/vlm/checkpoint.rst b/docs/source/multimodal/vlm/checkpoint.rst index 996d9828f5aa..d984f1453510 100644 --- a/docs/source/multimodal/vlm/checkpoint.rst +++ b/docs/source/multimodal/vlm/checkpoint.rst @@ -35,58 +35,36 @@ To load a local ``.nemo`` checkpoint: Replace `` with the appropriate MM model class. -Converting Local Checkpoints ----------------------------- - -Only the last checkpoint is automatically saved in the ``.nemo`` format. If intermediate training checkpoints evaluation is required, a ``.nemo`` conversion might be necessary. For this, refer to the script at `script `_: - -.. code-block:: python - - python -m torch.distributed.launch --nproc_per_node= * \ - examples/multimodal/convert_ckpt_to_nemo.py \ - --checkpoint_folder \ - --checkpoint_name \ - --nemo_file_path \ - --tensor_model_parallel_size \ - --pipeline_model_parallel_size - Converting Community Checkpoints -------------------------------- CLIP Checkpoints ^^^^^^^^^^^^^^^^ -To migrate community checkpoints: -.. code-block:: python +To migrate community checkpoints, use the following command: + +.. code-block:: bash - python examples/multimodal/foundation/clip/convert_external_clip_to_nemo.py \ - --arch=ViT-H-14 \ - --version=laion2b_s32b_b79k \ - --hparams_file=path/to/saved.yaml \ - --nemo_file_path=open_clip.nemo + torchrun --nproc-per-node=1 /opt/NeMo/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py \ + --input_name_or_path=openai/clip-vit-large-patch14 \ + --output_path=openai_clip.nemo \ + --hparams_file=/opt/NeMo/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-L-14.yaml Ensure the NeMo hparams file has the correct model architectural parameters, placed at `path/to/saved.yaml`. An example can be found in `examples/multimodal/foundation/clip/conf/megatron_clip_config.yaml`. -For OpenCLIP migrations, provide the architecture (`arch`) and version (`version`) according to the OpenCLIP `model list `_. For Hugging Face conversions, set the version to `huggingface` and the architecture (`arch`) to the specific Hugging Face model identifier, e.g., `yuvalkirstain/PickScore_v1`. +After conversion, you can verify the model with the following command: -Model Parallelism Adjustment ----------------------------- +.. code-block:: bash -CLIP Checkpoints -^^^^^^^^^^^^^^^^ + wget https://upload.wikimedia.org/wikipedia/commons/0/0f/1665_Girl_with_a_Pearl_Earring.jpg + torchrun --nproc-per-node=1 /opt/NeMo/examples/multimodal/vision_language_foundation/clip/megatron_clip_infer.py \ + model.restore_from_path=./openai_clip.nemo \ + image_path=./1665_Girl_with_a_Pearl_Earring.jpg \ + texts='["a dog", "a boy", "a girl"]' -To adjust model parallelism from original model parallelism size to a new model parallelism size (Note: NeMo CLIP currently only supports `pipeline_model_parallel_size=1`): +It should generate a high probability for the "a girl" tag. For example: -.. code-block:: python +.. code-block:: text - python examples/nlp/language_modeling/megatron_change_num_partitions.py \ - --model_file=/path/to/source.nemo \ - --target_file=/path/to/target.nemo \ - --tensor_model_parallel_size=??? \ - --target_tensor_model_parallel_size=??? \ - --pipeline_model_parallel_size=-1 \ - --target_pipeline_model_parallel_size=1 \ - --precision=32 \ - --model_class="nemo.collections.multimodal.models.clip.megatron_clip_models.MegatronCLIPModel" \ - --tp_conversion_only + Given image's CLIP text probability: [('a dog', 0.0049710185), ('a boy', 0.002258187), ('a girl', 0.99277073)] diff --git a/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml index 3ec90b2d1b53..d8a31fa19ca9 100644 --- a/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml +++ b/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml @@ -71,10 +71,10 @@ model: freeze: False model_type: llama_2 # Only support nvgpt or llama_2 vision_encoder: - from_pretrained: "openai/clip-vit-large-patch14" # path or name + from_pretrained: "openai/clip-vit-large-patch14-336" # path or name from_hf: True patch_dim: 14 - crop_size: [224, 224] + crop_size: [336, 336] hidden_size: 1024 # could be found from model but tricky in code vision_select_layer: -2 # default to the last layer class_token_length: 1 diff --git a/examples/multimodal/vision_language_foundation/clip/conf/megatron_siglip_config.yaml b/examples/multimodal/vision_language_foundation/clip/conf/megatron_siglip_config.yaml new file mode 100644 index 000000000000..59f21813ce01 --- /dev/null +++ b/examples/multimodal/vision_language_foundation/clip/conf/megatron_siglip_config.yaml @@ -0,0 +1,253 @@ +name: megatron_siglip +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 375000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + check_val_every_n_epoch: null + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_siglip + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits + filename: 'megatron_siglip--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + ema: + enable: False + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + +model: + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 32 # limited by GPU memory + global_batch_size: 32 # will use more micro batches to reach global batch size + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + restore_from_path: null # used in fine-tuning + # multimodal configs + output_dim: 1152 + # As the number of devices used to train increases, so does the space complexity of + # the logit matrix. Using a naïve all-gather scheme, space complexity will be + # `O(n^2)`. Instead, complexity may become effectively linear if the flags + # `--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one + # numerical results as the naïve method. + + use_siglip: True + mcore_gpt: True + transformer_engine: True + + vision: + precision: ${trainer.precision} + # vision configs + patch_dim: 14 + img_h: 378 + img_w: 378 + image_mean: null + image_std: null + num_channels: 3 + drop_patch_rate: 0.0 + drop_path_rate: 0.0 + global_average_pool: False + output_dim: ${model.output_dim} + class_token_length: 0 + preprocess_layernorm: True # apply layer norm to embedded tokens + + # model architecture + encoder_seq_length: 196 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: learned_absolute + num_layers: 27 + hidden_size: 1152 + ffn_hidden_size: 4304 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 16 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0. # Dropout probability for hidden state transformer. + attention_dropout: 0. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: layernorm # Type of normalization layers + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: True + bias_activation_fusion: False + activation: approx-gelu + megatron_legacy: False + + + text: + precision: ${trainer.precision} + # text configs + output_dim: ${model.output_dim} + + # model architecture + encoder_seq_length: 64 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: learned_absolute + num_layers: 27 + hidden_size: 1152 + ffn_hidden_size: 4304 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 16 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0. # Dropout probability for hidden state transformer. + attention_dropout: 0. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: layernorm # Type of normalization layers + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + num_micro_batches_with_partial_activation_checkpoints: null + activations_checkpoint_layers_per_pipeline: null + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: True + bias_activation_fusion: False + megatron_legacy: False + + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + activation: approx-gelu + + # Megatron O2-style half-precision + megatron_amp_O2: True # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + tokenizer: + library: 'huggingface' + type: 'google/siglip-so400m-patch14-384' + model: null + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + + data: + num_workers: 8 + train: + dataset_path: # List of paths to pkl files or tar files + - /datasets/coyo/test.pkl + validation: # List of paths to pkl files or tar files + dataset_path: + - /datasets/coyo/test.pkl + webdataset: + infinite_sampler: False + local_root_path: /datasets/coyo + + imagenet_val: null # Path to imagenet val set for conducting zero shot evaluation. + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: fused_adam + lr: 1e-3 + weight_decay: 0.2 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 2000 + constant_steps: 0 + min_lr: 1e-5 \ No newline at end of file diff --git a/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py b/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py index 9af25181d07e..178140aac828 100644 --- a/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py +++ b/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py @@ -283,7 +283,10 @@ def convert(local_rank, rank, world_size, args): if __name__ == '__main__': - logging.warning("This script is going to be deprecated soon. Please use ") + logging.warning( + "This script is going to be deprecated soon. Please use " + "`scripts/checkpoint_converters/convert_clip_hf_to_nemo.py`" + ) args = get_args() local_rank, rank, world_size = initialize_distributed(args) convert(local_rank, rank, world_size, args) diff --git a/nemo/collections/multimodal/data/clip/clip_dataset.py b/nemo/collections/multimodal/data/clip/clip_dataset.py index 6b63d546194a..448efba4b8ba 100644 --- a/nemo/collections/multimodal/data/clip/clip_dataset.py +++ b/nemo/collections/multimodal/data/clip/clip_dataset.py @@ -57,8 +57,9 @@ def tokenize(texts: Union[str, List[str]], tokenizer: Any, context_length: int = bos_id = tokenizer.bos_id eos_id = tokenizer.eos_id - all_tokens = [[bos_id] + tokenizer.text_to_ids(text) + [eos_id] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + pad_id = tokenizer.pad_id + all_tokens = [([bos_id] if bos_id is not None else []) + tokenizer.text_to_ids(text) + [eos_id] for text in texts] + result = torch.ones(len(all_tokens), context_length, dtype=torch.long) * pad_id for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index 92066b89c1a1..095861b9c1fc 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -67,6 +67,8 @@ try: from megatron.core import InferenceParams, dist_checkpointing, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace + from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject, ShardedObject from megatron.core.models.gpt import GPTModel as MCoreGPTModel from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.pipeline_parallel.schedules import get_forward_backward_func @@ -79,6 +81,12 @@ HAVE_MEGATRON_CORE = False +def skip_fp8_load(x): + if isinstance(x, ShardedObject) and 'fused_attention' in x.key and '_extra_state' in x.key: + x = LocalNonpersitentObject(x.data) # use the FP8 state from initialization, not from ckpt + return x + + class FrozenCLIPVisionTransformer(CLIPVisionTransformer): """Frozen version of CLIPVisionTransformer""" @@ -521,6 +529,9 @@ def _load_model_weights(self, nemo_path): sharded_state_dict = None if getattr(self, "sharded_state_dict", None) is not None: sharded_state_dict = self.sharded_state_dict(prefix="model.") + # WAR: This is a temporary fix to skip loading FP8 parameters for Dot Product Attention + # TODO(yuya): Check if this skip affecting fp8 native checkpoints loading + dict_list_map_inplace(skip_fp8_load, sharded_state_dict) state_dict, self.is_dist_ckpt = load_nemo_model_weights(nemo_path, sharded_state_dict) return state_dict diff --git a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py index d811ce94dbea..a197ae291880 100644 --- a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py +++ b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py @@ -984,6 +984,7 @@ def training_step(self, dataloader_iter): for module in modules: if isinstance(module, (Float16Module, MCoreFloat16Module)): module = module.module + module = module.text_encoder if not self.mcore_gpt: module = module.language_model if hasattr(module, 'embedding'): diff --git a/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py b/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py index 690fa74abccd..2b8156ad4b26 100644 --- a/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py b/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py index d91899348e8c..85f65ca05ecf 100644 --- a/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py @@ -292,7 +292,7 @@ def convert(args): batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()} hf_model = hf_model.cuda().eval() - model = model.eval() + model = model.cuda().eval() hf_outputs = hf_model(**batch_dict_cuda, output_hidden_states=True) ids = batch_dict_cuda['input_ids'] @@ -307,7 +307,7 @@ def convert(args): attn_mask, _, pos_ids = attn_mask_and_pos_ids outputs = model( - tokens=tokens, text_position_ids=pos_ids.cuda(), attention_mask=attn_mask.cuda(), labels=None + tokens=tokens.cuda(), text_position_ids=pos_ids.cuda(), attention_mask=attn_mask.cuda(), labels=None ) hf_next_token = hf_outputs.logits[0, -1].argmax() diff --git a/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py b/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py index 97a9d557f78b..053b3a053884 100644 --- a/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py @@ -13,11 +13,10 @@ # limitations under the License. """ -Requires HF transformers updated to support Gemma Models - python3 /opt/NeMo/scripts/nlp_language_modeling/convert_gemma_hf_to_nemo.py \ - --input_name_or_path /path/to/gemma/checkpoints/hf/7b \ - --output_path /path/to/gemma-7b.nemo \ - --tokenizer_path /path/to/tokenizer.model +Requires HF transformers updated to support Siglip Models + python /opt/NeMo/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py \ + --input_name_or_path=google/siglip-so400m-patch14-384 \ + --output_path=test.nemo """ import os @@ -352,7 +351,7 @@ def get_args(): def convert(args): logging.info(f"Loading checkpoint from HF: `{args.input_name_or_path}`") hf_model = AutoModel.from_pretrained(args.input_name_or_path) - # hf_processor = AutoProcessor.from_pretrained(args.input_name_or_path) + hf_processor = AutoProcessor.from_pretrained(args.input_name_or_path) logging.info("HF Model loading done.") nemo_config = OmegaConf.load(args.hparams_file) @@ -369,6 +368,35 @@ def convert(args): nemo_state_dict = adjust_tensor_shapes(model, new_state_dict) model.load_state_dict(nemo_state_dict, strict=False) + logging.info(f'=' * 100) + # Verifications + import requests + from PIL import Image + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + texts = ["a photo of 2 cats", "a photo of 2 dogs"] + inputs = hf_processor(text=texts, images=image, padding="max_length", return_tensors="pt") + + tokens = inputs["input_ids"].cuda() + text_model = model.model.text_encoder.cuda() + hf_text_model = hf_model.text_model.cuda() + text_model_output = text_model(tokens) + hf_text_model_output = hf_text_model(tokens).pooler_output + assert torch.allclose(text_model_output, hf_text_model_output, atol=0.01) + logging.info(f'! Text model results matched.') + + pixels = inputs["pixel_values"].cuda() + vision_model = model.model.vision_encoder.cuda() + hf_vision_model = hf_model.vision_model.cuda() + vision_model_output = vision_model(pixels) + hf_vision_model_output = hf_vision_model(pixels).pooler_output + assert torch.allclose(vision_model_output, hf_vision_model_output, atol=0.01) + logging.info(f'! Vision model results matched.') + + logging.info(f'=' * 100) + dtype = torch_dtype_from_precision(args.precision) model = model.to(dtype=dtype) model.save_to(args.output_path) diff --git a/tutorials/multimodal/NeVA Tutorial.ipynb b/tutorials/multimodal/NeVA Tutorial.ipynb index b57bdb47df57..1ad1101a0299 100644 --- a/tutorials/multimodal/NeVA Tutorial.ipynb +++ b/tutorials/multimodal/NeVA Tutorial.ipynb @@ -2,8 +2,13 @@ "cells": [ { "cell_type": "markdown", - "id": "a2225742c5996304", - "metadata": {}, + "id": "b29a4b72-31bb-4268-9598-2cd2b6f7475e", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } + }, "source": [ "# NeVA Training / Inference Tutorial\n", "\n", @@ -20,28 +25,19 @@ "\n", "This notebook illustrates how to train and perform inference using NeVA with the NeMo Toolkit. NeVA originates from [LLaVA](https://github.com/haotian-liu/LLaVA) (Large Language and Vision Assistant) and is a powerful multimodal image-text instruction tuned model optimized within the NeMo Framework. \n", "\n", - "\n", "This tutorial will guide you through the following topics:\n", - "1. Training a NeVA model\n", - "2. Performing inference with the trained model\n", + "1. Prepare pre-requisites for NeVA training\n", + "2. Training a NeVA model\n", + "3. Performing inference with the trained model\n", "\n", "## Datasets\n", "\n", - "After downloading all below datasets for pretraining and instruction tuning, your dataset directory should look something similar to:\n", + "Please refer to [NeMo User Guide](https://docs.nvidia.com/nemo-framework/user-guide/latest/multimodalmodels/multimodallanguagemodel/neva/dataprep.html#prepare-pretraining-and-fine-tuning-datasets) for preparing NeVA dataset for pretrain and fine-tuning.\n", "\n", - "```\n", - "LLaVA-Pretrain-LCS-558K\n", - "├── blip_laion_cc_sbu_558k.json\n", - "├── images\n", - "LLaVA-Instruct-mixture\n", - "├── llava_v1_5_mix665k.json\n", - "└── images\n", - " └── ...\n", - "```\n", "\n", "### Pre-Training Dataset\n", "\n", - "The pre-training dataset is open-sourced from the LLaVA implementation and can be downloaded [here](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain). The dataset consists of a 558K subset of the LAION-CC-SBU dataset with BLIP captions. \n", + "The pre-training dataset is open-sourced from the LLaVA implementation and can be downloaded [here](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain). The dataset consists of a 558K subset of the LAION-CC-SBU dataset with BLIP captions.\n", "\n", "The associated images for pretraining can be downloaded via HuggingFace [here](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/blob/main/images.zip).\n", "\n", @@ -66,14 +62,75 @@ " └── VG_100K_2\n", "```\n", "\n", - "## Training\n", + "After downloading all below datasets for pretraining and instruction tuning, please put data folder at `/workspace/datasets`. Your dataset directory should look something similar to:\n", + "\n", + "```\n", + "LLaVA-Pretrain-LCS-558K\n", + "├── blip_laion_cc_sbu_558k.json\n", + "├── images\n", + "LLaVA-Instruct-mixture\n", + "├── llava_v1_5_mix665k.json\n", + "└── images\n", + " └── ...\n", + "```\n", + "\n", + "## Setting up Checkpoint and Tokenizer\n", + "\n", + "In this notebook, we first need to convert the Vicuna 1.5 checkpoint into the .nemo format. Meanwhile, special tokens must be incorporated into the tokenizer for NeVA training. After downloading language models from Hugging Face, ensure you also fetch the corresponding tokenizer model. Using the 7B-chat model as a reference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d80adff-bd3a-40e0-9441-684328ec7596", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "! mkdir -p /workspace/checkpoints\n", + "\n", + "# Download vicuna checkpoint from HF\n", + "! git clone https://huggingface.co/lmsys/vicuna-7b-v1.5 /workspace/checkpoints/vicuna-7b-v1.5\n", + "\n", + "# Convert checkpoint\n", + "! python /opt/NeMo/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py \\\n", + " --input_name_or_path /workspace/checkpoints/vicuna-7b-v1.5 \\\n", + " --output_path /workspace/checkpoints/vicuna-7b-v1.5.nemo\n", + "\n", + "# Prepare tokenizer\n", + "! cd /opt && git clone https://github.com/google/sentencepiece.git && \\\n", + " cd sentencepiece && \\\n", + " mkdir build && \\\n", + " cd build && \\\n", + " cmake .. && \\\n", + " make && \\\n", + " make install && \\\n", + " ldconfig && \\\n", + "cd /opt/sentencepiece/src/ && protoc --python_out=/opt/NeMo/scripts/tokenizers/ sentencepiece_model.proto\n", "\n", + "! python /opt/NeMo/scripts/tokenizers/add_special_tokens_to_sentencepiece.py \\\n", + "--input_file /workspace/checkpoints/vicuna-7b-v1.5/tokenizer.model \\\n", + "--output_file /workspace/checkpoints/vicuna-7b-v1.5/tokenizer_neva.model \\\n", + "--is_userdefined \\\n", + "--tokens \"\" \"\" \"\" \"\" \\\n", + " \"\" \"\" \"\" \"\"\n" + ] + }, + { + "cell_type": "markdown", + "id": "6b619e0a", + "metadata": {}, + "source": [ + "## Training\n", "\n", "### Feature Alignment Pre-Training\n", "\n", "We provide a set of scripts for pre-training and fine-tuning which can be kicked off with CLI flags defining specified arguments. \n", "\n", - "An example of a pre-training script execution:" + "An example of a pre-training script execution (note the scripts will only perform 100 steps with a small micro batch size, this is not a full training):" ] }, { @@ -92,61 +149,58 @@ " trainer.precision=bf16 \\\n", " trainer.num_nodes=1 \\\n", " trainer.devices=4 \\\n", - " trainer.val_check_interval=1000 \\\n", + " trainer.val_check_interval=50 \\\n", " trainer.limit_val_batches=5 \\\n", " trainer.log_every_n_steps=1 \\\n", - " trainer.max_steps=1000 \\\n", + " trainer.max_steps=100 \\\n", " model.megatron_amp_O2=True \\\n", " model.micro_batch_size=1 \\\n", - " model.global_batch_size=2 \\\n", - " model.tensor_model_parallel_size=4 \\\n", + " model.global_batch_size=4 \\\n", + " model.tensor_model_parallel_size=1 \\\n", " model.pipeline_model_parallel_size=1 \\\n", " model.mcore_gpt=True \\\n", " model.transformer_engine=True \\\n", - " model.data.data_path=/path/to/datasets/LLaVA-Pretrain-LCS-558K/blip_laion_cc_sbu_558k.json \\\n", - " model.data.image_folder=/path/to/dataset/LLaVA-Pretrain-LCS-558K/images \\\n", + " model.data.data_path=/workspace/datasets/LLaVA-Pretrain-LCS-558K/blip_laion_cc_sbu_558k.json \\\n", + " model.data.image_folder=/workspace/datasets/LLaVA-Pretrain-LCS-558K/images \\\n", " model.tokenizer.library=sentencepiece \\\n", - " model.tokenizer.model=/path/to/tokenizer/model \\\n", + " model.tokenizer.model=/workspace/checkpoints/vicuna-7b-v1.5/tokenizer_neva.model \\\n", " model.encoder_seq_length=4096 \\\n", " model.num_layers=32 \\\n", " model.hidden_size=4096 \\\n", - " model.ffn_hidden_size=16384 \\\n", + " model.ffn_hidden_size=11008 \\\n", " model.num_attention_heads=32 \\\n", - " model.normalization=layernorm1p \\\n", + " model.normalization=rmsnorm \\\n", " model.do_layer_norm_weight_decay=False \\\n", " model.apply_query_key_layer_scaling=True \\\n", - " model.activation=squared-relu \\\n", + " model.bias=False \\\n", + " model.activation=fast-swiglu \\\n", " model.headscale=False \\\n", " model.position_embedding_type=rope \\\n", - " model.rotary_percentage=0.5 \\\n", + " model.rotary_percentage=1.0 \\\n", " model.num_query_groups=null \\\n", " model.data.num_workers=0 \\\n", - " model.mm_cfg.llm.from_pretrained=/path/to/checkpoint \\\n", - " model.mm_cfg.llm.model_type=nvgpt \\\n", - " model.data.conv_template=nvgpt \\\n", + " model.mm_cfg.llm.from_pretrained=/workspace/checkpoints/vicuna-7b-v1.5.nemo \\\n", + " model.mm_cfg.llm.model_type=v1 \\\n", + " model.data.conv_template=v1 \\\n", " model.mm_cfg.vision_encoder.from_pretrained='openai/clip-vit-large-patch14' \\\n", " model.mm_cfg.vision_encoder.from_hf=True \\\n", - " model.data.image_token_len=256 \\\n", " model.optim.name=\"fused_adam\" \\\n", " exp_manager.create_checkpoint_callback=True \\\n", - " exp_manager.create_wandb_logger=False \\\n", - " exp_manager.wandb_logger_kwargs.project=neva_demo" + " exp_manager.create_wandb_logger=False" ] }, { "cell_type": "markdown", - "id": "6b619e0a", + "id": "f24ee70d-3025-47f6-8571-295b024c3e05", "metadata": {}, "source": [ - "\n", - "\n", "**Note**: To initialize training a model from scratch rather than from a pretrained checkpoint, you may specify `null` instead of a path in the CLI arguments.\n", "\n", "### Image-Language Pair Instruction Fine-Tuning\n", "\n", - "Fine-tuning can also be run from within the container via a similar command leveraging the `neva_finetune.py` script.\n", + "Fine-tuning can also be run from within the container via a similar command leveraging the `neva_finetune.py` script. We leverage the checkpoint saved from pretrain step to further finetune it, given by `model.restore_from_path=/workspace/nemo_experiments/nemo_neva/checkpoints/nemo_neva.nemo`.\n", "\n", - "An example of an image-text pair instruction tuning script execution:" + "An example of an image-text pair instruction tuning script execution (note the scripts will only perform 1000 steps with a small micro batch size, this is not a full training):" ] }, { @@ -164,42 +218,44 @@ "++cluster_type=BCP \\\n", " trainer.precision=bf16 \\\n", " trainer.num_nodes=1 \\\n", - " trainer.devices=1 \\\n", - " trainer.val_check_interval=100 \\\n", + " trainer.devices=4 \\\n", + " trainer.val_check_interval=50 \\\n", " trainer.limit_val_batches=50 \\\n", - " trainer.max_steps=4900 \\\n", + " trainer.max_steps=100 \\\n", + " model.restore_from_path=/workspace/nemo_experiments/nemo_neva/checkpoints/nemo_neva.nemo \\\n", " model.megatron_amp_O2=True \\\n", - " model.micro_batch_size=4 \\\n", - " model.global_batch_size=32 \\\n", - " model.tensor_model_parallel_size=1 \\\n", + " model.micro_batch_size=1 \\\n", + " model.global_batch_size=2 \\\n", + " model.tensor_model_parallel_size=4 \\\n", " model.pipeline_model_parallel_size=1 \\\n", " model.mcore_gpt=True \\\n", " model.transformer_engine=True \\\n", - " model.data.data_path=/path/to/dataset/LLaVA-Pretrain-LCS-558K/blip_laion_cc_sbu_558k.json \\\n", - " model.data.image_folder=/path/to/dataset/LLaVA-Pretrain-LCS-558K/images \\\n", - " model.tokenizer.library=megatron \\\n", - " model.tokenizer.model=/path/to/tokenizer \\\n", + " model.data.data_path=/workspace/datasets/LLaVA-Instruct-mixture/llava_v1_5_mix665k.json \\\n", + " model.data.image_folder=/workspace/datasets/LLaVA-Instruct-mixture/images \\\n", + " model.tokenizer.library=sentencepiece \\\n", + " model.tokenizer.model=/workspace/checkpoints/vicuna-7b-v1.5/tokenizer_neva.model \\\n", " model.encoder_seq_length=4096 \\\n", - " model.num_layers=24 \\\n", - " model.hidden_size=2048 \\\n", - " model.ffn_hidden_size=5440 \\\n", - " model.num_attention_heads=16 \\\n", - " model.normalization=layernorm1p \\\n", + " model.num_layers=32 \\\n", + " model.hidden_size=4096 \\\n", + " model.ffn_hidden_size=11008 \\\n", + " model.num_attention_heads=32 \\\n", + " model.normalization=rmsnorm \\\n", " model.do_layer_norm_weight_decay=False \\\n", " model.apply_query_key_layer_scaling=True \\\n", + " model.bias=False \\\n", " model.activation=fast-swiglu \\\n", " model.headscale=False \\\n", " model.position_embedding_type=rope \\\n", - " model.rotary_percentage=0.5 \\\n", + " model.rotary_percentage=1.0 \\\n", " model.num_query_groups=null \\\n", - " model.data.num_workers=8 \\\n", - " model.mm_cfg.llm.from_pretrained=/path/to/checkpoint \\\n", - " model.mm_cfg.llm.model_type=nvgpt \\\n", - " exp_manager.create_checkpoint_callback=True \\\n", - " model.data.conv_template=nvgpt \\\n", + " model.data.num_workers=0 \\\n", + " model.mm_cfg.llm.from_pretrained=/workspace/checkpoints/vicuna-7b-v1.5.nemo \\\n", + " model.mm_cfg.llm.model_type=v1 \\\n", + " model.data.conv_template=v1 \\\n", " model.mm_cfg.vision_encoder.from_pretrained='openai/clip-vit-large-patch14' \\\n", " model.mm_cfg.vision_encoder.from_hf=True \\\n", - " model.data.image_token_len=256 \\\n", + " exp_manager.create_checkpoint_callback=True \\\n", + " exp_manager.name=\"nemo_neva_finetune\" \\\n", " model.optim.name=\"fused_adam\"" ] }, @@ -212,38 +268,7 @@ "\n", "### From Pre-trained Checkpoints\n", "\n", - "If you would like to use NeVA for inference from pre-trained checkpoint via HuggingFace, you can convert from HuggingFace to `.nemo` first.\n", - "\n", - "First, download the model checkpoint from HuggingFace [here](https://huggingface.co/liuhaotian/llava-v1.5-7b). The tokenizer (stored as `tokenizer.model` within the pretrained checkpoint) must be modified with the following commands:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0d30003f", - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, - "outputs": [], - "source": [ - "! cd /opt/sentencepiece/src/\n", - "! protoc --python_out=/opt/NeMo/scripts/tokenizers/ sentencepiece_model.proto\n", - "! python /opt/NeMo/scripts/tokenizers/add_special_tokens_to_sentencepiece.py \\\n", - "--input_file /path/to/tokenizer.model \\\n", - "--output_file /path/to/tokenizer_neva.model \\\n", - "--is_userdefined \\\n", - "--tokens \"\" \"\" \"\" \"\" \\\n", - " \"\" \"\" \"\" \"\"" - ] - }, - { - "cell_type": "markdown", - "id": "470c093b", - "metadata": {}, - "source": [ - "Finally, convert to `.nemo` via the provided script:" + "If you would like to use NeVA for inference from pre-trained checkpoint via HuggingFace, you can use the checkpoint from fine-tune step or convert from HuggingFace to `.nemo` first. Since we didn't finish full training in this tutorial with NeMo. We will instruct how you can convert a checkpoint from Hugging Face." ] }, { @@ -257,12 +282,10 @@ }, "outputs": [], "source": [ - "! git clone --depth 1 --branch v1.2.2 https://github.com/haotian-liu/LLaVA/\n", - "! export PYTHONPATH=/opt/LLaVA/:$PYTHONPATH\n", - "! python /opt/NeMo/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py \\\n", - "--in-file /path/to/llava-v1.5-7b \\\n", - "--out-file /path/to/llava-v1.5-7b.nemo \\\n", - "--tokenizer-model /path/to/tokenizer_neva.model" + "! python3 /opt/NeMo/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py \\\n", + " --input_name_or_path llava-hf/llava-1.5-7b-hf \\\n", + " --output_path /workspace/checkpoints/llava-7b.nemo \\\n", + " --tokenizer_path /workspace/checkpoints/vicuna-7b-v1.5/tokenizer_neva.model" ] }, { @@ -290,57 +313,19 @@ }, "outputs": [], "source": [ + "! echo '{\"image\": \"RTX4080.png\", \"prompt\": \"\\nCan you describe this image?\"}' > sample.json\n", + "! mkdir images && wget https://assets.nvidia.partners/images/png/TUF_Gaming_GeForce_RTX_4080_SUPER_OC_edition_packaging_with_card__12419.png --output-document=images/RTX4080.png\n", "! torchrun --nproc_per_node=1 /opt/NeMo/examples/multimodal/multimodal_llm/neva/neva_evaluation.py \\\n", "tensor_model_parallel_size=1 \\\n", "pipeline_model_parallel_size=1 \\\n", - "neva_model_file=/path/to/checkpoint \\\n", + "neva_model_file=/workspace/checkpoints/llava-7b.nemo \\\n", "trainer.devices=1 \\\n", "trainer.precision=bf16 \\\n", - "prompt_file=/path/to/prompt/file \\\n", - "inference.media_base_path=/path/to/image \\\n", - "output_file=path/for/output/file/ \\\n", + "prompt_file=sample.json \\\n", + "inference.media_base_path=images \\\n", + "output_file=output.json \\\n", "inference.temperature=0.2 \\\n", - "inference.top_k=0 \\\n", - "inference.top_p=0.9 \\\n", - "inference.greedy=False \\\n", - "inference.add_BOS=False \\\n", - "inference.all_probs=False \\\n", - "inference.repetition_penalty=1.2 \\\n", - "inference.insert_media_token=null \\\n", - "inference.tokens_to_generate=256 \\\n", - "quantization.algorithm=awq \\\n", - "quantization.enable=False" - ] - }, - { - "cell_type": "markdown", - "id": "7d989385", - "metadata": {}, - "source": [ - "#### Running Inference via Launcher\n", - "\n", - "Inference can also be run via the NeMo Launcher, where parameters are specified in the inference config file rather than CLI arguments. To customize the default config provided in `conf/config.yaml` for NeVA inference, see below.\n", - "\n", - "##### Inference Config Setup\n", - "1. Modify `fw_inference` within `defaults` to use `neva/inference` \n", - "2. In `stages`, ensure that `fw_inference` is included\n", - "3. Within the `inference.yaml` default NeVA inference config file, ensure that the path to the `prompt` file, `neva_model_file`, and `media_base_path` within `inference` are specified.\n", - "\n", - "Once either the necessary checkpoints have been loaded or the training workflow is complete, inference can be executed within the launcher pipeline with the following command:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "68d434ff", - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, - "outputs": [], - "source": [ - "! python3 main.py" + "inference.tokens_to_generate=256" ] } ], @@ -360,7 +345,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.11.6" } }, "nbformat": 4,