-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
225 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
{ | ||
"input_model":{ | ||
"type": "PyTorchModel", | ||
"config": { | ||
"hf_config": { | ||
"model_name": "facebook/opt-125m", | ||
"task": "text-generation" | ||
} | ||
} | ||
}, | ||
"systems": { | ||
"local_system": { | ||
"type": "LocalSystem", | ||
"config": { | ||
"accelerators": [ | ||
{ | ||
"device": "gpu", | ||
"execution_providers": [ | ||
"CUDAExecutionProvider" | ||
] | ||
} | ||
] | ||
} | ||
} | ||
}, | ||
"data_configs": { | ||
"wikitext2": { | ||
"name": "wikitext2", | ||
"type": "HuggingfaceContainer", | ||
"params_config": { | ||
"data_name": "wikitext", | ||
"subset": "wikitext-2-raw-v1", | ||
"split": "train", | ||
"component_kwargs": { | ||
"pre_process_data": { | ||
"text_cols": ["text"], | ||
"source_max_len": 2048 | ||
} | ||
} | ||
} | ||
} | ||
}, | ||
"passes": { | ||
"slice": { | ||
"type": "SliceGPT", | ||
"config": { | ||
"sparsity": 0.4, | ||
"calibration_data_config": "wikitext2" | ||
} | ||
} | ||
}, | ||
"engine": { | ||
"log_severity_level": 0, | ||
"search_strategy": false, | ||
"evaluate_input_model": false, | ||
"cache_dir": "cache", | ||
"output_name": "sliced_opt", | ||
"output_dir": "models/opt" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
# ------------------------------------------------------------------------- | ||
import json | ||
import logging | ||
from typing import Any, Dict, Union | ||
|
||
import torch | ||
|
||
from olive.constants import ModelFileFormat | ||
from olive.data.config import DataConfig | ||
from olive.hardware.accelerator import AcceleratorSpec | ||
from olive.model import PyTorchModelHandler | ||
from olive.model.utils.path_utils import normalize_path_suffix | ||
from olive.passes import Pass | ||
from olive.passes.olive_pass import PassConfigParam | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SliceGPT(Pass): | ||
"""Run SliceGPT on a Hugging Face PyTorch model. | ||
See https://arxiv.org/pdf/2401.15024.pdf for more details on the algorithm. | ||
This pass only supports PyTorchModelHandler with hf_config. | ||
""" | ||
|
||
@staticmethod | ||
def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]: | ||
return { | ||
"calibration_data_config": PassConfigParam( | ||
type_=Union[DataConfig, Dict], | ||
required=True, | ||
description=("Data config for Dataset to calibrate and calculate perplexity on."), | ||
), | ||
"calibration_nsamples": PassConfigParam( | ||
type_=int, | ||
required=False, | ||
default_value=128, | ||
description=("Number of samples of the calibration data to load."), | ||
), | ||
"calibration_batch_size": PassConfigParam( | ||
type_=int, | ||
required=False, | ||
default_value=16, | ||
description=("Batch size for loading the calibration data."), | ||
), | ||
"calibration_max_seqlen": PassConfigParam( | ||
type_=int, | ||
required=False, | ||
default_value=2048, | ||
description=("Maximum sequence length for the calibration data."), | ||
), | ||
"varied_seqlen": PassConfigParam( | ||
type_=bool, | ||
required=False, | ||
default_value=False, | ||
description=("Varied sequence lengths in the calibration data."), | ||
), | ||
"seed": PassConfigParam( | ||
type_=int, | ||
required=False, | ||
default_value=42, | ||
description=("Seed for sampling the calibration data."), | ||
), | ||
"sparsity": PassConfigParam( | ||
type_=float, | ||
default_value=0.0, | ||
description="A measure of how much slicing is applied (in the range [0, 1))", | ||
), | ||
"round_interval": PassConfigParam( | ||
type_=int, | ||
default_value=8, | ||
description="Interval for rounding the weights (the best value may depend on your hardware)", | ||
), | ||
"final_orientation": PassConfigParam( | ||
type_=str, | ||
default_value="random", | ||
description="Final orientation of the sliced weights. Choices are random or pca.", | ||
), | ||
} | ||
|
||
@torch.no_grad() | ||
def _run_for_config( | ||
Check warning Code scanning / lintrunner PYLINT/W0237 Warning
Parameter 'model' has been renamed to 'model_handler' in overriding 'SliceGPT._run_for_config' method (arguments-renamed)
See arguments-renamed. |
||
self, model_handler: PyTorchModelHandler, data_root: str, config: Dict[str, Any], output_model_path: str | ||
) -> PyTorchModelHandler: | ||
from slicegpt import layernorm_fusion, rotate | ||
from slicegpt.data_utils import get_dataset, prepare_dataloader | ||
from slicegpt.hf_utils import get_model_and_tokenizer | ||
from slicegpt.slicing_scheduler import ConstSlicingScheduler | ||
|
||
# convert config to pass config class | ||
# this will validate the config and convert to the correct types | ||
config = self._config_class(**config) | ||
|
||
if model_handler.hf_config is None or model_handler.hf_config.model_name is None: | ||
raise ValueError("SliceGPT only supports select HuggingFace models") | ||
|
||
model_adapter, tokenizer = get_model_and_tokenizer(model_handler.hf_config.model_name) | ||
model_handler.model = model_adapter.model | ||
model = model_handler.load_model() | ||
|
||
# replace and fuse layers | ||
layernorm_fusion.replace_layers(model_adapter) | ||
layernorm_fusion.fuse_modules(model_adapter) | ||
|
||
original_param_count = sum(int(p.nelement()) for p in model.parameters()) | ||
logger.info("Original model parameters: %s", f"{original_param_count:,}") | ||
|
||
# compute new embedding dimension given the desired sparsity level | ||
new_embedding_dim = int((1 - config.sparsity) * model_adapter.hidden_size) | ||
# round (down) to the nearest multiple of round_interval | ||
new_embedding_dim -= new_embedding_dim % config.round_interval | ||
logger.info( | ||
"New embedding dimension: %f (sparsity %.4f%%)", | ||
new_embedding_dim, | ||
100 * (1 - new_embedding_dim / model_adapter.hidden_size), | ||
) | ||
|
||
train_dataset = get_dataset(config.calibration_data_config.name)["train"] | ||
train_loader = prepare_dataloader( | ||
dataset=train_dataset, | ||
tokenizer=tokenizer, | ||
max_seqlen=config.calibration_max_seqlen, | ||
batch_size=config.calibration_batch_size, | ||
nsamples=config.calibration_nsamples, | ||
varied_seqlen=config.varied_seqlen, | ||
seed=config.seed, | ||
) | ||
|
||
# rotate and slice | ||
schedular = ConstSlicingScheduler(new_embedding_dim) | ||
rotate.rotate_and_slice(model_adapter, train_loader, schedular, final_orientation=config.final_orientation) | ||
|
||
sliced_param_count = sum(int(p.nelement()) for p in model.parameters()) | ||
sliced_fraction = 1.0 - sliced_param_count / original_param_count | ||
logger.info("Sliced model parameters: %s (sliced fraction %.4f)", f"{sliced_param_count:,}", sliced_fraction) | ||
|
||
output_model_filepath = normalize_path_suffix(output_model_path, "model.pt") | ||
torch.save(model.state_dict(), output_model_filepath) | ||
|
||
output_config_filepath = normalize_path_suffix(output_model_path, "config.json") | ||
with open(output_config_filepath, "w") as strm: | ||
json.dump(model_adapter.slicing_conf.to_dict(), strm, indent=4) | ||
|
||
# return PyTorchModelHandler | ||
model_config = model_handler.to_json()["config"] | ||
model_config["model_path"] = output_model_path | ||
del model_config["model_file_format"] | ||
return PyTorchModelHandler(**model_config, model_file_format=ModelFileFormat.PYTORCH_SLICE_GPT_MODEL) |