From 531f35cbaa705cb9341de2e7ada824e0df8d3d37 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 29 Nov 2024 15:56:23 +0000 Subject: [PATCH 1/4] Add fp8 and int4 types in supported list for Onnxruntime EP --- .../migraphx/migraphx_execution_provider.cc | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index fd36b8ae5f678..f320bf61f0ddf 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -232,6 +232,10 @@ static bool IsTypeSupported(const NodeArg* node_arg) { switch (type_proto->tensor_type().elem_type()) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: @@ -261,6 +265,18 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: mgx_type = migraphx_shape_double_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: + mgx_type = migraphx_shape_fp8e4m3fnuz_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: + mgx_type = migraphx_shape_fp8e4m3fn_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: + mgx_type = migraphx_shape_fp8e5m2_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: + mgx_type = migraphx_shape_fp8e5m2fnuz_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; From 7076f7463507e72f1fa01f81ef3ebb33e21ab4b8 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 4 Dec 2024 23:05:49 +0000 Subject: [PATCH 2/4] Add support for int4 inputs Map things to int8 right now as we don't explicitly set an int4 input type and pack/unpack int4 operands --- .../providers/migraphx/migraphx_execution_provider.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index f320bf61f0ddf..c1cae43480ea2 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -237,10 +237,12 @@ static bool IsTypeSupported(const NodeArg* node_arg) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32: @@ -277,6 +279,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: mgx_type = migraphx_shape_fp8e5m2fnuz_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + mgx_type = migraphx_shape_int8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; @@ -289,6 +294,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: mgx_type = migraphx_shape_int64_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: + mgx_type = migraphx_shape_uint8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: mgx_type = migraphx_shape_uint8_type; break; From 11ff6449deb32bdabf7f08c670bc958e3bab416b Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 18 Dec 2024 20:20:45 +0000 Subject: [PATCH 3/4] Add flag to allow for fp8 quantization through Onnxruntime API --- include/onnxruntime/core/session/onnxruntime_c_api.h | 1 + .../migraphx/migraphx_execution_provider.cc | 12 ++++++++++-- .../providers/migraphx/migraphx_execution_provider.h | 3 +++ .../migraphx/migraphx_execution_provider_info.cc | 4 ++++ .../migraphx/migraphx_execution_provider_info.h | 1 + .../providers/migraphx/migraphx_provider_factory.cc | 2 ++ onnxruntime/python/onnxruntime_pybind_state.cc | 11 +++++++++++ onnxruntime/test/util/default_providers.cc | 1 + 8 files changed, 33 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4ca2791e26ab9..7ba935633ca3c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -614,6 +614,7 @@ typedef struct OrtTensorRTProviderOptions { typedef struct OrtMIGraphXProviderOptions { int device_id; // hip device id. int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index c1cae43480ea2..755da8611eedd 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -114,6 +114,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); } + // whether fp8 quantization is enabled + const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); + if (!fp8_enable_env.empty()) { + fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); + } + // whether int8 is enabled const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { @@ -192,6 +198,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " << "device_id: " << info_.device_id << ", migraphx_fp16_enable: " << fp16_enable_ + << ", migraphx_fp8_enable: " << fp8_enable_ << ", migraphx_int8_enable: " << int8_enable_ << ", migraphx_int8_enable: " << int8_enable_ << ", dump_model_ops: " << dump_model_ops_ @@ -1183,7 +1190,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable_ && int8_calibration_cache_available_) { + if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl; migraphx::quantize_int8_options quant_opts; migraphx::program_parameters quant_params; @@ -1240,7 +1247,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, + map_no_input_shape_[context->node_name], fp16_enable_, fp8_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map_, save_compiled_model_, save_compiled_path_, load_compiled_model_, load_compiled_path_, dump_model_ops_}; @@ -1265,6 +1272,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool fp8_enable = mgx_state->fp8_enable; bool int8_enable = mgx_state->int8_enable; bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 2be6c09551a71..dd7cfaedac361 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -17,6 +17,7 @@ namespace onnxruntime { namespace migraphx_env_vars { static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; +static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE"; static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; @@ -43,6 +44,7 @@ struct MIGraphXFuncState { OrtMutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool fp8_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; @@ -89,6 +91,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: MIGraphXExecutionProviderInfo info_; bool fp16_enable_ = false; + bool fp8_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; bool int8_calibration_cache_available_ = false; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 1f9a47d3ad87d..6537d1c12bc9c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -14,6 +14,7 @@ namespace migraphx { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kFp8Enable = "migx_fp8_enable"; constexpr const char* kInt8Enable = "migx_int8_enable"; constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; @@ -43,6 +44,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) @@ -56,6 +58,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, @@ -68,6 +71,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, + {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index b8bf86580f03d..554f77b6f7f58 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -16,6 +16,7 @@ struct MIGraphXExecutionProviderInfo { std::string target_device; OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; + bool fp8_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 7b192b657b7cc..098512b2ce69c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -60,6 +60,7 @@ struct MIGraphX_Provider : Provider { info.device_id = static_cast(options.device_id); info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; + info.fp8_enable = options.migraphx_fp8_enable; info.exhaustive_tune = options.migraphx_exhaustive_tune; info.int8_enable = options.migraphx_int8_enable; info.int8_calibration_table_name = ""; @@ -85,6 +86,7 @@ struct MIGraphX_Provider : Provider { auto& migx_options = *reinterpret_cast(provider_options); migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; + migx_options.migraphx_fp8_enable = internal_options.fp8_enable; migx_options.migraphx_int8_enable = internal_options.int8_enable; migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f31734bdfb805..1862072c36e1f 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -831,6 +831,7 @@ std::unique_ptr CreateExecutionProviderInstance( 0, 0, 0, + 0, nullptr, 1, "./compiled_model.mxr", @@ -854,6 +855,16 @@ std::unique_ptr CreateExecutionProviderInstance( "[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "migraphx_fp8_enable") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fp8_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_fp8_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_fp8_enable' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } } else if (option.first == "migraphx_int8_enable") { if (option.second == "True" || option.second == "true") { params.migraphx_int8_enable = true; diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 1feba20e32bbb..eecf194e7c9a1 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -76,6 +76,7 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { 0, 0, 0, + 0, nullptr, 1, "./compiled_model.mxr", From ac77aacb174a30b1e682645fb73846af78de6e0e Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 19 Dec 2024 01:21:47 +0000 Subject: [PATCH 4/4] Add fp8 quantization to the compile stage of the MIGraphX EP Mirror the same calibration code we use for int8 and just change which quantize we call through the MIGraphx API --- .../migraphx/migraphx_execution_provider.cc | 54 +++++++++++++------ 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 755da8611eedd..980ac7720b8d9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1190,9 +1190,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { + if ((int8_enable_ xor fp8_enable_) && int8_calibration_cache_available_) { LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl; - migraphx::quantize_int8_options quant_opts; migraphx::program_parameters quant_params; auto param_shapes = prog.get_parameter_shapes(); @@ -1202,15 +1201,26 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); } - quant_opts.add_calibration_data(quant_params); - - // specify thing we want to int8 quantize - quant_opts.add_op_name("convolution"); - quant_opts.add_op_name("dot"); // perform static quantization on the programs - migraphx::quantize_int8(prog, t_, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl; + if(int8_enable_) + { + migraphx::quantize_int8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + // specify thing we want to int8 quantize + quant_opts.add_op_name("convolution"); + quant_opts.add_op_name("dot"); + migraphx::quantize_int8(prog, t_, quant_opts); + LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl; + } + else if(fp8_enable_) + { + migraphx::quantize_fp8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + migraphx::quantize_fp8(prog, t_, quant_opts); + LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete" << std::endl; + } + } if (fp16_enable_) { @@ -1333,9 +1343,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable && int8_calibration_cache_available) { + if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { LOGS_DEFAULT(INFO) << "Quantize Int8: Begin" << std::endl; - migraphx::quantize_int8_options quant_opts; migraphx::program_parameters quant_params; auto param_shapes = prog.get_parameter_shapes(); @@ -1364,14 +1373,25 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); } - quant_opts.add_calibration_data(quant_params); - // specify thing we want to int8 quantize - quant_opts.add_op_name("convolution"); - quant_opts.add_op_name("dot"); // perform static quantization on the programs - migraphx::quantize_int8(prog, t, quant_opts); - LOGS_DEFAULT(INFO) << "Quantize Int8: Completed" << std::endl; + if(int8_enable) + { + migraphx::quantize_int8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + // specify thing we want to int8 quantize + quant_opts.add_op_name("convolution"); + quant_opts.add_op_name("dot"); + migraphx::quantize_int8(prog, t, quant_opts); + LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete" << std::endl; + } + else if(fp8_enable) + { + migraphx::quantize_fp8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + migraphx::quantize_fp8(prog, t, quant_opts); + LOGS_DEFAULT(INFO) << "Quantizing input program to fp8: Complete" << std::endl; + } } if (fp16_enable) {