Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ortvalue features support for MGX EP #81

Open
wants to merge 4 commits into
base: rocm6.3_internal_testing
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,21 @@
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
const char* migraphx_load_model_path; // migraphx model path name
bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false

/** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t)
* Defaults to SIZE_MAX.
* \note If a ::OrtArenaCfg has been applied, it will override this field
*/
size_t migraphx_mem_limit;

/** \brief Strategy used to grow the memory arena
* 0 = kNextPowerOfTwo<br>
* 1 = kSameAsRequested<br>
* Defaults to 0.
* \note If a ::OrtArenaCfg has been applied, it will override this field
*/
int migraphx_arena_extend_strategy;

Check warning on line 639 in include/onnxruntime/core/session/onnxruntime_c_api.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/session/onnxruntime_c_api.h#L639

Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3]
Raw output
include/onnxruntime/core/session/onnxruntime_c_api.h:639:  Redundant blank line at the end of a code block should be deleted.  [whitespace/blank_line] [3]
} OrtMIGraphXProviderOptions;

/** \brief OpenVINO Provider Options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
#include "core/common/safeint.h"
#include "core/common/logging/severity.h"
#include "migraphx_execution_provider.h"
#include "migraphx_execution_provider_info.h"

Check warning on line 16 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L16

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:16:  Include the directory when naming header files  [build/include_subdir] [4]
#include "migraphx_execution_provider_utils.h"
#include "migraphx_allocator.h"
#include "gpu_data_transfer.h"
#include "migraphx_inc.h"
#include "migraphx_call.h"

Check warning on line 20 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L20

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:20:  Include the directory when naming header files  [build/include_subdir] [4]

#include "migraphx_stream_handle.h"

Expand Down Expand Up @@ -208,6 +209,44 @@
MIGraphXExecutionProvider::~MIGraphXExecutionProvider() {
}

AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id,
size_t migx_mem_limit,

Check warning on line 213 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L213

Do not indent within a namespace. [whitespace/indent_namespace] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:213:  Do not indent within a namespace.  [whitespace/indent_namespace] [4]
ArenaExtendStrategy arena_extend_strategy,

Check warning on line 214 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L214

Do not indent within a namespace. [whitespace/indent_namespace] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:214:  Do not indent within a namespace.  [whitespace/indent_namespace] [4]
MIGraphXExecutionProviderExternalAllocatorInfo

Check warning on line 215 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L215

Do not indent within a namespace. [whitespace/indent_namespace] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:215:  Do not indent within a namespace.  [whitespace/indent_namespace] [4]
external_allocator_info,

Check warning on line 216 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L216

Do not indent within a namespace. [whitespace/indent_namespace] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:216:  Do not indent within a namespace.  [whitespace/indent_namespace] [4]
const OrtArenaCfg* default_memory_arena_cfg) {

Check warning on line 217 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L217

Do not indent within a namespace. [whitespace/indent_namespace] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:217:  Do not indent within a namespace.  [whitespace/indent_namespace] [4]
if (external_allocator_info.UseExternalAllocator()) {
AllocatorCreationInfo default_memory_info(
[external_allocator_info](OrtDevice::DeviceId id) {
return std::make_unique<MIGraphXExternalAllocator>(id, HIP,
external_allocator_info.alloc,
external_allocator_info.free,
external_allocator_info.empty_cache);
},
device_id,
false);

return CreateAllocator(default_memory_info);
} else {
AllocatorCreationInfo default_memory_info(
[](OrtDevice::DeviceId id) {
return std::make_unique<MIGraphXAllocator>(id, HIP);
},
device_id,
true,
{default_memory_arena_cfg ? *default_memory_arena_cfg
: OrtArenaCfg(migx_mem_limit, static_cast<int>(arena_extend_strategy),
-1, -1, -1, -1L)},
// make it stream aware
true,
// enable cross stream sharing?
false);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we want to make controllable from he API later?


// ROCM malloc/free is expensive so always use an arena
return CreateAllocator(default_memory_info);
}
}

std::vector<AllocatorPtr> MIGraphXExecutionProvider::CreatePreferredAllocators() {
AllocatorCreationInfo default_memory_info(
[](OrtDevice::DeviceId device_id) { return CreateMIGraphXAllocator(device_id, onnxruntime::CUDA); }, info_.device_id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "core/framework/execution_provider.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/migraphx/migraphx_execution_provider_info.h"
#include "core/providers/migraphx/migraphx_inc.h"
#include "core/providers/migraphx/migraphx_call.h"

#include <map>
#include <unordered_map>
Expand Down Expand Up @@ -76,6 +76,9 @@
virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override;

static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy,

Check warning on line 79 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.h#L79

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.h:79:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
MIGraphXExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg);

Check warning on line 80 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.h#L80

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.h:80:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

std::unique_ptr<IndexedSubGraph> GetSubGraph(const std::vector<std::size_t>& graph_nodes_index, const GraphViewer& graph) const;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/migraphx/migraphx_execution_provider_info.h"

#include "core/common/make_string.h"
Expand All @@ -10,6 +11,12 @@
#include "migraphx_call.h"

namespace onnxruntime {

const EnumNameMapping<ArenaExtendStrategy> arena_extend_strategy_mapping{
{ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"},
{ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"},
};

namespace migraphx {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
Expand All @@ -22,12 +29,20 @@
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
constexpr const char* kLoadModelPath = "migx_load_model_name";
constexpr const char* kExhaustiveTune = "migx_exhaustive_tune";
constexpr const char* kMemLimit = "migx_mem_limit";
constexpr const char* kArenaExtendStrategy = "migx_arena_extend_strategy";
constexpr const char* kGpuExternalAlloc = "migx_external_alloc";
constexpr const char* kGpuExternalFree = "migx_external_free";
constexpr const char* kGpuExternalEmptyCache = "migx_external_empty_cache";

} // namespace provider_option_names
} // namespace migraphx

MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
MIGraphXExecutionProviderInfo info{};
void* alloc = nullptr;
void* free = nullptr;
void* empty_cache = nullptr;
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddValueParser(
Expand All @@ -42,13 +57,42 @@
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
return Status::OK();
})
.AddValueParser(
migraphx_provider_option::kGpuExternalAlloc,
[&alloc](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
alloc = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
migraphx_provider_option::kGpuExternalFree,
[&free](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
free = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
migraphx_provider_option::kGpuExternalEmptyCache,
[&empty_cache](const std::string& value_str) -> Status {

Check warning on line 78 in onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc#L78

Add #include <string> for string [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc:78:  Add #include <string> for string  [build/include_what_you_use] [4]
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
empty_cache = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune)
.AddAssignmentToReference(migraphx_provider_option::kMemLimit, info.mem_limit)
.AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy)

Check warning on line 90 in onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc#L90

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc:90:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
.Parse(options));

MIGraphXExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache};
info.external_allocator_info = alloc_info;

return info;
}

Expand All @@ -59,6 +103,12 @@
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
{migraphx_provider_option::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)},
{migraphx_provider_option::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},

Check warning on line 107 in onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc#L107

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc:107:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
{migraphx_provider_option::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},

Check warning on line 108 in onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc#L108

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc:108:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
{migraphx_provider_option::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.empty_cache))},

Check warning on line 109 in onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc#L109

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc:109:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
{migraphx_provider_option::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)},
};
return options;
Expand All @@ -71,6 +121,8 @@
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
{migraphx_provider_option::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)},
{migraphx_provider_option::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast<onnxruntime::ArenaExtendStrategy>(info.migraphx_arena_extend_strategy))},

Check warning on line 125 in onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc#L125

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc:125:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)},
};
return options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,36 @@
#include <string>

#include "core/framework/ortdevice.h"
#include "core/common/hash_combine.h"
#include "core/framework/arena_extend_strategy.h"
#include "core/framework/provider_options.h"
#include "core/session/onnxruntime_c_api.h"

namespace onnxruntime {

// Information needed to construct MIGraphX execution providers.
struct MIGraphXExecutionProviderExternalAllocatorInfo {
void* alloc{nullptr};
void* free{nullptr};
void* empty_cache{nullptr};

MIGraphXExecutionProviderExternalAllocatorInfo() {
alloc = nullptr;
free = nullptr;
empty_cache = nullptr;
}

MIGraphXExecutionProviderExternalAllocatorInfo(void* a, void* f, void* e) {
alloc = a;
free = f;
empty_cache = e;
}

bool UseExternalAllocator() const {
return (alloc != nullptr) && (free != nullptr);
}
};

// Information needed to construct trt execution providers.
struct MIGraphXExecutionProviderInfo {
std::string target_device;
Expand All @@ -25,8 +51,43 @@
std::string load_model_file{"./compiled_model.mxr"};
bool exhaustive_tune{false};

size_t mem_limit{std::numeric_limits<size_t>::max()}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified)

Check warning on line 54 in onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h#L54

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h:54:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified)

Check warning on line 55 in onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h#L55

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h:55:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

OrtArenaCfg* default_memory_arena_cfg{nullptr};
MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info{};

static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);
static ProviderOptions ToProviderOptions(const OrtMIGraphXProviderOptions& info);
};
} // namespace onnxruntime

template <>
struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> {
size_t operator()(const ::onnxruntime::MIGraphXExecutionProviderInfo& info) const {
size_t value{0xbc9f1d34}; // seed

// Bits: device_id (16), arena_extend_strategy (reserved 2), boolean options (1 each)
size_t data = static_cast<size_t>(info.device_id) ^
(static_cast<size_t>(info.arena_extend_strategy) << 16) ^
(static_cast<size_t>(info.fp16_enable) << 18) ^
(static_cast<size_t>(info.int8_enable) << 19) ^
(static_cast<size_t>(info.int8_use_native_calibration_table) << 20) ^
(static_cast<size_t>(info.model_cache_enable) << 21) ^
(static_cast<size_t>(info.save_compiled_model) << 22) ^
(static_cast<size_t>(info.load_compiled_model) << 23) ^
(static_cast<size_t>(info.exhaustive_tune) << 24);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going forward is the intent to add the other flags (fp16/int8) and other quantize modes in here as well?

onnxruntime::HashCombine(data, value);

onnxruntime::HashCombine(info.mem_limit, value);

// Memory pointers
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.external_allocator_info.alloc), value);
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.external_allocator_info.free), value);
onnxruntime::HashCombine(reinterpret_cast<size_t>(info.external_allocator_info.empty_cache), value);

// The default memory arena cfg is not used in hashing right now.
return value;
}
};
26 changes: 26 additions & 0 deletions onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/migraphx/migraphx_provider_factory.h"
#include "migraphx_execution_provider.h"
#include "migraphx_execution_provider_info.h"

Check warning on line 8 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc#L8

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:8:  Include the directory when naming header files  [build/include_subdir] [4]
#include "migraphx_provider_factory_creator.h"
#include "migraphx_allocator.h"
#include "gpu_data_transfer.h"
Expand Down Expand Up @@ -42,6 +43,27 @@
return std::make_unique<HIPPinnedAllocator>(device_id, name);
}

void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) override {
// hipMemcpy() operates on the default stream
HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice));

// To ensure that the copy has completed, invoke a stream sync for the default stream.
// For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated.

Check warning on line 51 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc#L51

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:51:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
// The function will return once the pageable buffer has been copied to the staging memory for DMA transfer
// to device memory, but the DMA to final destination may not have completed.

HIP_CALL_THROW(hipStreamSynchronize(0));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we always want to be using hipstream 0 for this?

}

// Used by onnxruntime_pybind_state.cc
void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override {
// For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed.

Check warning on line 60 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc#L60

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:60:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost));
}

std::shared_ptr<IAllocator> CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override {

Check warning on line 64 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc#L64

Weird number of spaces at line-start. Are you using a 2-space indent? [whitespace/indent] [3]
Raw output
onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:64:  Weird number of spaces at line-start.  Are you using a 2-space indent?  [whitespace/indent] [3]

Check warning on line 64 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc#L64

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:64:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, migx_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg);

Check warning on line 65 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc#L65

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc:65:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
} g_info;

struct MIGraphX_Provider : Provider {
Expand Down Expand Up @@ -77,6 +99,8 @@
if (options.migraphx_load_model_path != nullptr) {
info.load_model_file = options.migraphx_load_model_path;
}
info.arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(options.migraphx_arena_extend_strategy);
info.mem_limit = options.migraphx_mem_limit;
return std::make_shared<MIGraphXProviderFactory>(info);
}

Expand Down Expand Up @@ -109,6 +133,8 @@
migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str();
migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model;
migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str();
migx_options.migraphx_arena_extend_strategy = static_cast<int>(internal_options.arena_extend_strategy);
migx_options.migraphx_mem_limit = internal_options.mem_limit;
}

ProviderOptions GetProviderOptions(const void* provider_options) override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
struct ProviderInfo_MIGraphX {
virtual std::unique_ptr<onnxruntime::IAllocator> CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0;
virtual std::unique_ptr<onnxruntime::IAllocator> CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0;
virtual void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0;
virtual void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0;
virtual std::shared_ptr<onnxruntime::IAllocator> CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0;

Check warning on line 19 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.h#L19

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_provider_factory.h:19:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 19 in onnxruntime/core/providers/migraphx/migraphx_provider_factory.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_provider_factory.h#L19

Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_provider_factory.h:19:  Add #include <memory> for shared_ptr<>  [build/include_what_you_use] [4]

protected:
~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance
Expand Down
Loading
Loading