Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

r2.1.0 cherrypick #11680

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions nemo/collections/llm/gpt/data/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
self.packed_sequence_size = -1 if not packed_sequence_specs else packed_sequence_specs.packed_sequence_size
self.validate_batch_size_for_packed_sequence()
self.dataset_kwargs = dataset_kwargs or {}
self.init_global_step = 0

def validate_batch_size_for_packed_sequence(self):
"""
Expand Down Expand Up @@ -163,9 +164,7 @@ def state_dict(self) -> Dict[str, Any]:
A dictionary containing datamodule state.

"""
consumed_samples = self.data_sampler.compute_consumed_samples(
self.trainer.global_step - self.data_sampler.init_global_step
)
consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
return {"consumed_samples": consumed_samples}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -240,6 +239,8 @@ def _create_dataset(self, path, is_test=False, **kwargs):

def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader:
# pylint: disable=C0115,C0116
self.init_global_step = self.trainer.global_step
self.data_sampler.init_global_step = self.init_global_step
return WrappedDataLoader(
mode=mode,
dataset=dataset,
Expand Down
7 changes: 6 additions & 1 deletion nemo/core/connectors/save_restore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,12 @@ def _is_safe_path(member, extract_to):
# Construct the full path where the member would be extracted
full_path = os.path.join(extract_to, member_path)
# Ensure the member would be extracted within the intended directory
return os.path.commonprefix([full_path, extract_to]) == extract_to
if os.path.commonprefix([full_path, extract_to]) != extract_to:
return False
# Check if the member is a symbolic link
if member.issym() or member.islnk():
return False
return True

@staticmethod
def _safe_extract(tar, out_folder: str, members=None):
Expand Down
2 changes: 0 additions & 2 deletions nemo/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.export.tensorrt_lazy_compiler import trt_compile
57 changes: 52 additions & 5 deletions nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
from nemo.export.tarutils import TarPath, ZarrPathStore
from nemo.export.tiktoken_tokenizer import TiktokenTokenizer

try:
from nemo.lightning import io

HAVE_NEMO2 = True
except (ImportError, ModuleNotFoundError):
HAVE_NEMO2 = False

LOGGER = logging.getLogger("NeMo")


Expand Down Expand Up @@ -287,14 +294,54 @@ def copy_tokenizer_files(config, out_dir):
return config


def get_tokenizer_from_nemo2_context(model_context_dir: Path):
"""
Retrieve tokenizer configuration from NeMo 2.0 context and instantiate the tokenizer.

Args:
model_context_dir (Path): Path to the model context directory.

Returns:
The instantiated tokenizer (various classes possible).
"""

if HAVE_NEMO2:
# Use NeMo tokenizer loaded from the NeMo 2.0 model context
tokenizer_spec = io.load_context(model_context_dir, subpath="model.tokenizer")
return build_tokenizer(tokenizer_spec)
else:
# Use local nemo.export SentencePieceTokenizer implementation
# or directly a HuggingFace tokenizer based on the model config
with (model_context_dir / "model.yaml").open("r") as stream:
model_config = yaml.safe_load(stream)

tokenizer_config = model_config["tokenizer"]
target_class = tokenizer_config["_target_"]
tokenizer_module = "nemo.collections.common.tokenizers."
assert target_class.startswith(tokenizer_module)
target_class = target_class.removeprefix(tokenizer_module)

if target_class == "sentencepiece_tokenizer.SentencePieceTokenizer":
tokenizer = SentencePieceTokenizer(
model_path=str(model_context_dir / tokenizer_config["model_path"]),
special_tokens=tokenizer_config.get("special_tokens", None),
legacy=tokenizer_config.get("legacy", False),
)
elif target_class == "huggingface.auto_tokenizer.AutoTokenizer":
tokenizer = AutoTokenizer.from_pretrained(
str(model_context_dir / tokenizer_config["pretrained_model_name"])
)
else:
raise ValueError(f"Unsupported tokenizer type: {tokenizer_module}{target_class}.")

return tokenizer


def get_tokenizer(tokenizer_dir_or_path: Union[str, Path]) -> PreTrainedTokenizer:
"""Loads the tokenizer from the decoded NeMo weights dir."""
tokenizer_dir_or_path = Path(tokenizer_dir_or_path)
if (tokenizer_dir_or_path / "nemo_context").exists():
from nemo.lightning import io

tokenizer_spec = io.load_context((tokenizer_dir_or_path / "nemo_context"), subpath="model.tokenizer")
return build_tokenizer(tokenizer_spec)
return get_tokenizer_from_nemo2_context(tokenizer_dir_or_path / "nemo_context")
elif os.path.exists(os.path.join(tokenizer_dir_or_path, "vocab.json")):
vocab_path = tokenizer_dir_or_path / "vocab.json" if tokenizer_dir_or_path.is_dir() else tokenizer_dir_or_path
tokenizer_config = {"library": "tiktoken", "vocab_file": str(vocab_path)}
Expand Down Expand Up @@ -474,7 +521,7 @@ def load_nemo_model(nemo_ckpt: Union[str, Path], nemo_export_dir: Union[str, Pat
elif k == "activation_func":
nemo_model_config["activation"] = v["_target_"].rsplit('.', 1)[-1]
else:
from nemo.lightning import io
assert HAVE_NEMO2, "nemo_toolkit>=2.0.0 is required to load the model context."

config = io.load_context(io_folder, subpath="model.config")

Expand Down
3 changes: 2 additions & 1 deletion nemo/lightning/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from nemo.lightning.io.api import export_ckpt, import_ckpt, load, load_context, model_exporter, model_importer
from nemo.lightning.io.capture import reinit
from nemo.lightning.io.connector import Connector, ModelConnector
from nemo.lightning.io.mixin import ConnectorMixin, IOMixin, track_io
from nemo.lightning.io.mixin import ConnectorMixin, IOMixin, drop_unexpected_params, track_io
from nemo.lightning.io.pl import TrainerContext, is_distributed_ckpt
from nemo.lightning.io.state import TransformCTX, apply_transforms, state_transform

__all__ = [
"apply_transforms",
"Connector",
"ConnectorMixin",
"drop_unexpected_params",
"IOMixin",
"track_io",
"import_ckpt",
Expand Down
39 changes: 39 additions & 0 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,45 @@ def _artifact_transform_load(cfg: fdl.Config, path: Path):
pass


def drop_unexpected_params(config: fdl.Config) -> bool:
"""
Analyzes config to detect unexpected keyword arguments -- for example, deprecated parameters -- and
updates the config by dropping them. Returns True if the config gets updated and False otherwise.

Args:
config (fdl.Config): The configuration object to analyze.
"""

updated = False

def analyze(config: fdl.Config, prefix: str):

if isinstance(config, fdl.Config):
signature = inspect.signature(config.__fn_or_cls__)

accept_kwargs = any(param.kind is inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values())

if not accept_kwargs:
to_drop = [param for param in config.__arguments__ if param not in signature.parameters]

if to_drop:
nonlocal updated
updated = True
logging.warning(f"Deprecated parameters to drop from {prefix}: {to_drop}")
for param in to_drop:
del config.__arguments__[param]
else:
logging.info(f"Skip analyzing {prefix} as it accepts arbitrary keyword arguments.")

# Proceed recursively for all arguments
for key, value in config.__arguments__.items():
analyze(value, prefix + "." + key)

analyze(config, "<root>")

return updated


def load(path: Path, output_type: Type[CkptType] = Any, subpath: Optional[str] = None, build: bool = True) -> CkptType:
"""
Loads a configuration from a pickle file and constructs an object of the specified type.
Expand Down
7 changes: 6 additions & 1 deletion nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,12 @@ def apply_transform(self, trainer):
)
trainer.strategy.load_model_state_dict(adapter_state, strict=False)
if trainer.state.fn == TrainerFn.FITTING:
trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=True)
# Load optimizer
trainer.strategy.load_optimizer_state_dict(adapter_state, selective_restore=False)
# Load lr scheduler
if (lr_schedulers := adapter_state.get('lr_schedulers', None)) is not None:
for config, lrs_state in zip(trainer.lr_scheduler_configs, lr_schedulers):
config.scheduler.load_state_dict(lrs_state)

for cb in trainer.callbacks[::-1]:
if isinstance(cb, MegatronOptimizerModule):
Expand Down
94 changes: 94 additions & 0 deletions scripts/llm/update_io_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import sys
from datetime import datetime
from pathlib import Path

import fiddle as fdl
from fiddle._src.experimental import serialization

from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.io import drop_unexpected_params, load
from nemo.utils import logging

IO_FILE = "io.json"

"""
Script to update NeMo 2.0 model context (stored in io.json) for unexpected
keword arguments for compatibility with the currently running environment.

It accepts path to a NeMo 2.0 checkpoint and optional flag for building
the updated configuration. It performs the following steps:

1. Loads config from the model context directory.
2. Checks the config for unexpected (e.g. deprecated) arguments and drops them.
3. Attempts to build the updated configuration if --build flag is on.
4. Backs up the existing context file and saves the updated configuration.
"""


def get_args():
"""Parses command line arguments."""
parser = argparse.ArgumentParser(
description="Script to drop unexpected arguments from NeMo 2.0 io.json model context."
)
parser.add_argument("--model_path", type=str, required=True, help="Path to a NeMo 2.0 checkpoint.")
parser.add_argument("--build", action="store_true", help="Whether to test building the updated config.")
return parser.parse_args()


def save_io(config: fdl.Config, path: str):
"""
Saves the given configuration object to a specified file path in JSON format.

Args:
config (fdl.Config): The configuration object to be saved.
path (str): The file path where the configuration will be saved.
"""
config_json = serialization.dump_json(config)
with open(path, "w") as f:
f.write(config_json)


if __name__ == "__main__":
args = get_args()

model_path = Path(args.model_path)
context_path = ckpt_to_context_subdir(model_path)
logging.info(f"Path to model context: {context_path}.")

config = load(context_path, build=False)
updated = drop_unexpected_params(config)

if not updated:
logging.info("Config does not need any updates.")
sys.exit(0)

if args.build:
try:
fdl.build(config)
except Exception as e:
logging.error("Build for the updated config failed.")
raise
else:
logging.info("Build for the updated config successful.")

# Backup the existing context file and save the updated config
io_path = context_path / IO_FILE
io_path_backup = context_path / f"BACKUP_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}_{IO_FILE}"
io_path.rename(io_path_backup)
save_io(config, io_path)
logging.info(f"Config saved to {io_path}.")
84 changes: 84 additions & 0 deletions tests/collections/llm/io/test_drop_unexpected_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import fiddle as fdl

from nemo.lightning.io import drop_unexpected_params


class TestDropUnexpectedParams:

def setup_method(self):
"""
Setup common test resources.
"""

class MockClassOld:
def __init__(self, x, y, deprecated):
pass

class MockClassNew:
def __init__(self, x, y):
pass

class OuterClass:
def __init__(self, z, t):
pass

self.MockClassOld = MockClassOld
self.MockClassNew = MockClassNew
self.OuterClass = OuterClass

def test_valid_config_stays_same(self):
"""
Test that a valid config remains unchanged.
"""

config = fdl.Config(self.MockClassNew, x=1, y=2)
updated = drop_unexpected_params(config)

assert not updated, "Expected the config to remain unchanged."
assert config.x == 1
assert config.y == 2

def test_config_updates(self):
"""
Test that a config with unexpected parameters gets updated.
"""
config = fdl.Config(self.MockClassOld, x=1, y=2, deprecated=3)

# Simulate deprecation issue by overriding target class
config.__dict__['__fn_or_cls__'] = self.MockClassNew

updated = drop_unexpected_params(config)
assert updated, "Expected the config to be updated."
assert config.x == 1
assert config.y == 2
assert not hasattr(config, "deprecated"), "Expected 'deprecated' to be removed from the config."

def test_nested_config_updates(self):
"""
Test that a nested config with unexpected parameters gets updated.
"""
config = fdl.Config(self.OuterClass, z=4, t=fdl.Config(self.MockClassOld, x=1, y=2, deprecated=3))

# Simulate deprecation issue by overriding target class
config.t.__dict__["__fn_or_cls__"] = self.MockClassNew

updated = drop_unexpected_params(config)
assert updated, "Expected the nested config to be updated."
assert config.z == 4
assert config.t.x == 1
assert config.t.y == 2
assert not hasattr(config.t, "deprecated"), "Expected 'deprecated' to be removed from the inner config."
Loading
Loading