From cda76e807a68c58ac5abae70f8c1326e48b7839b Mon Sep 17 00:00:00 2001 From: shaahji <96227573+shaahji@users.noreply.github.com> Date: Wed, 1 May 2024 13:23:00 -0700 Subject: [PATCH] Add support for DML execution provider (#1130) ## Add support for DML execution provider ## Checklist before requesting a review - [ ] Add unit tests for this change. - [x] Make sure all tests can pass. - [ ] Update documents if necessary. - [x] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- olive/passes/onnx/model_builder.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index eb454b6e1..0c9171701 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -13,7 +13,7 @@ from pathlib import Path from typing import Any, Dict, Union -from olive.hardware.accelerator import AcceleratorLookup, AcceleratorSpec, Device +from olive.hardware.accelerator import AcceleratorSpec, Device from olive.model import ONNXModelHandler, PyTorchModelHandler from olive.model.utils import resolve_onnx_path from olive.passes import Pass @@ -86,7 +86,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon ), "enable_cuda_graph": PassConfigParam( type_=bool, - default_value=False, + default_value=None, # Explicitly setting to None to differentiate between user intent and default. required=False, description=( "The model can use CUDA graph capture for CUDA execution provider. " @@ -102,13 +102,12 @@ def validate_search_point( if with_fixed_value: search_point = self.config_at_search_point(search_point or {}) precision = search_point.get("precision") - device = ( - Device.CPU - if self.accelerator_spec.execution_provider - in AcceleratorLookup.get_execution_providers_for_device(Device.CPU) - else Device.GPU - ) - if precision == ModelBuilder.Precision.FP16 and device == Device.CPU: + + # if device is GPU, but user choose CPU EP, the is_cpu should be True + if (precision == ModelBuilder.Precision.FP16) and not ( + accelerator_spec.accelerator_type == Device.GPU + and accelerator_spec.execution_provider != "CPUExecutionProvider" + ): logger.info( "FP16 is not supported on CPU. Valid precision + execution" "provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, INT4 CPU, INT4 CUDA" @@ -152,12 +151,12 @@ def _run_for_config( else Path(resolve_onnx_path(output_model_path, model.onnx_file_name)) ) - target_execution_provider = ( - "cpu" - if self.accelerator_spec.execution_provider - in AcceleratorLookup.get_execution_providers_for_device(Device.CPU) - else "cuda" - ) + if self.accelerator_spec.execution_provider == "DmlExecutionProvider": + target_execution_provider = "dml" + elif self.accelerator_spec.execution_provider == "CUDAExecutionProvider": + target_execution_provider = "cuda" + else: + target_execution_provider = "cpu" # Select cache location based on priority # HF_CACHE (HF >= v5) -> TRANSFORMERS_CACHE (HF < v5) -> local dir