From 40139d9915bfec3da00458e07b63bcd5c62ebb8f Mon Sep 17 00:00:00 2001 From: nikitaved Date: Tue, 16 Jul 2024 09:00:17 -0400 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: nikitaved --- .../models/multimodal_llm/neva/neva_model.py | 23 ++++++------------- nemo/utils/__init__.py | 13 +++++++++++ 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index 88fd22f96403a..714c832121b95 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -16,7 +16,6 @@ from functools import partial from itertools import chain from typing import Any, Optional -from collections.abc import Callable import numpy as np import torch @@ -64,9 +63,7 @@ from nemo.collections.vision.data.megatron.data_samplers import MegatronVisionPretrainingRandomSampler from nemo.core import adapter_mixins from nemo.core.classes.common import PretrainedModelInfo -from nemo.utils import logging -from nemo.constants import NEMO_ENV_VARNAME_TESTING -from nemo.utils.env_var_parsing import get_envbool +from nemo.utils import logging, run_if_testing try: from megatron.core import InferenceParams, dist_checkpointing, parallel_state, tensor_parallel @@ -82,15 +79,6 @@ HAVE_MEGATRON_CORE = False -def run_if_testing(f: Callable): - """Helper function that invokes the input callable `f` - if the environment variable `NEMO_TESTING` is set. - """ - - if get_envbool(NEMO_ENV_VARNAME_TESTING): - f() - - class FrozenCLIPVisionTransformer(CLIPVisionTransformer): """Frozen version of CLIPVisionTransformer""" @@ -238,7 +226,9 @@ def replace_media_embeddings(self, input_ids, inputs_embeds, media): # create an indices matrix used in torch.scatter { sorted_media_end_positions_mask, media_end_positions_mask_sort_idx = ( # NOTE: to(torch.long) is needed because PyTorch does not have sort for boolean tensors on CUDA - (input_ids == self.media_end_id).to(torch.long).sort(dim=-1, descending=True, stable=True) + (input_ids == self.media_end_id) + .to(torch.long) + .sort(dim=-1, descending=True, stable=True) ) # TODO: unless `media_end_positions_mask_sort_idx` is required to be sorted, # we can replace sort with topk(..., k=num_images_per_sample) @@ -250,14 +240,15 @@ def replace_media_embeddings(self, input_ids, inputs_embeds, media): padded_media_indices = torch.where( sorted_media_end_positions_mask.to(torch.bool), media_end_positions_mask_sort_idx - num_patches, - sequence_length + sequence_length, ) else: padded_media_indices = torch.where( sorted_media_end_positions_mask.to(torch.bool), media_end_positions_mask_sort_idx - num_patches + 1, - sequence_length + sequence_length, ) + # Check whether `padded_media_indices` represents correct indices # This check is only run when the env var `NEMO_TESTING` is set def check_padded_media_indices(): diff --git a/nemo/utils/__init__.py b/nemo/utils/__init__.py index a1e59646ae131..505cc1c7172b0 100644 --- a/nemo/utils/__init__.py +++ b/nemo/utils/__init__.py @@ -13,6 +13,9 @@ # limitations under the License. +from collections.abc import Callable + +from nemo.constants import NEMO_ENV_VARNAME_TESTING from nemo.utils.app_state import AppState from nemo.utils.cast_utils import ( CastToFloat, @@ -24,6 +27,7 @@ monkeypatched, ) from nemo.utils.dtype import str_to_dtype +from nemo.utils.env_var_parsing import get_envbool from nemo.utils.nemo_logging import Logger as _Logger from nemo.utils.nemo_logging import LogMode as logging_mode @@ -34,3 +38,12 @@ add_memory_handlers_to_pl_logger() except ModuleNotFoundError: pass + + +def run_if_testing(f: Callable): + """Helper function that invokes the input callable `f` + if the environment variable `NEMO_TESTING` is set. + """ + + if get_envbool(NEMO_ENV_VARNAME_TESTING, False): + f()