Skip to content

Commit

Permalink
Introducing SliceGPT pass
Browse files Browse the repository at this point in the history
  • Loading branch information
shaahji committed Apr 4, 2024
1 parent fb869c1 commit 1b1a556
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 0 deletions.
60 changes: 60 additions & 0 deletions examples/opt/opt_cofig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{
"input_model":{
"type": "PyTorchModel",
"config": {
"hf_config": {
"model_name": "facebook/opt-125m",
"task": "text-generation"
}
}
},
"systems": {
"local_system": {
"type": "LocalSystem",
"config": {
"accelerators": [
{
"device": "gpu",
"execution_providers": [
"CUDAExecutionProvider"
]
}
]
}
}
},
"data_configs": {
"wikitext2": {
"name": "wikitext2",
"type": "HuggingfaceContainer",
"params_config": {
"data_name": "wikitext",
"subset": "wikitext-2-raw-v1",
"split": "train",
"component_kwargs": {
"pre_process_data": {
"text_cols": ["text"],
"source_max_len": 2048
}
}
}
}
},
"passes": {
"slice": {
"type": "SliceGPT",
"config": {
"sparsity": 0.4,
"calibration_data_config": "wikitext2"
}
}
},
"engine": {
"log_severity_level": 0,
"search_strategy": false,
"evaluate_input_model": false,
"cache_dir": "cache",
"output_name": "sliced_opt",
"output_dir": "models/opt"
}
}
1 change: 1 addition & 0 deletions olive/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ModelFileFormat(str, Enum):
PYTORCH_STATE_DICT = "PyTorch.StateDict"
PYTORCH_TORCH_SCRIPT = "PyTorch.TorchScript"
PYTORCH_MLFLOW_MODEL = "PyTorch.MLflow"
PYTORCH_SLICE_GPT_MODEL = "PyTorch.SliceGPT"
TENSORFLOW_PROTOBUF = "TensorFlow.Protobuf"
TENSORFLOW_SAVED_MODEL = "TensorFlow.SavedModel"
SNPE_DLC = "SNPE.DLC"
Expand Down
9 changes: 9 additions & 0 deletions olive/model/handler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def load_model(self, rank: int = None) -> torch.nn.Module:
model = self.load_hf_model(self.model_path)
elif self.model_file_format == ModelFileFormat.PYTORCH_ENTIRE_MODEL:
model = torch.load(self.model_path)
elif self.model_file_format == ModelFileFormat.PYTORCH_SLICE_GPT_MODEL:
model = self._load_slicegpt_model()
elif self.model_file_format == ModelFileFormat.PYTORCH_STATE_DICT:
raise ValueError("Please use customized model loader to load state dict of model.")
else:
Expand Down Expand Up @@ -224,6 +226,13 @@ def _load_mlflow_model(self):
loaded_model.eval()
return loaded_model

def _load_slicegpt_model(self):
logger.info("Loading SliceGPT model from %s", self.model_path)
from slicgpt.hf_utils import load_sliced_model as lsm

loaded_model, _ = lsm(self.model_path)
return loaded_model

def to_json(self, check_object: bool = False):
config = super().to_json(check_object)
# only keep model_attributes that are not in hf_config
Expand Down
2 changes: 2 additions & 0 deletions olive/passes/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from olive.passes.pytorch.gptq import GptqQuantizer
from olive.passes.pytorch.lora import LoftQ, LoRA, QLoRA
from olive.passes.pytorch.quantization_aware_training import QuantizationAwareTraining
from olive.passes.pytorch.slicegpt import SliceGPT
from olive.passes.pytorch.sparsegpt import SparseGPT
from olive.passes.pytorch.tensor_parallel import PyTorchTensorParallel
from olive.passes.pytorch.torch_trt_conversion import TorchTRTConversion
Expand All @@ -17,5 +18,6 @@
"QLoRA",
"QuantizationAwareTraining",
"SparseGPT",
"SliceGPT",
"TorchTRTConversion",
]
153 changes: 153 additions & 0 deletions olive/passes/pytorch/slicegpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# -------------------------------------------------------------------------
import json
import logging
from typing import Any, Dict, Union

import torch

from olive.constants import ModelFileFormat
from olive.data.config import DataConfig
from olive.hardware.accelerator import AcceleratorSpec
from olive.model import PyTorchModelHandler
from olive.model.utils.path_utils import normalize_path_suffix
from olive.passes import Pass
from olive.passes.olive_pass import PassConfigParam

logger = logging.getLogger(__name__)


class SliceGPT(Pass):
"""Run SliceGPT on a Hugging Face PyTorch model.
See https://arxiv.org/pdf/2401.15024.pdf for more details on the algorithm.
This pass only supports PyTorchModelHandler with hf_config.
"""

