Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PTQ memory optimization #11257

Merged
merged 15 commits into from
Nov 18, 2024
Merged
19 changes: 10 additions & 9 deletions nemo/collections/llm/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,8 @@ def export(self, model: llm.GPTModel, model_dir: str) -> None:
# TODO: Add sample generate
# TODO: Support megatron_amp_O2
export_dir = self.export_config.path
use_nfs_workspace = (model.trainer._fabric.__io__.num_nodes > 1) or (
model.config.pipeline_model_parallel_size > 1
)

use_nfs_workspace = model.config.pipeline_model_parallel_size > 1
export_tensorrt_llm_checkpoint(
model=get_unwrapped_mcore_model(model),
decoder_type=self._get_decoder_type(model.config),
Expand All @@ -284,15 +283,17 @@ def export(self, model: llm.GPTModel, model_dir: str) -> None:
inference_pipeline_parallel=self.export_config.inference_pipeline_parallel,
use_nfs_workspace=use_nfs_workspace,
)
dist.barrier()

# Save the model context in order to restore its tokenizer later. The destination
# path is "nemo_context" as this name is used in nemo.export to setup tokenizer.
shutil.copytree(
os.path.join(model_dir, CONTEXT_PATH),
os.path.join(export_dir, "nemo_context"),
dirs_exist_ok=True,
)
logging.info(f"Model context saved.")
if dist.get_rank() == 0:
shutil.copytree(
os.path.join(model_dir, CONTEXT_PATH),
os.path.join(export_dir, "nemo_context"),
dirs_exist_ok=True,
)
logging.info("Model context saved.")

logging.info(f"Export succeeded, model has been exported to {export_dir}.")

Expand Down
42 changes: 31 additions & 11 deletions nemo/collections/llm/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.llm.inference.base import _setup_trainer_and_restore_model
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.utils import logging

Expand All @@ -42,25 +43,44 @@ def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig:
return model_cfg


def load_with_modelopt_layer_spec(nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1) -> llm.GPTModel:
def load_with_modelopt_layer_spec(
nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1, inference_only: bool = True
):
# TODO: setting ddp="pytorch" with manually deleting model.optim is a hackish way to disable DDP initialization. Needs a systematic solution.
if inference_only:
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=calib_tp,
pipeline_model_parallel_size=calib_pp,
pipeline_dtype=torch.bfloat16,
ckpt_load_optimizer=False,
ckpt_parallel_save_optim=False,
setup_optimizers=False,
lazy_init=True,
ddp="pytorch",
)
else:
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=calib_tp, pipeline_model_parallel_size=calib_pp, pipeline_dtype=torch.bfloat16
)

trainer = nl.Trainer(
devices=calib_tp,
num_nodes=calib_pp,
strategy=nl.MegatronStrategy(
tensor_model_parallel_size=calib_tp, pipeline_model_parallel_size=calib_pp, pipeline_dtype=torch.bfloat16
),
plugins=nl.MegatronMixedPrecision(precision='bf16', pipeline_dtype=torch.bfloat16, autocast_enabled=True),
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision='bf16', params_dtype=torch.bfloat16, autocast_enabled=True),
)
fabric = trainer.to_fabric()
fabric.launch()

model_path = Path(nemo_checkpoint_path)
model = nl.io.load_context(ckpt_to_context_subdir(model_path)).model
model = nl.io.load_context(path=ckpt_to_context_subdir(model_path), subpath="model")
model.config = quantizable_model_config(model.config)
return fabric.load_model(nemo_checkpoint_path, model=model)

if inference_only:
del model.optim

_setup_trainer_and_restore_model(nemo_checkpoint_path, trainer, model)
return model


def get_unwrapped_mcore_model(model: llm.GPTModel):
def get_unwrapped_mcore_model(model):
from megatron.core.models.gpt import GPTModel as MCoreGPTModel

unwrapped_model = model
Expand Down
Loading