Skip to content

Commit

Permalink
Clip Model in Nemo2 (#11980)
Browse files Browse the repository at this point in the history
* Initial commit

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

* Initial commit

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

* Adding clip data module

* Adding some data modules

* Adding code for Yash

* Adding code for Yash

* first commit

* cleaning codes

* runnable code for one vision encoder

* Apply isort and black reformatting

Signed-off-by: huvunvidia <[email protected]>

* Newer code

* remove irrelevant code in backbones (download models and models-related methods)

* Apply isort and black reformatting

Signed-off-by: huvunvidia <[email protected]>

* Code for Huy

* Code for Huy

* Code for Huy

* Code for Huy

* Code for Huy

* Code for Huy

* Removing OpenVLA related commits

* Removing OpenVLA related commits

* Code for Clip

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: artbataev <[email protected]>

* Removing some debug commits

* Removing some debug commits

* Code for Clip

* Code for Clip

* Code for Clip

* Code for Clip

* Removing some not needed files

* Some more minor change

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

* Adding Nemo Run

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

* Batch size changes

* PR comments

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

* PR comments

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

* PR comments

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

* Adding Energon requirements

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

* PR Changes

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

* PR Changes

* Changes to Energon

* Changes for PR

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

* Changes for PR

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

* Changes for PR

* Apply isort and black reformatting

Signed-off-by: abhinavg4 <[email protected]>

---------

Signed-off-by: abhinavg4 <[email protected]>
Signed-off-by: huvunvidia <[email protected]>
Signed-off-by: artbataev <[email protected]>
Co-authored-by: abhinavg4 <[email protected]>
Co-authored-by: Huy Vu2 <[email protected]>
Co-authored-by: huvunvidia <[email protected]>
Co-authored-by: artbataev <[email protected]>
  • Loading branch information
5 people authored and BoxiangW committed Feb 7, 2025
1 parent 5603bc5 commit d9d3103
Show file tree
Hide file tree
Showing 19 changed files with 2,178 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

@dataclass
class AugmentationCfg:
"""Augmentation Config"""

scale: Tuple[float, float] = (0.9, 1.0)
ratio: Optional[Tuple[float, float]] = None
color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
Expand All @@ -56,6 +58,8 @@ class AugmentationCfg:


class ResizeMaxSize(nn.Module):
"""Resize module"""

def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
super().__init__()
if not isinstance(max_size, int):
Expand All @@ -66,6 +70,7 @@ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max',
self.fill = fill

def forward(self, img):
# pylint: disable=C0116
if isinstance(img, torch.Tensor):
height, width = img.shape[:2]
else:
Expand All @@ -82,6 +87,7 @@ def forward(self, img):


def _convert_to_rgb(image):
# pylint: disable=C0116
return image.convert('RGB')


Expand All @@ -94,7 +100,9 @@ def image_transform(
fill_color: int = 0,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
):
# pylint: disable=C0116
assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required."

mean = mean or OPENAI_DATASET_MEAN
if not isinstance(mean, (list, tuple)):
mean = (mean,) * 3
Expand Down Expand Up @@ -139,7 +147,9 @@ def image_transform(
train_transform = Compose(
[
RandomResizedCrop(
image_size, scale=aug_cfg_dict.pop('scale'), interpolation=InterpolationMode.BICUBIC,
image_size,
scale=aug_cfg_dict.pop('scale'),
interpolation=InterpolationMode.BICUBIC,
),
_convert_to_rgb,
ToTensor(),
Expand All @@ -160,6 +170,10 @@ def image_transform(
CenterCrop(image_size),
]
transforms.extend(
[_convert_to_rgb, ToTensor(), normalize,]
[
_convert_to_rgb,
ToTensor(),
normalize,
]
)
return Compose(transforms)
104 changes: 103 additions & 1 deletion nemo/collections/multimodal/data/clip/clip_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,31 @@ def tokenize(texts: Union[str, List[str]], tokenizer: Any, context_length: int =
return result


# pylint: disable=C0116
def get_preprocess_fns_params(
img_h, img_w, img_mean=None, img_std=None, is_train=True, max_position_embedding=None, tokenizer=None
):

# This is equivalent to `get_preprocess_fns` but does not need the whole config to get the functions. This is
# Particularly used in Nemo2
# Define transforms
img_size = (img_h, img_w)
img_transform = image_transform(
img_size,
is_train=is_train,
mean=img_mean,
std=img_std,
)
text_transform = lambda x: x
if tokenizer is not None:
text_transform = partial(
tokenize,
tokenizer=tokenizer,
context_length=max_position_embedding,
)
return img_transform, text_transform


def get_preprocess_fns(model_cfg, tokenizer=None, is_train=True):
# Define transforms
img_size = (model_cfg.vision.get("img_h"), model_cfg.vision.get("img_w"))
Expand Down Expand Up @@ -104,7 +129,8 @@ def tuple_to_dict(inp):

def transform_fn(sample, img_transform, text_transform):
image, text = sample["jpg"], sample["txt"]
return img_transform(image), text_transform(text)
img_transformed, text_transformed = img_transform(image), text_transform(text)
return img_transformed, text_transformed


def build_train_valid_datasets(
Expand Down Expand Up @@ -144,8 +170,79 @@ def custom_collate(batch):
return default_collate(batch)


def build_imagenet_validation_dataloader_params(
imagenet_val,
img_h,
img_w,
mbs,
gbs,
num_workers=0,
pin_memory=True,
img_mean=None,
img_std=None,
is_train=False,
max_position_embedding=None,
tokenizer=None,
):
# This is equivalent to `build_imagenet_validation_dataloader` but does not need the whole config.
# Particularly used in Nemo2
val_image_transform, text_transform = get_preprocess_fns_params(
img_h,
img_w,
img_mean,
img_std,
is_train=is_train,
max_position_embedding=max_position_embedding,
tokenizer=tokenizer,
)

imagenet_val_data = {}

imagenet_path = imagenet_val
if imagenet_path is None:
return None

image_dataset = ImageFolder(
root=imagenet_path,
transform=val_image_transform,
)

image_batch_sampler = MegatronPretrainingSampler(
total_samples=len(image_dataset),
consumed_samples=0,
micro_batch_size=mbs,
global_batch_size=gbs,
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size(),
drop_last=False,
)

imagenet_val_data["images"] = torch.utils.data.DataLoader(
image_dataset,
batch_sampler=image_batch_sampler,
num_workers=num_workers,
collate_fn=custom_collate,
pin_memory=pin_memory,
persistent_workers=True,
)
text_dataset = ImagenetClassnameDataset(imagenet_classnames, openai_imagenet_template, text_transform)

imagenet_val_data["texts"] = torch.utils.data.DataLoader(
text_dataset,
batch_size=text_dataset.num_templates,
num_workers=0,
pin_memory=True,
persistent_workers=False,
drop_last=False,
)

return imagenet_val_data


# pylint: enable=C0116
# For zero-shot imagenet validation
def build_imagenet_validation_dataloader(model_cfg, tokenizer=None):
"""Build dataloaders"""
val_image_transform, text_transform = get_preprocess_fns(model_cfg, tokenizer, is_train=False)
data_cfg = model_cfg.data

Expand Down Expand Up @@ -192,15 +289,20 @@ def build_imagenet_validation_dataloader(model_cfg, tokenizer=None):


class ImagenetClassnameDataset(Dataset):
"""Imagenet class dataset"""

def __init__(self, classnames, templates, text_transform):
# pylint: disable=C0116
self.num_templates = len(templates)
self.samples = []
for classname in classnames:
texts = [template(classname) for template in templates]
self.samples.extend(text_transform(texts))

def __getitem__(self, index):
# pylint: disable=C0116
return self.samples[index]

def __len__(self):
# pylint: disable=C0116
return len(self.samples)
63 changes: 48 additions & 15 deletions nemo/collections/multimodal/data/energon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,16 @@ def __init__(
micro_batch_size: int = 1,
global_batch_size: int = 1,
num_workers: int = 1,
num_val_workers: int | None = None,
pin_memory: bool = True,
shuffle_buffer_size: int = 100,
max_samples_per_sequence: int | None = None,
multimodal_sample_config: Optional[MultiModalSampleConfig] = MultiModalSampleConfig(),
task_encoder: Optional[MultiModalTaskEncoder] = None,
decoder_seq_length: Optional[int] = None,
packing_buffer_size: Optional[int] = None,
validation_task_encoder: Optional[MultiModalTaskEncoder] = None,
**kwargs,
) -> None:
"""
Initialize the EnergonMultiModalDataModule.
Expand All @@ -80,13 +85,20 @@ def __init__(
seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048.
micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1.
num_workers (int, optional): Number of workers for data loading. Defaults to 1.
num_val_workers (int, optional): Number of workers for validation data loading. Defaults to num_workers.
pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True.
multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples.
Defaults to MultiModalSampleConfig().
shuffle_buffer_size (int, optional): Size of the shuffle buffer. Defaults to 100.
max_samples_per_sequence (int, optional): Maximum number of samples per sequence to load from memory.
Defaults to None (loads the whole tar file at once).
task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples.
If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None.
decoder_seq_length (int, optional): The maximum sequence length for the decoder. Used in encoder-decoder models.
decoder_seq_length (int, optional): The max sequence length for the decoder. Used in encoder-decoder models
packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None.
validation_task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding
and batching samples for validation. Defaults to None and will be the same as task_encoder.
**kwargs: Additional keyword arguments. Will be passed to get_train_dataset() of Energon
"""

super().__init__()
Expand All @@ -102,6 +114,8 @@ def __init__(
self.num_workers = num_workers
self.pin_memory = pin_memory
self.multimodal_sample_config = multimodal_sample_config
self.shuffle_buffer_size = shuffle_buffer_size
self.max_samples_per_sequence = max_samples_per_sequence
self.task_encoder = task_encoder or MultiModalTaskEncoder(
tokenizer=self.tokenizer,
image_processor=self.image_processor,
Expand All @@ -117,10 +131,17 @@ def __init__(
self.train_dataloader_object = None
self.val_dataloader_object = None
self.packing_buffer_size = packing_buffer_size
self.validation_task_encoder = validation_task_encoder or self.task_encoder
self.num_val_workers = num_val_workers or self.num_workers
self.kwargs = kwargs

def io_init(self, **kwargs) -> fdl.Config[Self]:

cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items() if k not in ['image_processor', 'task_encoder']}
cfg_kwargs = {
k: deepcopy(v)
for k, v in kwargs.items()
if k not in ['image_processor', 'task_encoder', 'validation_task_encoder']
}

for val in cfg_kwargs.values():
if not serialization.find_node_traverser(type(val)):
Expand All @@ -142,18 +163,27 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val
Returns:
Dataset: The dataset configured for the specified split.
"""

if split not in {'train', 'val'}:
raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.")

if split == "train":
task_encoder = self.task_encoder
else:
task_encoder = self.validation_task_encoder

_dataset = get_train_dataset(
self.path,
batch_size=self.micro_batch_size,
task_encoder=self.task_encoder,
task_encoder=task_encoder,
worker_config=worker_config,
max_samples_per_sequence=None,
packing_buffer_size=self.packing_buffer_size,
shuffle_buffer_size=100,
split_part=split,
shuffle_buffer_size=self.shuffle_buffer_size,
max_samples_per_sequence=self.max_samples_per_sequence,
**self.kwargs,
)

return _dataset

def train_dataloader(self) -> TRAIN_DATALOADERS:
Expand Down Expand Up @@ -216,9 +246,9 @@ def val_dataloader(self) -> EVAL_DATALOADERS:
if not parallel_state.is_initialized():
logging.info(
f"Muiltimodal val data loader parallel state is not initialized,"
"using default worker config with no_workers {self.num_workers}"
f"using default worker config with no_workers {self.num_workers}"
)
worker_config = WorkerConfig.default_worker_config(self.num_workers)
worker_config = WorkerConfig.default_worker_config(self.num_val_workers)
else:
rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
Expand Down Expand Up @@ -248,7 +278,7 @@ def test_dataloader(self) -> None:
Returns:
None
"""
logging.warning(f"Multimodal dataloader test dataset split does not exist")
logging.warning("Multimodal dataloader test dataset split does not exist")
return None

def state_dict(self) -> Dict[str, Any]:
Expand All @@ -264,7 +294,7 @@ def state_dict(self) -> Dict[str, Any]:

if self.trainer:
dataloader_obj = self.trainer.train_dataloader
state = dataloader_obj.save_state()
state = dataloader_obj.save_state_global(dst_rank=0)
consumed_samples = self.data_sampler.compute_consumed_samples(
self.trainer.global_step - self.init_global_step
)
Expand All @@ -286,24 +316,27 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
if not 'dataloader_state' in state_dict:
logging.warning(
f"Data loader state cannot be resumed from state_dict,"
f"Data loader state cannot be resumed from state_dict, "
f"it does not have the required key dataloader_state. It has {state_dict.keys()}"
)
return

state = state_dict['dataloader_state']
try:
if self.trainer:
self.trainer.datamodule.train_dataloader().restore_state(state)
logging.info(f" Multimodal dataloader state restored")
self.trainer.datamodule.train_dataloader().restore_state(state, src_rank=0)
logging.info("Multimodal dataloader state restored")
else:
logging.error(f"Cannot restore state from state_dict {state_dict}")
raise ValueError(
f"Cannot restore state from state_dict: "
f"Is the trainer object is initialized and attached to datamodule???"
"Cannot restore state from state_dict: "
"Is the trainer object is initialized and attached to datamodule???"
)
except Exception as e:
raise RuntimeError(f"Failed to dataloader restore state due to: {e}")
logging.warning(
f"Failed to dataloader restore state due to [Please ensure you are using same version "
f"of energon while saving and loading, Continuing without restoring data loader] : {e}"
)

try:
from megatron.core.num_microbatches_calculator import update_num_microbatches
Expand Down
Loading

0 comments on commit d9d3103

Please sign in to comment.