Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clip Model in Nemo2 #11980

Merged
merged 76 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
277dfad
Initial commit
abhinavg4 Dec 20, 2024
27b112c
Apply isort and black reformatting
abhinavg4 Dec 20, 2024
ee31967
Initial commit
abhinavg4 Dec 20, 2024
4ef3f3f
Apply isort and black reformatting
abhinavg4 Dec 20, 2024
a3a9214
Adding clip data module
abhinavg4 Jan 16, 2025
96c6da0
Adding some data modules
abhinavg4 Jan 16, 2025
aab7d84
Merge remote-tracking branch 'origin/main' into wip_ab/clip
abhinavg4 Jan 16, 2025
b7ea72c
Adding code for Yash
abhinavg4 Jan 18, 2025
7710b0d
Adding code for Yash
abhinavg4 Jan 18, 2025
b554044
first commit
Jan 22, 2025
2def80c
cleaning codes
Jan 22, 2025
b36343c
runnable code for one vision encoder
Jan 22, 2025
4b7c17b
Apply isort and black reformatting
huvunvidia Jan 22, 2025
f0801dd
Newer code
abhinavg4 Jan 22, 2025
8b82de4
Merge remote-tracking branch 'origin/huvu/openvla_dataloader' into wi…
abhinavg4 Jan 23, 2025
894a06d
remove irrelevant code in backbones (download models and models-relat…
Jan 24, 2025
f291de6
Apply isort and black reformatting
huvunvidia Jan 24, 2025
0b809a4
Merge remote-tracking branch 'origin/huvu/openvla_dataloader' into wi…
abhinavg4 Jan 24, 2025
21f4aaf
Code for Huy
abhinavg4 Jan 24, 2025
b23bdad
Code for Huy
abhinavg4 Jan 24, 2025
d2225bd
Code for Huy
abhinavg4 Jan 24, 2025
79d8bd1
Code for Huy
abhinavg4 Jan 25, 2025
e53cda0
Code for Huy
abhinavg4 Jan 25, 2025
99f2122
Code for Huy
abhinavg4 Jan 27, 2025
6ae3f90
Removing OpenVLA related commits
abhinavg4 Jan 28, 2025
c0cec3b
Removing OpenVLA related commits
abhinavg4 Jan 28, 2025
6a09cd5
Code for Clip
abhinavg4 Jan 28, 2025
1e9a2a6
Apply isort and black reformatting
abhinavg4 Jan 28, 2025
cf1be0c
Apply isort and black reformatting
artbataev Jan 28, 2025
291cf6a
Removing some debug commits
abhinavg4 Jan 28, 2025
c402664
Removing some debug commits
abhinavg4 Jan 28, 2025
4ef2120
Merge remote-tracking branch 'origin/main' into wip_ab/clip
abhinavg4 Jan 28, 2025
7049551
Code for Clip
abhinavg4 Jan 28, 2025
ec641b2
Code for Clip
abhinavg4 Jan 28, 2025
fb56f0e
Code for Clip
abhinavg4 Jan 28, 2025
5848a6d
Code for Clip
abhinavg4 Jan 28, 2025
d76b077
Removing some not needed files
abhinavg4 Jan 28, 2025
75e747f
Some more minor change
abhinavg4 Jan 28, 2025
eb283de
Apply isort and black reformatting
abhinavg4 Jan 28, 2025
a0fd44b
Adding Nemo Run
abhinavg4 Jan 29, 2025
97c1a2b
Merge remote-tracking branch 'origin/wip_ab/clip' into wip_ab/clip
abhinavg4 Jan 29, 2025
7ef9e21
Apply isort and black reformatting
abhinavg4 Jan 29, 2025
1590b4c
Batch size changes
abhinavg4 Jan 29, 2025
ec0361f
Merge remote-tracking branch 'origin/wip_ab/clip' into wip_ab/clip
abhinavg4 Jan 29, 2025
3e883e6
PR comments
abhinavg4 Jan 29, 2025
2540f82
Apply isort and black reformatting
abhinavg4 Jan 29, 2025
e2d8230
PR comments
abhinavg4 Jan 29, 2025
84323ae
Merge remote-tracking branch 'origin/wip_ab/clip' into wip_ab/clip
abhinavg4 Jan 29, 2025
b44be91
Apply isort and black reformatting
abhinavg4 Jan 29, 2025
431f063
PR comments
abhinavg4 Jan 29, 2025
c380fba
Apply isort and black reformatting
abhinavg4 Jan 30, 2025
9c4c8b0
Adding Energon requirements
abhinavg4 Jan 30, 2025
5569ac2
Merge remote-tracking branch 'origin/wip_ab/clip' into wip_ab/clip
abhinavg4 Jan 30, 2025
f9021c9
Apply isort and black reformatting
abhinavg4 Jan 30, 2025
752bdfa
PR Changes
abhinavg4 Jan 30, 2025
4032c4c
Merge remote-tracking branch 'origin/wip_ab/clip' into wip_ab/clip
abhinavg4 Jan 30, 2025
cb479cf
Apply isort and black reformatting
abhinavg4 Jan 30, 2025
2238200
PR Changes
abhinavg4 Jan 30, 2025
cc6a6f6
Merge remote-tracking branch 'origin/wip_ab/clip' into wip_ab/clip
abhinavg4 Jan 30, 2025
3cce712
Merge branch 'main' into wip_ab/clip
abhinavg4 Jan 30, 2025
01c9f31
Merge branch 'main' into wip_ab/clip
abhinavg4 Jan 31, 2025
bca69de
Changes to Energon
abhinavg4 Jan 31, 2025
fdcb896
Merge remote-tracking branch 'origin/wip_ab/clip' into wip_ab/clip
abhinavg4 Jan 31, 2025
da9715b
Merge branch 'main' into wip_ab/clip
abhinavg4 Feb 1, 2025
a9d9352
Changes for PR
abhinavg4 Feb 3, 2025
84465eb
Merge remote-tracking branch 'origin/wip_ab/clip' into wip_ab/clip
abhinavg4 Feb 3, 2025
6f38106
Apply isort and black reformatting
abhinavg4 Feb 3, 2025
bb1c7d8
Changes for PR
abhinavg4 Feb 3, 2025
253a984
Merge remote-tracking branch 'origin/wip_ab/clip' into wip_ab/clip
abhinavg4 Feb 3, 2025
0b4de02
Apply isort and black reformatting
abhinavg4 Feb 3, 2025
994d69f
Changes for PR
abhinavg4 Feb 3, 2025
72ce757
Merge remote-tracking branch 'origin/wip_ab/clip' into wip_ab/clip
abhinavg4 Feb 3, 2025
0785ee8
Merge remote-tracking branch 'origin/main' into wip_ab/clip
abhinavg4 Feb 3, 2025
41e1527
Apply isort and black reformatting
abhinavg4 Feb 3, 2025
3463ff6
Merge remote-tracking branch 'origin/wip_ab/clip' into wip_ab/clip
abhinavg4 Feb 3, 2025
2233f22
Merge branch 'main' into wip_ab/clip
abhinavg4 Feb 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

@dataclass
class AugmentationCfg:
"""Augmentation Config"""

scale: Tuple[float, float] = (0.9, 1.0)
ratio: Optional[Tuple[float, float]] = None
color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
Expand All @@ -56,6 +58,8 @@ class AugmentationCfg:


class ResizeMaxSize(nn.Module):
"""Resize module"""

def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
super().__init__()
if not isinstance(max_size, int):
Expand All @@ -66,6 +70,7 @@ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max',
self.fill = fill

def forward(self, img):
# pylint: disable=C0116
if isinstance(img, torch.Tensor):
height, width = img.shape[:2]
else:
Expand All @@ -82,6 +87,7 @@ def forward(self, img):


def _convert_to_rgb(image):
# pylint: disable=C0116
return image.convert('RGB')


Expand All @@ -94,7 +100,9 @@ def image_transform(
fill_color: int = 0,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
):
# pylint: disable=C0116
assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required."

mean = mean or OPENAI_DATASET_MEAN
if not isinstance(mean, (list, tuple)):
mean = (mean,) * 3
Expand Down Expand Up @@ -139,7 +147,9 @@ def image_transform(
train_transform = Compose(
[
RandomResizedCrop(
image_size, scale=aug_cfg_dict.pop('scale'), interpolation=InterpolationMode.BICUBIC,
image_size,
scale=aug_cfg_dict.pop('scale'),
interpolation=InterpolationMode.BICUBIC,
),
_convert_to_rgb,
ToTensor(),
Expand All @@ -160,6 +170,10 @@ def image_transform(
CenterCrop(image_size),
]
transforms.extend(
[_convert_to_rgb, ToTensor(), normalize,]
[
_convert_to_rgb,
ToTensor(),
normalize,
]
)
return Compose(transforms)
104 changes: 103 additions & 1 deletion nemo/collections/multimodal/data/clip/clip_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,31 @@ def tokenize(texts: Union[str, List[str]], tokenizer: Any, context_length: int =
return result


# pylint: disable=C0116
def get_preprocess_fns_params(
img_h, img_w, img_mean=None, img_std=None, is_train=True, max_position_embedding=None, tokenizer=None
):

# This is equivalent to `get_preprocess_fns` but does not need the whole config to get the functions. This is
# Particularly used in Nemo2
# Define transforms
img_size = (img_h, img_w)
img_transform = image_transform(
img_size,
is_train=is_train,
mean=img_mean,
std=img_std,
)
text_transform = lambda x: x
if tokenizer is not None:
text_transform = partial(
tokenize,
tokenizer=tokenizer,
context_length=max_position_embedding,
)
return img_transform, text_transform


def get_preprocess_fns(model_cfg, tokenizer=None, is_train=True):
# Define transforms
img_size = (model_cfg.vision.get("img_h"), model_cfg.vision.get("img_w"))
Expand Down Expand Up @@ -104,7 +129,8 @@ def tuple_to_dict(inp):

def transform_fn(sample, img_transform, text_transform):
image, text = sample["jpg"], sample["txt"]
return img_transform(image), text_transform(text)
img_transformed, text_transformed = img_transform(image), text_transform(text)
return img_transformed, text_transformed


def build_train_valid_datasets(
Expand Down Expand Up @@ -144,8 +170,79 @@ def custom_collate(batch):
return default_collate(batch)


def build_imagenet_validation_dataloader_params(
abhinavg4 marked this conversation as resolved.
Show resolved Hide resolved
imagenet_val,
img_h,
img_w,
mbs,
gbs,
num_workers=0,
pin_memory=True,
img_mean=None,
img_std=None,
is_train=False,
max_position_embedding=None,
tokenizer=None,
):
# This is equivalent to `build_imagenet_validation_dataloader` but does not need the whole config.
# Particularly used in Nemo2
val_image_transform, text_transform = get_preprocess_fns_params(
img_h,
img_w,
img_mean,
img_std,
is_train=is_train,
max_position_embedding=max_position_embedding,
tokenizer=tokenizer,
)

imagenet_val_data = {}

imagenet_path = imagenet_val
if imagenet_path is None:
return None

image_dataset = ImageFolder(
root=imagenet_path,
transform=val_image_transform,
)

image_batch_sampler = MegatronPretrainingSampler(
total_samples=len(image_dataset),
consumed_samples=0,
micro_batch_size=mbs,
global_batch_size=gbs,
data_parallel_rank=parallel_state.get_data_parallel_rank(),
data_parallel_size=parallel_state.get_data_parallel_world_size(),
drop_last=False,
)

imagenet_val_data["images"] = torch.utils.data.DataLoader(
image_dataset,
batch_sampler=image_batch_sampler,
num_workers=num_workers,
collate_fn=custom_collate,
pin_memory=pin_memory,
persistent_workers=True,
)
text_dataset = ImagenetClassnameDataset(imagenet_classnames, openai_imagenet_template, text_transform)

imagenet_val_data["texts"] = torch.utils.data.DataLoader(
text_dataset,
batch_size=text_dataset.num_templates,
num_workers=0,
pin_memory=True,
persistent_workers=False,
drop_last=False,
)

return imagenet_val_data


# pylint: enable=C0116
# For zero-shot imagenet validation
def build_imagenet_validation_dataloader(model_cfg, tokenizer=None):
"""Build dataloaders"""
val_image_transform, text_transform = get_preprocess_fns(model_cfg, tokenizer, is_train=False)
data_cfg = model_cfg.data

Expand Down Expand Up @@ -192,15 +289,20 @@ def build_imagenet_validation_dataloader(model_cfg, tokenizer=None):


class ImagenetClassnameDataset(Dataset):
"""Imagenet class dataset"""

def __init__(self, classnames, templates, text_transform):
# pylint: disable=C0116
self.num_templates = len(templates)
self.samples = []
for classname in classnames:
texts = [template(classname) for template in templates]
self.samples.extend(text_transform(texts))

def __getitem__(self, index):
# pylint: disable=C0116
return self.samples[index]

def __len__(self):
# pylint: disable=C0116
return len(self.samples)
63 changes: 48 additions & 15 deletions nemo/collections/multimodal/data/energon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,16 @@ def __init__(
micro_batch_size: int = 1,
global_batch_size: int = 1,
num_workers: int = 1,
num_val_workers: int | None = None,
pin_memory: bool = True,
shuffle_buffer_size: int = 100,
max_samples_per_sequence: int | None = None,
multimodal_sample_config: Optional[MultiModalSampleConfig] = MultiModalSampleConfig(),
task_encoder: Optional[MultiModalTaskEncoder] = None,
decoder_seq_length: Optional[int] = None,
packing_buffer_size: Optional[int] = None,
validation_task_encoder: Optional[MultiModalTaskEncoder] = None,
**kwargs,
) -> None:
"""
Initialize the EnergonMultiModalDataModule.
Expand All @@ -80,13 +85,20 @@ def __init__(
seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048.
micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1.
num_workers (int, optional): Number of workers for data loading. Defaults to 1.
num_val_workers (int, optional): Number of workers for validation data loading. Defaults to num_workers.
pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True.
multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples.
Defaults to MultiModalSampleConfig().
shuffle_buffer_size (int, optional): Size of the shuffle buffer. Defaults to 100.
max_samples_per_sequence (int, optional): Maximum number of samples per sequence to load from memory.
Defaults to None (loads the whole tar file at once).
task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples.
If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None.
decoder_seq_length (int, optional): The maximum sequence length for the decoder. Used in encoder-decoder models.
decoder_seq_length (int, optional): The max sequence length for the decoder. Used in encoder-decoder models
packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None.
validation_task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding
and batching samples for validation. Defaults to None and will be the same as task_encoder.
**kwargs: Additional keyword arguments. Will be passed to get_train_dataset() of Energon
"""

super().__init__()
Expand All @@ -102,6 +114,8 @@ def __init__(
self.num_workers = num_workers
self.pin_memory = pin_memory
self.multimodal_sample_config = multimodal_sample_config
self.shuffle_buffer_size = shuffle_buffer_size
self.max_samples_per_sequence = max_samples_per_sequence
self.task_encoder = task_encoder or MultiModalTaskEncoder(
tokenizer=self.tokenizer,
image_processor=self.image_processor,
Expand All @@ -117,10 +131,17 @@ def __init__(
self.train_dataloader_object = None
self.val_dataloader_object = None
self.packing_buffer_size = packing_buffer_size
self.validation_task_encoder = validation_task_encoder or self.task_encoder
self.num_val_workers = num_val_workers or self.num_workers
self.kwargs = kwargs

def io_init(self, **kwargs) -> fdl.Config[Self]:

cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items() if k not in ['image_processor', 'task_encoder']}
cfg_kwargs = {
k: deepcopy(v)
for k, v in kwargs.items()
if k not in ['image_processor', 'task_encoder', 'validation_task_encoder']
}

for val in cfg_kwargs.values():
if not serialization.find_node_traverser(type(val)):
Expand All @@ -142,18 +163,27 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val
Returns:
Dataset: The dataset configured for the specified split.
"""

if split not in {'train', 'val'}:
raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.")

if split == "train":
abhinavg4 marked this conversation as resolved.
Show resolved Hide resolved
task_encoder = self.task_encoder
else:
task_encoder = self.validation_task_encoder

_dataset = get_train_dataset(
self.path,
batch_size=self.micro_batch_size,
task_encoder=self.task_encoder,
task_encoder=task_encoder,
worker_config=worker_config,
max_samples_per_sequence=None,
packing_buffer_size=self.packing_buffer_size,
shuffle_buffer_size=100,
split_part=split,
shuffle_buffer_size=self.shuffle_buffer_size,
max_samples_per_sequence=self.max_samples_per_sequence,
**self.kwargs,
)

return _dataset

def train_dataloader(self) -> TRAIN_DATALOADERS:
Expand Down Expand Up @@ -216,9 +246,9 @@ def val_dataloader(self) -> EVAL_DATALOADERS:
if not parallel_state.is_initialized():
logging.info(
f"Muiltimodal val data loader parallel state is not initialized,"
"using default worker config with no_workers {self.num_workers}"
f"using default worker config with no_workers {self.num_workers}"
)
worker_config = WorkerConfig.default_worker_config(self.num_workers)
worker_config = WorkerConfig.default_worker_config(self.num_val_workers)
else:
rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
Expand Down Expand Up @@ -248,7 +278,7 @@ def test_dataloader(self) -> None:
Returns:
None
"""
logging.warning(f"Multimodal dataloader test dataset split does not exist")
logging.warning("Multimodal dataloader test dataset split does not exist")
return None

def state_dict(self) -> Dict[str, Any]:
Expand All @@ -264,7 +294,7 @@ def state_dict(self) -> Dict[str, Any]:

if self.trainer:
dataloader_obj = self.trainer.train_dataloader
state = dataloader_obj.save_state()
state = dataloader_obj.save_state_global(dst_rank=0)
abhinavg4 marked this conversation as resolved.
Show resolved Hide resolved
consumed_samples = self.data_sampler.compute_consumed_samples(
self.trainer.global_step - self.init_global_step
)
Expand All @@ -286,24 +316,27 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
if not 'dataloader_state' in state_dict:
logging.warning(
f"Data loader state cannot be resumed from state_dict,"
f"Data loader state cannot be resumed from state_dict, "
f"it does not have the required key dataloader_state. It has {state_dict.keys()}"
)
return

state = state_dict['dataloader_state']
try:
if self.trainer:
self.trainer.datamodule.train_dataloader().restore_state(state)
logging.info(f" Multimodal dataloader state restored")
self.trainer.datamodule.train_dataloader().restore_state(state, src_rank=0)
logging.info("Multimodal dataloader state restored")
else:
logging.error(f"Cannot restore state from state_dict {state_dict}")
raise ValueError(
f"Cannot restore state from state_dict: "
f"Is the trainer object is initialized and attached to datamodule???"
"Cannot restore state from state_dict: "
"Is the trainer object is initialized and attached to datamodule???"
)
except Exception as e:
raise RuntimeError(f"Failed to dataloader restore state due to: {e}")
logging.warning(
f"Failed to dataloader restore state due to [Please ensure you are using same version "
f"of energon while saving and loading, Continuing without restoring data loader] : {e}"
)

try:
from megatron.core.num_microbatches_calculator import update_num_microbatches
Expand Down
Loading
Loading