diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 2ab54e285c98d..4cb9a04c5aebf 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -114,6 +114,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); } + // whether fp16 is enable + const std::string fast_math_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFastMathOptimization); + if (!fast_math_env.empty()) { + fast_math_enable_ = (std::stoi(fast_math_enable_env) == 0 ? false : true); + } + // whether int8 is enabled const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { @@ -168,6 +174,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " << "device_id: " << device_id_ << ", migraphx_fp16_enable: " << fp16_enable_ + << ", migraphx_fast_math: " << fast_math_enable_ << ", migraphx_int8_enable: " << int8_enable_ << ", dump_model_ops: " << dump_model_ops_ << ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ @@ -1145,7 +1152,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::quantize_int8(prog, t_, quant_opts); } migraphx::compile_options co; - co.set_fast_math(false); + co.set_fast_math(fast_math_enable_); prog.compile(t_, co); auto prog_output_shapes = prog.get_output_shapes(); for (std::size_t i = 0; i < output_names.size(); ++i) { @@ -1165,7 +1172,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, + map_no_input_shape_[context->node_name], fp16_enable_, fast_math_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map, dump_model_ops_}; *state = p.release(); return 0; @@ -1265,7 +1272,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } migraphx::compile_options co; - co.set_fast_math(false); + co.set_fast_math(fast_math_enable); prog.compile(t, co); mgx_state->prog = prog; param_shapes = prog.get_parameter_shapes(); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index c094be51012e4..69b1d180316cd 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -26,6 +26,7 @@ static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; +static const char kSetFastMathOptimization[] = "ORT_MIGRAPHX_SET_FAST_MATH"; }; // namespace migraphx_env_vars // Information to construct kernel function state. @@ -41,6 +42,7 @@ struct MIGraphXFuncState { OrtMutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool fast_math_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; @@ -78,6 +80,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: bool fp16_enable_ = false; + bool fast_math_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; bool int8_calibration_cache_available_ = false; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index b7d7a77853df6..85843765a1e57 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -14,6 +14,7 @@ namespace migraphx { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kFastMathEnable = "migx_fast_math_enable"; constexpr const char* kInt8Enable = "migx_int8_enable"; constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; @@ -38,6 +39,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(migraphx::provider_option_names::kFastMathEnable, info.fast_math_enable) .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) .Parse(options)); @@ -48,6 +50,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {migraphx::provider_option_names::kFastMathEnable, MakeStringWithClassicLocale(info.fast_math_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, }; return options; @@ -57,6 +60,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, + {migraphx::provider_option_names::kFastMathEnable, MakeStringWithClassicLocale(info.migraphx_fast_math_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, }; return options; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 18ac30fdc1283..95e8538b12ca4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -16,6 +16,7 @@ struct MIGraphXExecutionProviderInfo { std::string target_device; int device_id{0}; bool fp16_enable{false}; + bool fast_math_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index f985682ddc735..c4dfdf95416c5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -47,6 +47,7 @@ struct MIGraphX_Provider : Provider { info.device_id = options.device_id; info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; + info.fast_math_enable = options.migraphx_fast_math_enable; info.int8_enable = options.migraphx_int8_enable; info.int8_calibration_table_name = ""; if (options.migraphx_int8_calibration_table_name != nullptr) { @@ -61,6 +62,7 @@ struct MIGraphX_Provider : Provider { auto& migx_options = *reinterpret_cast(provider_options); migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; + migx_options.migraphx_fast_math_enable = internal_options.fast_math_enable; migx_options.migraphx_int8_enable = internal_options.int8_enable; char* dest = nullptr; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 8423dcfbadc58..800fabbd88a6d 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -734,6 +734,7 @@ std::unique_ptr CreateExecutionProviderInstance( 0, 0, 0, + 0, nullptr}; for (auto option : it->second) { if (option.first == "device_id") { @@ -752,6 +753,13 @@ std::unique_ptr CreateExecutionProviderInstance( "[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } + } + else if (option.first == "migraphx_set_fast_math") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fast_math_enable = true; + } else { + params.migraphx_fast_math_enable = false; + } } else if (option.first == "migraphx_int8_enable") { if (option.second == "True" || option.second == "true") { params.migraphx_int8_enable = true;