Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: nikitaved <[email protected]>
  • Loading branch information
nikitaved committed Nov 6, 2024
1 parent 37d72e6 commit ac472ad
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from functools import partial
from itertools import chain
from typing import Any, Optional
from collections.abc import Callable

import numpy as np
import packaging
Expand Down Expand Up @@ -68,9 +67,7 @@
from nemo.collections.vision.data.megatron.data_samplers import MegatronVisionPretrainingRandomSampler
from nemo.core import adapter_mixins
from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging
from nemo.constants import NEMO_ENV_VARNAME_TESTING
from nemo.utils.env_var_parsing import get_envbool
from nemo.utils import logging, run_if_testing

try:
from megatron.energon import (
Expand Down Expand Up @@ -119,15 +116,6 @@ def skip_fp8_load(x):
return x


def run_if_testing(f: Callable):
"""Helper function that invokes the input callable `f`
if the environment variable `NEMO_TESTING` is set.
"""

if get_envbool(NEMO_ENV_VARNAME_TESTING):
f()


class FrozenCLIPVisionTransformer(CLIPVisionTransformer):
"""Frozen version of CLIPVisionTransformer"""

Expand Down Expand Up @@ -275,7 +263,9 @@ def replace_media_embeddings(self, input_ids, inputs_embeds, media):
# create an indices matrix used in torch.scatter {
sorted_media_end_positions_mask, media_end_positions_mask_sort_idx = (
# NOTE: to(torch.long) is needed because PyTorch does not have sort for boolean tensors on CUDA
(input_ids == self.media_end_id).to(torch.long).sort(dim=-1, descending=True, stable=True)
(input_ids == self.media_end_id)
.to(torch.long)
.sort(dim=-1, descending=True, stable=True)
)
# TODO: unless `media_end_positions_mask_sort_idx` is required to be sorted,
# we can replace sort with topk(..., k=num_images_per_sample)
Expand All @@ -287,14 +277,15 @@ def replace_media_embeddings(self, input_ids, inputs_embeds, media):
padded_media_indices = torch.where(
sorted_media_end_positions_mask.to(torch.bool),
media_end_positions_mask_sort_idx - num_patches,
sequence_length
sequence_length,
)
else:
padded_media_indices = torch.where(
sorted_media_end_positions_mask.to(torch.bool),
media_end_positions_mask_sort_idx - num_patches + 1,
sequence_length
sequence_length,
)

# Check whether `padded_media_indices` represents correct indices
# This check is only run when the env var `NEMO_TESTING` is set
def check_padded_media_indices():
Expand Down
13 changes: 13 additions & 0 deletions nemo/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.


from collections.abc import Callable

from nemo.constants import NEMO_ENV_VARNAME_TESTING
from nemo.utils.app_state import AppState
from nemo.utils.cast_utils import (
CastToFloat,
Expand All @@ -24,6 +27,7 @@
monkeypatched,
)
from nemo.utils.dtype import str_to_dtype
from nemo.utils.env_var_parsing import get_envbool
from nemo.utils.nemo_logging import Logger as _Logger
from nemo.utils.nemo_logging import LogMode as logging_mode

Expand All @@ -34,3 +38,12 @@
add_memory_handlers_to_pl_logger()
except ModuleNotFoundError:
pass


def run_if_testing(f: Callable):
"""Helper function that invokes the input callable `f`
if the environment variable `NEMO_TESTING` is set.
"""

if get_envbool(NEMO_ENV_VARNAME_TESTING, False):
f()

0 comments on commit ac472ad

Please sign in to comment.