@staticmethod
def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
return {
"calibration_data_config": PassConfigParam(
type_=Union[DataConfig, Dict],
required=True,
description=("Data config for Dataset to calibrate and calculate perplexity on."),
),
"calibration_nsamples": PassConfigParam(
type_=int,
required=False,
default_value=128,
description=("Number of samples of the calibration data to load."),
),
"calibration_batch_size": PassConfigParam(
type_=int,
required=False,
default_value=16,
description=("Batch size for loading the calibration data."),
),
"calibration_max_seqlen": PassConfigParam(
type_=int,
required=False,
default_value=2048,
description=("Maximum sequence length for the calibration data."),
),
"varied_seqlen": PassConfigParam(
type_=bool,
required=False,
default_value=False,
description=("Varied sequence lengths in the calibration data."),
),
"seed": PassConfigParam(
type_=int,
required=False,
default_value=42,
description=("Seed for sampling the calibration data."),
),
"sparsity": PassConfigParam(
type_=float,
default_value=0.0,
description="A measure of how much slicing is applied (in the range [0, 1))",
),
"round_interval": PassConfigParam(
type_=int,
default_value=8,
description="Interval for rounding the weights (the best value may depend on your hardware)",
),
"final_orientation": PassConfigParam(
type_=str,
default_value="random",
description="Final orientation of the sliced weights. Choices are random or pca.",
),
}

@torch.no_grad()
def _run_for_config(

Check warning

Code scanning / lintrunner

PYLINT/W0237 Warning

Parameter 'model' has been renamed to 'model_handler' in overriding 'SliceGPT._run_for_config' method (arguments-renamed)
See arguments-renamed.
self, model_handler: PyTorchModelHandler, data_root: str, config: Dict[str, Any], output_model_path: str
) -> PyTorchModelHandler:
from slicegpt import layernorm_fusion, rotate
from slicegpt.data_utils import get_dataset, prepare_dataloader
from slicegpt.hf_utils import get_model_and_tokenizer
from slicegpt.slicing_scheduler import ConstSlicingScheduler

# convert config to pass config class
# this will validate the config and convert to the correct types
config = self._config_class(**config)

if model_handler.hf_config is None or model_handler.hf_config.model_name is None:
raise ValueError("SliceGPT only supports select HuggingFace models")

model_adapter, tokenizer = get_model_and_tokenizer(model_handler.hf_config.model_name)
model_handler.model = model_adapter.model
model = model_handler.load_model()

# replace and fuse layers
layernorm_fusion.replace_layers(model_adapter)
layernorm_fusion.fuse_modules(model_adapter)

original_param_count = sum(int(p.nelement()) for p in model.parameters())
logger.info("Original model parameters: %s", f"{original_param_count:,}")

# compute new embedding dimension given the desired sparsity level
new_embedding_dim = int((1 - config.sparsity) * model_adapter.hidden_size)
# round (down) to the nearest multiple of round_interval
new_embedding_dim -= new_embedding_dim % config.round_interval
logger.info(
"New embedding dimension: %f (sparsity %.4f%%)",
new_embedding_dim,
100 * (1 - new_embedding_dim / model_adapter.hidden_size),
)

train_dataset = get_dataset(config.calibration_data_config.name)["train"]
train_loader = prepare_dataloader(
dataset=train_dataset,
tokenizer=tokenizer,
max_seqlen=config.calibration_max_seqlen,
batch_size=config.calibration_batch_size,
nsamples=config.calibration_nsamples,
varied_seqlen=config.varied_seqlen,
seed=config.seed,
)

# rotate and slice
schedular = ConstSlicingScheduler(new_embedding_dim)
rotate.rotate_and_slice(model_adapter, train_loader, schedular, final_orientation=config.final_orientation)

sliced_param_count = sum(int(p.nelement()) for p in model.parameters())
sliced_fraction = 1.0 - sliced_param_count / original_param_count
logger.info("Sliced model parameters: %s (sliced fraction %.4f)", f"{sliced_param_count:,}", sliced_fraction)

output_model_filepath = normalize_path_suffix(output_model_path, "model.pt")
torch.save(model.state_dict(), output_model_filepath)

output_config_filepath = normalize_path_suffix(output_model_path, "config.json")
with open(output_config_filepath, "w") as strm:
json.dump(model_adapter.slicing_conf.to_dict(), strm, indent=4)

# return PyTorchModelHandler
model_config = model_handler.to_json()["config"]
model_config["model_path"] = output_model_path
del model_config["model_file_format"]
return PyTorchModelHandler(**model_config, model_file_format=ModelFileFormat.PYTORCH_SLICE_GPT_MODEL)

0 comments on commit 1b1a556

Please sign in to comment.