Skip to content

Commit

Permalink
Add flag to allow for fp8 quantization through Onnxruntime API
Browse files Browse the repository at this point in the history
  • Loading branch information
TedThemistokleous committed Dec 18, 2024
1 parent 7076f74 commit 4850dfd
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ typedef struct OrtTensorRTProviderOptions {
typedef struct OrtMIGraphXProviderOptions {
int device_id; // hip device id.
int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true
int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true);
}

// whether fp8 quantization is enabled
const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable);
if (!fp8_enable_env.empty()) {
fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true);
}

// whether int8 is enabled
const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable);
if (!int8_enable_env.empty()) {
Expand Down Expand Up @@ -192,6 +198,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: "
<< "device_id: " << info_.device_id
<< ", migraphx_fp16_enable: " << fp16_enable_
<< ", migraphx_fp8_enable: " << fp8_enable_
<< ", migraphx_int8_enable: " << int8_enable_
<< ", migraphx_int8_enable: " << int8_enable_
<< ", dump_model_ops: " << dump_model_ops_
Expand Down Expand Up @@ -1183,7 +1190,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options);

// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
if (int8_enable_ && int8_calibration_cache_available_) {
if ((int8_enable_ || fp8_enable) && int8_calibration_cache_available_) {
LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl;
migraphx::quantize_int8_options quant_opts;
migraphx::program_parameters quant_params;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace onnxruntime {

namespace migraphx_env_vars {
static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE";
static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE";
static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE";
static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
Expand All @@ -43,6 +44,7 @@ struct MIGraphXFuncState {
OrtMutex* mgx_mu_ptr = nullptr;
bool no_input_shape = false;
bool fp16_enable = false;
bool fp8_enable = false;
bool int8_enable = false;
bool int8_calibration_cache_available = false;
std::unordered_map<std::string, float> dynamic_range_map;
Expand Down Expand Up @@ -89,6 +91,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
private:
MIGraphXExecutionProviderInfo info_;
bool fp16_enable_ = false;
bool fp8_enable_ = false;
bool int8_enable_ = false;
std::string int8_calibration_cache_name_;
bool int8_calibration_cache_available_ = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace migraphx {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kFp16Enable = "trt_fp16_enable";
constexpr const char* kFp8Enable = "migx_fp8_enable";
constexpr const char* kInt8Enable = "migx_int8_enable";
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
Expand Down Expand Up @@ -43,6 +44,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
return Status::OK();
})
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
Expand All @@ -56,6 +58,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
{migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
Expand All @@ -68,6 +71,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
{migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct MIGraphXExecutionProviderInfo {
std::string target_device;
OrtDevice::DeviceId device_id{0};
bool fp16_enable{false};
bool fp8_enable{false};
bool int8_enable{false};
std::string int8_calibration_table_name{""};
bool int8_use_native_calibration_table{false};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ struct MIGraphX_Provider : Provider {
info.device_id = static_cast<OrtDevice::DeviceId>(options.device_id);
info.target_device = "gpu";
info.fp16_enable = options.migraphx_fp16_enable;
info.fp8_enable = options.migraphx_fp8_enable;
info.exhaustive_tune = options.migraphx_exhaustive_tune;
info.int8_enable = options.migraphx_int8_enable;
info.int8_calibration_table_name = "";
Expand All @@ -85,6 +86,7 @@ struct MIGraphX_Provider : Provider {
auto& migx_options = *reinterpret_cast<OrtMIGraphXProviderOptions*>(provider_options);
migx_options.device_id = internal_options.device_id;
migx_options.migraphx_fp16_enable = internal_options.fp16_enable;
migx_options.migraphx_fp8_enable = internal_options.fp8_enable;
migx_options.migraphx_int8_enable = internal_options.int8_enable;
migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune;

Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
0,
0,
0,
0,
nullptr,
1,
"./compiled_model.mxr",
Expand All @@ -854,6 +855,16 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_fp8_enable") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp8_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp8_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_fp8_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_enable") {
if (option.second == "True" || option.second == "true") {
params.migraphx_int8_enable = true;
Expand Down

0 comments on commit 4850dfd

Please sign in to comment.