Skip to content

Commit

Permalink
ExtractAdapters: Support int4 models and external data config (micros…
Browse files Browse the repository at this point in the history
…oft#1083)

## Describe your changes
- `ExtractAdapters` pass supports int4 quantized models. 
- Previously output models were always saved with external data.
External data config options are exposed to the user.
- `export_adapters` script also supports quantizing the weights to int4.

## Checklist before requesting a review
- [x] Add unit tests for this change.
- [x] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [ ] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
jambayk authored and DavitGrigoryan132 committed Aug 14, 2024
1 parent 082ab59 commit a91a162
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 143 deletions.
243 changes: 138 additions & 105 deletions olive/passes/onnx/extract_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Set, Tuple

import numpy as np
import onnx
Expand All @@ -15,7 +15,7 @@
from olive.model import ONNXModelHandler
from olive.model.utils import resolve_onnx_path
from olive.passes import Pass
from olive.passes.onnx.common import model_proto_to_olive_model
from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model
from olive.passes.onnx.onnx_dag import OnnxDAG
from olive.passes.pass_config import PassConfigParam

Expand All @@ -37,7 +37,7 @@ class ExtractAdapters(Pass):

@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
return {
config = {
"make_inputs": PassConfigParam(
type_=bool,
default_value=False,
Expand All @@ -55,6 +55,8 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
),
),
}
config.update(get_external_data_config())
return config

def _run_for_config(
self, model: ONNXModelHandler, data_root: str, config: Dict[str, Any], output_model_path: str
Expand Down Expand Up @@ -83,75 +85,86 @@ def _run_for_config(
nodes_to_remove = set()

for node_name in dag.get_node_names():
if dag.get_node_op_type(node_name) != "MatMul" or not any(
op_type = dag.get_node_op_type(node_name)
if op_type not in {"MatMul", "MatMulNBits"} or not any(
re.match(pattern, node_name) for pattern in lora_name_patterns
):
# not a lora module
continue

# new name for the weight
# new name for the float weight
new_weight_name = self._create_new_weight_name(node_name)

# original weight name
old_weight_name = dag.get_node_inputs(node_name)[1]

if dag.is_input(old_weight_name):
# nothing we can do here
continue
elif dag.is_initializer(old_weight_name):
# weight is an initializer (not quantized)
# create initializer with new weight name
self._create_empty_initializer(dag, weights, old_weight_name, new_weight_name)

# change input to the new name
dag.replace_node_input(node_name, old_weight_name, new_weight_name)

# add the module to the float modules
float_modules.add(new_weight_name.replace(".weight", ""))
elif dag.get_node_op_type(dag.get_producer(old_weight_name)) == "DequantizeLinear":
# weight is quantized
# get the dequantize node
old_dequantize_name = dag.get_producer(old_weight_name)
old_dequantize_node = dag.get_node(old_dequantize_name)

# new names for the dequantize node inputs
suffixes = ["quant.weight", "quant.scale", "quant.zero_point"]
new_input_names = [new_weight_name.replace("weight", suffix) for suffix in suffixes]

# zero point is optional so we keep track of used inputs
used_inputs = []
# create new initializers for the dequantize node
for old_input, new_input in zip(old_dequantize_node.inputs, new_input_names):
# new names for quantized weight and parameters
# zero point is optional if symmetric
quantized_suffices = [".quant.weight", ".quant.scale", ".quant.zero_point"]
new_quantized_names = [new_weight_name.replace(".weight", suffix) for suffix in quantized_suffices]

if op_type == "MatMul":
# float or QDQ quantized
# original weight name
old_weight_name = dag.get_node_inputs(node_name)[1]

if dag.is_input(old_weight_name):
# nothing to do here
continue
elif dag.is_initializer(old_weight_name):
# weight is an float initializer
# create initializer with new weight name
self._create_empty_initializer(dag, weights, old_weight_name, new_weight_name)

# change input to the new name
dag.replace_node_input(node_name, old_weight_name, new_weight_name)

# add the module to the float modules
float_modules.add(new_weight_name.replace(".weight", ""))
elif dag.get_node_op_type(dag.get_producer(old_weight_name)) == "DequantizeLinear":
# weight is QDQ quantized
# get the dequantize node
old_dequantize_name = dag.get_producer(old_weight_name)
old_dequantize_node = dag.get_node(old_dequantize_name)

# zero point is optional so we keep track of used inputs
used_inputs = []
# create new initializers for the dequantize node
for old_input, new_input in zip(old_dequantize_node.inputs, new_quantized_names):
self._create_empty_initializer(dag, weights, old_input, new_input)
used_inputs.append(new_input)

# create a new dequantize node
# NOTE: We could directly modify the original dequantize node but this assumes that the dequantize
# node is not used elsewhere
# this cannot be guaranteed (for instance, if the float model has lora modules with same weights,
# they might all share the same dequantize node)
new_dequantize_proto = onnx.NodeProto()
new_dequantize_proto.CopyFrom(old_dequantize_node.proto)
# change node name
new_dequantize_proto.name = new_weight_name.replace("weight", "dequantize")
# change input names
for i, new_input in enumerate(used_inputs):
new_dequantize_proto.input[i] = new_input
# change output name
new_dequantize_proto.output[0] = new_weight_name

# add new dequantize node
dag.add_node(new_dequantize_proto, old_dequantize_node.graph_idx)

# replace input to the new name
dag.replace_node_input(node_name, old_weight_name, new_weight_name)

# add old dequantize node to remove
nodes_to_remove.add(old_dequantize_name)

# add the module to the quant modules
quant_modules.add(new_weight_name.replace(".weight", ".quant"))
elif op_type == "MatMulNBits":
# weight is Nbits quantized
# create empty initializers and change node inputs
for old_input, new_input in zip(dag.get_node_inputs(node_name)[1:], new_quantized_names):
self._create_empty_initializer(dag, weights, old_input, new_input)
used_inputs.append(new_input)

# create a new dequantize node
# NOTE: We could directly modify the original dequantize node but this assumes that the dequantize node
# is not used elsewhere
# this cannot be guaranteed (for instance, if the float model has lora modules with same weights, they
# might all share the same dequantize node)
new_dequantize_proto = onnx.NodeProto()
new_dequantize_proto.CopyFrom(old_dequantize_node.proto)
# change node name
new_dequantize_proto.name = new_weight_name.replace("weight", "dequantize")
# change input names
for i, new_input in enumerate(used_inputs):
new_dequantize_proto.input[i] = new_input
# change output name
new_dequantize_proto.output[0] = new_weight_name

# add new dequantize node
dag.add_node(new_dequantize_proto, old_dequantize_node.graph_idx)

# replace input to the new name
dag.replace_node_input(node_name, old_weight_name, new_weight_name)

# add old dequantize node to remove
nodes_to_remove.add(old_dequantize_name)
dag.replace_node_input(node_name, old_input, new_input)

# add the module to the quant modules
quant_modules.add(new_weight_name.replace(".weight", ".quant"))
# TODO(jambayk): Add int4 quantization support

# remove old dequantize nodes
for node_name in nodes_to_remove:
Expand All @@ -163,47 +176,10 @@ def _run_for_config(
dag.convert_initializer_to_input(weight_name)
elif config["make_inputs"] and config["pack_inputs"]:
# what weights are packed together
packings = {}

def get_sort_key(module_name: str):
parts = module_name.split(".")
for i, part in enumerate(parts):
try:
# want the layers to be sorted by the number
parts[i] = int(part)
except ValueError:
pass
return parts

# group by module type, sort by name and pack them together
for module_type in lora_modules:
for lora_i in ["lora_A", "lora_B"]:
# base name to use for split node and input
base_name = f"{module_type}.{lora_i}"

matching_float_modules = sorted(
[name for name in float_modules if module_type in name and lora_i in name], key=get_sort_key
)
if matching_float_modules:
packings[f"{base_name}.weight.packed"] = [f"{name}.weight" for name in matching_float_modules]

matching_quant_modules = sorted(
[name for name in quant_modules if module_type in name and lora_i in name], key=get_sort_key
)
if matching_quant_modules:
# zero point is optional so we need to check if it exists
for suffix in [".weight", ".scale", ".zero_point"]:
packings[f"{base_name}.quant{suffix}.packed"] = [
name + suffix for name in matching_quant_modules if name + suffix in weights
]

# pack the weights, create inputs and split nodes
packed_weights = {}
for weight_name, to_pack in packings.items():
if not to_pack:
continue
packed_weights[weight_name] = np.concatenate([np.atleast_1d(weights[name]) for name in to_pack], axis=0)
packed_weights, packings = self.pack_weights(weights, lora_modules, float_modules, quant_modules)

# create inputs and split nodes for the packed weights
for weight_name, to_pack in packings.items():
# input proto
input_proto = onnx.helper.make_tensor_value_info(
name=weight_name,
Expand Down Expand Up @@ -239,7 +215,7 @@ def get_sort_key(module_name: str):
output_model = model_proto_to_olive_model(
dag.model,
output_model_path,
external_data_config={"save_as_external_data": True, "all_tensors_to_one_file": True},
config,
external_initializers_file_name=weights_path.name if not config["make_inputs"] else None,
constant_inputs_file_name=weights_path.name if config["make_inputs"] else None,
)
Expand All @@ -254,10 +230,66 @@ def get_sort_key(module_name: str):
output_model.model_attributes["packed_inputs"] = packings
return output_model

@staticmethod
def pack_weights(
weights: Dict[str, "NDArray"], module_types: List[str], float_modules: Set[str], quant_modules: Set[str]
) -> Tuple[Dict[str, "NDArray"], Dict[str, List[str]]]:
"""Pack the weights for a given module type into an array each for lora_A and lora_B.
Assumes the weights for the same module type are in the same format (float, QDQ, Nbits).
:param weights: dictionary of weights
:param module_types: list of module types
:param float_modules: set of float module names
:param quant_modules: set of quantized module names
:return: dictionary of packed weights and dictionary of packed weights and their corresponding weights
"""

def get_sort_key(module_name: str):
parts = module_name.split(".")
for i, part in enumerate(parts):
try:
# want the layers to be sorted by the number
parts[i] = int(part)
except ValueError:
pass
return parts

packings = {}
for module_type in module_types:
for lora_i in ["lora_A", "lora_B"]:
matching_float_modules = sorted(
[name for name in float_modules if module_type in name and lora_i in name], key=get_sort_key
)
if matching_float_modules:
weight_name = f"{module_type}.{lora_i}.weight.packed"
packings[weight_name] = [f"{name}.weight" for name in matching_float_modules]

matching_quant_modules = sorted(
[name for name in quant_modules if module_type in name and lora_i in name], key=get_sort_key
)
if matching_quant_modules:
# zero point is optional so we need to check if it exists
for suffix in [".weight", ".scale", ".zero_point"]:
weight_name = f"{module_type}.{lora_i}.quant{suffix}.packed"
to_pack = [name + suffix for name in matching_quant_modules if name + suffix in weights]
if to_pack:
packings[weight_name] = to_pack

packed_weights = {}
for weight_name, to_pack in packings.items():
packed_weights[weight_name] = np.concatenate([np.atleast_1d(weights[name]) for name in to_pack], axis=0)

return packed_weights, packings

@staticmethod
def _get_lora_name_patterns(lora_modules: List[str]) -> List[str]:
"""Get the node name patterns for lora modules."""
return [f".*[./]{key}[./]{name}[./]MatMul$" for key in lora_modules for name in ["default", "default_1"]]
return [
f".*[./]{key}[./]{name}[./]{matmul}$"
for key in lora_modules
for name in ["default", "default_1"]
for matmul in ["MatMul", "MatMul_Q4"]
]

@staticmethod
def _create_new_weight_name(old_name: str) -> str:
Expand All @@ -266,11 +298,12 @@ def _create_new_weight_name(old_name: str) -> str:
The new weight name is of the form model.layers.0.self_attn.q_proj.lora_A.quant.weight
"""
weight_name = old_name[1:] if old_name.startswith("/") else old_name
matmul_name = weight_name.split("/")[-1]
return (
weight_name.replace("/", ".")
.replace("default.", "lora_A.")
.replace("default_1.", "lora_B.")
.replace("MatMul", "weight")
.replace(matmul_name, "weight")
)

@staticmethod
Expand Down
Loading

0 comments on commit a91a162

Please sign in to comment.