Skip to content

Commit

Permalink
Additional ModelBuilder pass options (#1112)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaahji authored Apr 25, 2024
1 parent 31729e3 commit 5e9e762
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ class Precision(str, Enum):
def __str__(self) -> str:
return self.value

class AccuracyLevel(int, Enum):
fp32 = 1
fp16 = 2
bf16 = 3
int8 = 4

def __str__(self) -> str:
return str(self.value)

@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
return {
Expand All @@ -53,6 +62,38 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
"search": PassConfigParam(
type_=Dict[str, Any], required=False, description="Search options to use for generate loop."
),
"int4_block_size": PassConfigParam(
type_=int,
required=False,
description="Specify the block_size for int4 quantization. Acceptable values: 16/32/64/128/256.",
),
"int4_accuracy_level": PassConfigParam(
type_=ModelBuilder.AccuracyLevel,
required=False,
description="Specify the minimum accuracy level for activation of MatMul in int4 quantization.",
),
"exclude_embeds": PassConfigParam(
type_=bool,
default_value=False,
required=False,
description="Remove embedding layer from your ONNX model.",
),
"exclude_lm_head": PassConfigParam(
type_=bool,
default_value=False,
required=False,
description="Remove language modeling head from your ONNX model.",
),
"enable_cuda_graph": PassConfigParam(
type_=bool,
default_value=False,
required=False,
description=(
"The model can use CUDA graph capture for CUDA execution provider. "
"If enabled, all nodes being placed on the CUDA EP is the prerequisite "
"for the CUDA graph to be used correctly."
),
),
}

def validate_search_point(
Expand Down Expand Up @@ -137,6 +178,18 @@ def _run_for_config(
model_path = str(model.hf_config.model_name)
input_path = ""

if config.get("int4_block_size"):
if int(config["int4_block_size"]) not in [16, 32, 64, 128, 256]:
raise ValueError("Invalid int4_block_size. Accepted values: 16/32/64/128/256.")
extra_args["int4_block_size"] = config["int4_block_size"]

if config.get("int4_accuracy_level"):
extra_args["int4_accuracy_level"] = config["int4_accuracy_level"].value

extra_args["exclude_embeds"] = config["exclude_embeds"]
extra_args["exclude_lm_head"] = config["exclude_lm_head"]
extra_args["enable_cuda_graph"] = "1" if config["enable_cuda_graph"] else "0"

create_model(
model_name=model_path,
input_path=input_path,
Expand Down

0 comments on commit 5e9e762

Please sign in to comment.