Skip to content

Commit

Permalink
Add implementation of USM pools (intel#11)
Browse files Browse the repository at this point in the history
Signed-off-by: Brandon Yates <[email protected]>
  • Loading branch information
bmyates authored May 23, 2023
1 parent 6fe8175 commit f572452
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(
case UR_DEVICE_INFO_USM_CROSS_SHARED_SUPPORT:
case UR_DEVICE_INFO_USM_SYSTEM_SHARED_SUPPORT: {
auto MapCaps = [](const ze_memory_access_cap_flags_t &ZeCapabilities) {
uint64_t Capabilities = 0;
ur_device_usm_access_capability_flags_t Capabilities = 0;
if (ZeCapabilities & ZE_MEMORY_ACCESS_CAP_FLAG_RW)
Capabilities |= UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS;
if (ZeCapabilities & ZE_MEMORY_ACCESS_CAP_FLAG_ATOMIC)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
Size, ///< [in] size in bytes of the USM memory object to be allocated
void **RetMem ///< [out] pointer to USM host memory object
) {
std::ignore = Pool;

uint32_t Align = USMDesc->align;
uint32_t Align = USMDesc ? USMDesc->align : 0;
// L0 supports alignment up to 64KB and silently ignores higher values.
// We flag alignment > 64KB as an invalid value.
if (Align > 65536)
return UR_RESULT_ERROR_INVALID_VALUE;

const ur_usm_advice_flags_t *USMHintFlags = &USMDesc->hints;
std::ignore = USMHintFlags;

ur_platform_handle_t Plt = Context->getPlatform();
// If indirect access tracking is enabled then lock the mutex which is
// guarding contexts container in the platform. This prevents new kernels from
Expand Down Expand Up @@ -77,7 +73,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc(
// find the allocator depending on context as we do for Shared and Device
// allocations.
try {
*RetMem = Context->HostMemAllocContext->allocate(Size, Align);
if (Pool) {
*RetMem = Pool->HostMemPool->allocate(Size, Align);
} else {
*RetMem = Context->HostMemAllocContext->allocate(Size, Align);
}
if (IndirectAccessTrackingEnabled) {
// Keep track of all memory allocations in the context
Context->MemAllocs.emplace(std::piecewise_construct,
Expand Down Expand Up @@ -105,18 +105,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
Size, ///< [in] size in bytes of the USM memory object to be allocated
void **RetMem ///< [out] pointer to USM device memory object
) {
std::ignore = Pool;

uint32_t Alignment = USMDesc->align;
uint32_t Alignment = USMDesc ? USMDesc->align : 0;

// L0 supports alignment up to 64KB and silently ignores higher values.
// We flag alignment > 64KB as an invalid value.
if (Alignment > 65536)
return UR_RESULT_ERROR_INVALID_VALUE;

const ur_usm_advice_flags_t *USMHintFlags = &USMDesc->hints;
std::ignore = USMHintFlags;

ur_platform_handle_t Plt = Device->Platform;

// If indirect access tracking is enabled then lock the mutex which is
Expand Down Expand Up @@ -157,11 +153,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc(
}

try {
auto It = Context->DeviceMemAllocContexts.find(Device->ZeDevice);
if (It == Context->DeviceMemAllocContexts.end())
return UR_RESULT_ERROR_INVALID_VALUE;

*RetMem = It->second.allocate(Size, Alignment);
if (Pool) {
*RetMem = Pool->DeviceMemPools[Device]->allocate(Size, Alignment);
} else {
auto It = Context->DeviceMemAllocContexts.find(Device->ZeDevice);
if (It == Context->DeviceMemAllocContexts.end())
return UR_RESULT_ERROR_INVALID_VALUE;

*RetMem = It->second.allocate(Size, Alignment);
}
if (IndirectAccessTrackingEnabled) {
// Keep track of all memory allocations in the context
Context->MemAllocs.emplace(std::piecewise_construct,
Expand Down Expand Up @@ -190,17 +191,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
Size, ///< [in] size in bytes of the USM memory object to be allocated
void **RetMem ///< [out] pointer to USM shared memory object
) {
std::ignore = Pool;

uint32_t Alignment = USMDesc->align;
uint32_t Alignment = USMDesc ? USMDesc->align : 0;

ur_usm_host_mem_flags_t UsmHostFlags{};

// See if the memory is going to be read-only on the device.
bool DeviceReadOnly = false;
ur_usm_device_mem_flags_t UsmDeviceFlags{};

void *pNext = const_cast<void *>(USMDesc->pNext);
void *pNext = USMDesc ? const_cast<void *>(USMDesc->pNext) : nullptr;
while (pNext != nullptr) {
const ur_base_desc_t *BaseDesc =
reinterpret_cast<const ur_base_desc_t *>(pNext);
Expand Down Expand Up @@ -259,13 +259,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc(
}

try {
auto &Allocator = (DeviceReadOnly ? Context->SharedReadOnlyMemAllocContexts
: Context->SharedMemAllocContexts);
auto It = Allocator.find(Device->ZeDevice);
if (It == Allocator.end())
return UR_RESULT_ERROR_INVALID_VALUE;

*RetMem = It->second.allocate(Size, Alignment);
if (Pool) {
if (DeviceReadOnly) {
*RetMem =
Pool->SharedMemReadOnlyPools[Device]->allocate(Size, Alignment);
} else {
*RetMem = Pool->SharedMemPools[Device]->allocate(Size, Alignment);
}
} else {
auto &Allocator =
(DeviceReadOnly ? Context->SharedReadOnlyMemAllocContexts
: Context->SharedMemAllocContexts);
auto It = Allocator.find(Device->ZeDevice);
if (It == Allocator.end())
return UR_RESULT_ERROR_INVALID_VALUE;

*RetMem = It->second.allocate(Size, Alignment);
}
if (DeviceReadOnly) {
Context->SharedReadOnlyAllocs.insert(*RetMem);
}
Expand Down Expand Up @@ -518,34 +528,87 @@ static ur_result_t USMAllocationMakeResident(
return UR_RESULT_SUCCESS;
}

ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
ur_usm_pool_desc_t *PoolDesc) {

zeroInit = static_cast<ur_bool_t>(PoolDesc->flags &
UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK);

void *pNext = const_cast<void *>(PoolDesc->pNext);
while (pNext != nullptr) {
const ur_base_desc_t *BaseDesc =
reinterpret_cast<const ur_base_desc_t *>(pNext);
switch (BaseDesc->stype) {
case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: {
const ur_usm_pool_limits_desc_t *Limits =
reinterpret_cast<const ur_usm_pool_limits_desc_t *>(BaseDesc);
for (auto &config : USMAllocatorConfigs.Configs) {
config.MaxPoolableSize = Limits->maxPoolableSize;
config.SlabMinSize = Limits->minDriverAllocSize;
}
break;
}
default: {
urPrint("urUSMPoolCreate: unexpected chained stype\n");
throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT);
}
}
pNext = const_cast<void *>(BaseDesc->pNext);
}

HostMemPool = std::make_unique<USMAllocContext>(
std::unique_ptr<SystemMemory>(new USMHostMemoryAlloc(Context)),
this->USMAllocatorConfigs.Configs[usm_settings::MemType::Host]);

for (auto device : Context->Devices) {
DeviceMemPools[device] = std::make_unique<USMAllocContext>(
std::unique_ptr<SystemMemory>(
new USMDeviceMemoryAlloc(Context, device)),
this->USMAllocatorConfigs.Configs[usm_settings::MemType::Device]);

SharedMemPools[device] = std::make_unique<USMAllocContext>(
std::unique_ptr<SystemMemory>(
new USMSharedMemoryAlloc(Context, device)),
this->USMAllocatorConfigs.Configs[usm_settings::MemType::Shared]);
SharedMemReadOnlyPools[device] = std::make_unique<USMAllocContext>(
std::unique_ptr<SystemMemory>(
new USMSharedMemoryAlloc(Context, device)),
this->USMAllocatorConfigs
.Configs[usm_settings::MemType::SharedReadOnly]);
}
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate(
ur_context_handle_t Context, ///< [in] handle of the context object
ur_usm_pool_desc_t
*PoolDesc, ///< [in] pointer to USM pool descriptor. Can be chained with
///< ::ur_usm_pool_limits_desc_t
ur_usm_pool_handle_t *Pool ///< [out] pointer to USM memory pool
) {
std::ignore = Context;
std::ignore = PoolDesc;
std::ignore = Pool;
urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__);
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;

try {
*Pool = reinterpret_cast<ur_usm_pool_handle_t>(
new ur_usm_pool_handle_t_(Context, PoolDesc));
} catch (const UsmAllocationException &Ex) {
return Ex.getError();
}
return UR_RESULT_SUCCESS;
}

ur_result_t
urUSMPoolRetain(ur_usm_pool_handle_t Pool ///< [in] pointer to USM memory pool
) {
std::ignore = Pool;
urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__);
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
Pool->RefCount.increment();
return UR_RESULT_SUCCESS;
}

ur_result_t
urUSMPoolRelease(ur_usm_pool_handle_t Pool ///< [in] pointer to USM memory pool
) {
std::ignore = Pool;
urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__);
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
if (Pool->RefCount.decrementAndTest()) {
delete Pool;
}
return UR_RESULT_SUCCESS;
}

ur_result_t urUSMPoolGetInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,23 @@

#include "ur_level_zero_common.hpp"

struct ur_usm_pool_handle_t_ : _ur_object {
bool zeroInit;

usm_settings::USMAllocatorConfig USMAllocatorConfigs;

std::unique_ptr<USMAllocContext> HostMemPool;
std::unordered_map<ur_device_handle_t, std::unique_ptr<USMAllocContext>>
SharedMemPools;
std::unordered_map<ur_device_handle_t, std::unique_ptr<USMAllocContext>>
SharedMemReadOnlyPools;
std::unordered_map<ur_device_handle_t, std::unique_ptr<USMAllocContext>>
DeviceMemPools;

ur_usm_pool_handle_t_(ur_context_handle_t Context,
ur_usm_pool_desc_t *PoolDesc);
};

// Exception type to pass allocation errors
class UsmAllocationException {
const ur_result_t Error;
Expand Down

0 comments on commit f572452

Please sign in to comment.