From 11ff6449deb32bdabf7f08c670bc958e3bab416b Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 18 Dec 2024 20:20:45 +0000 Subject: [PATCH] Add flag to allow for fp8 quantization through Onnxruntime API --- include/onnxruntime/core/session/onnxruntime_c_api.h | 1 + .../migraphx/migraphx_execution_provider.cc | 12 ++++++++++-- .../providers/migraphx/migraphx_execution_provider.h | 3 +++ .../migraphx/migraphx_execution_provider_info.cc | 4 ++++ .../migraphx/migraphx_execution_provider_info.h | 1 + .../providers/migraphx/migraphx_provider_factory.cc | 2 ++ onnxruntime/python/onnxruntime_pybind_state.cc | 11 +++++++++++ onnxruntime/test/util/default_providers.cc | 1 + 8 files changed, 33 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4ca2791e26ab9..7ba935633ca3c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -614,6 +614,7 @@ typedef struct OrtTensorRTProviderOptions { typedef struct OrtMIGraphXProviderOptions { int device_id; // hip device id. int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index c1cae43480ea2..755da8611eedd 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -114,6 +114,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); } + // whether fp8 quantization is enabled + const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); + if (!fp8_enable_env.empty()) { + fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); + } + // whether int8 is enabled const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { @@ -192,6 +198,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " << "device_id: " << info_.device_id << ", migraphx_fp16_enable: " << fp16_enable_ + << ", migraphx_fp8_enable: " << fp8_enable_ << ", migraphx_int8_enable: " << int8_enable_ << ", migraphx_int8_enable: " << int8_enable_ << ", dump_model_ops: " << dump_model_ops_ @@ -1183,7 +1190,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable_ && int8_calibration_cache_available_) { + if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl; migraphx::quantize_int8_options quant_opts; migraphx::program_parameters quant_params; @@ -1240,7 +1247,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, + map_no_input_shape_[context->node_name], fp16_enable_, fp8_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map_, save_compiled_model_, save_compiled_path_, load_compiled_model_, load_compiled_path_, dump_model_ops_}; @@ -1265,6 +1272,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool fp8_enable = mgx_state->fp8_enable; bool int8_enable = mgx_state->int8_enable; bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 2be6c09551a71..dd7cfaedac361 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -17,6 +17,7 @@ namespace onnxruntime { namespace migraphx_env_vars { static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; +static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE"; static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; @@ -43,6 +44,7 @@ struct MIGraphXFuncState { OrtMutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool fp8_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; @@ -89,6 +91,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: MIGraphXExecutionProviderInfo info_; bool fp16_enable_ = false; + bool fp8_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; bool int8_calibration_cache_available_ = false; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 1f9a47d3ad87d..6537d1c12bc9c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -14,6 +14,7 @@ namespace migraphx { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kFp8Enable = "migx_fp8_enable"; constexpr const char* kInt8Enable = "migx_int8_enable"; constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; @@ -43,6 +44,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) @@ -56,6 +58,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, @@ -68,6 +71,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, + {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index b8bf86580f03d..554f77b6f7f58 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -16,6 +16,7 @@ struct MIGraphXExecutionProviderInfo { std::string target_device; OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; + bool fp8_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 7b192b657b7cc..098512b2ce69c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -60,6 +60,7 @@ struct MIGraphX_Provider : Provider { info.device_id = static_cast(options.device_id); info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; + info.fp8_enable = options.migraphx_fp8_enable; info.exhaustive_tune = options.migraphx_exhaustive_tune; info.int8_enable = options.migraphx_int8_enable; info.int8_calibration_table_name = ""; @@ -85,6 +86,7 @@ struct MIGraphX_Provider : Provider { auto& migx_options = *reinterpret_cast(provider_options); migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; + migx_options.migraphx_fp8_enable = internal_options.fp8_enable; migx_options.migraphx_int8_enable = internal_options.int8_enable; migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f31734bdfb805..1862072c36e1f 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -831,6 +831,7 @@ std::unique_ptr CreateExecutionProviderInstance( 0, 0, 0, + 0, nullptr, 1, "./compiled_model.mxr", @@ -854,6 +855,16 @@ std::unique_ptr CreateExecutionProviderInstance( "[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "migraphx_fp8_enable") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fp8_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_fp8_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_fp8_enable' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } } else if (option.first == "migraphx_int8_enable") { if (option.second == "True" || option.second == "true") { params.migraphx_int8_enable = true; diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 1feba20e32bbb..eecf194e7c9a1 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -76,6 +76,7 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { 0, 0, 0, + 0, nullptr, 1, "./compiled_model.mxr",