diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index a4d5bd12be7f7..8a0d0c56d68f2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -600,6 +600,7 @@ typedef struct OrtTensorRTProviderOptions { typedef struct OrtMIGraphXProviderOptions { int device_id; // hip device id. int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_fast_math_enable; // MIGraphX Fast Math Optimize. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 4cb9a04c5aebf..4b7b1441a6563 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -115,8 +115,8 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv } // whether fp16 is enable - const std::string fast_math_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFastMathOptimization); - if (!fast_math_env.empty()) { + const std::string fast_math_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSetFastMathOptimization); + if (!fast_math_enable_env.empty()) { fast_math_enable_ = (std::stoi(fast_math_enable_env) == 0 ? false : true); } @@ -1195,6 +1195,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool fast_math_enable = mgx_state->fast_math_enable; bool int8_enable = mgx_state->int8_enable; bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available;