diff --git a/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_device.cpp b/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_device.cpp index dc9f6a9f7069d..7b95bb9bf5b1a 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_device.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_device.cpp @@ -573,7 +573,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo( case UR_DEVICE_INFO_USM_CROSS_SHARED_SUPPORT: case UR_DEVICE_INFO_USM_SYSTEM_SHARED_SUPPORT: { auto MapCaps = [](const ze_memory_access_cap_flags_t &ZeCapabilities) { - uint64_t Capabilities = 0; + ur_device_usm_access_capability_flags_t Capabilities = 0; if (ZeCapabilities & ZE_MEMORY_ACCESS_CAP_FLAG_RW) Capabilities |= UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS; if (ZeCapabilities & ZE_MEMORY_ACCESS_CAP_FLAG_ATOMIC) diff --git a/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_usm.cpp b/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_usm.cpp index 0b0cc51c845d9..9f215d06d85a8 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_usm.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_usm.cpp @@ -24,17 +24,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc( Size, ///< [in] size in bytes of the USM memory object to be allocated void **RetMem ///< [out] pointer to USM host memory object ) { - std::ignore = Pool; - uint32_t Align = USMDesc->align; + uint32_t Align = USMDesc ? USMDesc->align : 0; // L0 supports alignment up to 64KB and silently ignores higher values. // We flag alignment > 64KB as an invalid value. if (Align > 65536) return UR_RESULT_ERROR_INVALID_VALUE; - const ur_usm_advice_flags_t *USMHintFlags = &USMDesc->hints; - std::ignore = USMHintFlags; - ur_platform_handle_t Plt = Context->getPlatform(); // If indirect access tracking is enabled then lock the mutex which is // guarding contexts container in the platform. This prevents new kernels from @@ -77,7 +73,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc( // find the allocator depending on context as we do for Shared and Device // allocations. try { - *RetMem = Context->HostMemAllocContext->allocate(Size, Align); + if (Pool) { + *RetMem = Pool->HostMemPool->allocate(Size, Align); + } else { + *RetMem = Context->HostMemAllocContext->allocate(Size, Align); + } if (IndirectAccessTrackingEnabled) { // Keep track of all memory allocations in the context Context->MemAllocs.emplace(std::piecewise_construct, @@ -105,18 +105,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc( Size, ///< [in] size in bytes of the USM memory object to be allocated void **RetMem ///< [out] pointer to USM device memory object ) { - std::ignore = Pool; - uint32_t Alignment = USMDesc->align; + uint32_t Alignment = USMDesc ? USMDesc->align : 0; // L0 supports alignment up to 64KB and silently ignores higher values. // We flag alignment > 64KB as an invalid value. if (Alignment > 65536) return UR_RESULT_ERROR_INVALID_VALUE; - const ur_usm_advice_flags_t *USMHintFlags = &USMDesc->hints; - std::ignore = USMHintFlags; - ur_platform_handle_t Plt = Device->Platform; // If indirect access tracking is enabled then lock the mutex which is @@ -157,11 +153,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc( } try { - auto It = Context->DeviceMemAllocContexts.find(Device->ZeDevice); - if (It == Context->DeviceMemAllocContexts.end()) - return UR_RESULT_ERROR_INVALID_VALUE; - *RetMem = It->second.allocate(Size, Alignment); + if (Pool) { + *RetMem = Pool->DeviceMemPools[Device]->allocate(Size, Alignment); + } else { + auto It = Context->DeviceMemAllocContexts.find(Device->ZeDevice); + if (It == Context->DeviceMemAllocContexts.end()) + return UR_RESULT_ERROR_INVALID_VALUE; + + *RetMem = It->second.allocate(Size, Alignment); + } if (IndirectAccessTrackingEnabled) { // Keep track of all memory allocations in the context Context->MemAllocs.emplace(std::piecewise_construct, @@ -190,9 +191,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc( Size, ///< [in] size in bytes of the USM memory object to be allocated void **RetMem ///< [out] pointer to USM shared memory object ) { - std::ignore = Pool; - uint32_t Alignment = USMDesc->align; + uint32_t Alignment = USMDesc ? USMDesc->align : 0; ur_usm_host_mem_flags_t UsmHostFlags{}; @@ -200,7 +200,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc( bool DeviceReadOnly = false; ur_usm_device_mem_flags_t UsmDeviceFlags{}; - void *pNext = const_cast(USMDesc->pNext); + void *pNext = USMDesc ? const_cast(USMDesc->pNext) : nullptr; while (pNext != nullptr) { const ur_base_desc_t *BaseDesc = reinterpret_cast(pNext); @@ -259,13 +259,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc( } try { - auto &Allocator = (DeviceReadOnly ? Context->SharedReadOnlyMemAllocContexts - : Context->SharedMemAllocContexts); - auto It = Allocator.find(Device->ZeDevice); - if (It == Allocator.end()) - return UR_RESULT_ERROR_INVALID_VALUE; - - *RetMem = It->second.allocate(Size, Alignment); + if (Pool) { + if (DeviceReadOnly) { + *RetMem = + Pool->SharedMemReadOnlyPools[Device]->allocate(Size, Alignment); + } else { + *RetMem = Pool->SharedMemPools[Device]->allocate(Size, Alignment); + } + } else { + auto &Allocator = + (DeviceReadOnly ? Context->SharedReadOnlyMemAllocContexts + : Context->SharedMemAllocContexts); + auto It = Allocator.find(Device->ZeDevice); + if (It == Allocator.end()) + return UR_RESULT_ERROR_INVALID_VALUE; + + *RetMem = It->second.allocate(Size, Alignment); + } if (DeviceReadOnly) { Context->SharedReadOnlyAllocs.insert(*RetMem); } @@ -518,6 +528,56 @@ static ur_result_t USMAllocationMakeResident( return UR_RESULT_SUCCESS; } +ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context, + ur_usm_pool_desc_t *PoolDesc) { + + zeroInit = static_cast(PoolDesc->flags & + UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK); + + void *pNext = const_cast(PoolDesc->pNext); + while (pNext != nullptr) { + const ur_base_desc_t *BaseDesc = + reinterpret_cast(pNext); + switch (BaseDesc->stype) { + case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: { + const ur_usm_pool_limits_desc_t *Limits = + reinterpret_cast(BaseDesc); + for (auto &config : USMAllocatorConfigs.Configs) { + config.MaxPoolableSize = Limits->maxPoolableSize; + config.SlabMinSize = Limits->minDriverAllocSize; + } + break; + } + default: { + urPrint("urUSMPoolCreate: unexpected chained stype\n"); + throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT); + } + } + pNext = const_cast(BaseDesc->pNext); + } + + HostMemPool = std::make_unique( + std::unique_ptr(new USMHostMemoryAlloc(Context)), + this->USMAllocatorConfigs.Configs[usm_settings::MemType::Host]); + + for (auto device : Context->Devices) { + DeviceMemPools[device] = std::make_unique( + std::unique_ptr( + new USMDeviceMemoryAlloc(Context, device)), + this->USMAllocatorConfigs.Configs[usm_settings::MemType::Device]); + + SharedMemPools[device] = std::make_unique( + std::unique_ptr( + new USMSharedMemoryAlloc(Context, device)), + this->USMAllocatorConfigs.Configs[usm_settings::MemType::Shared]); + SharedMemReadOnlyPools[device] = std::make_unique( + std::unique_ptr( + new USMSharedMemoryAlloc(Context, device)), + this->USMAllocatorConfigs + .Configs[usm_settings::MemType::SharedReadOnly]); + } +} + UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate( ur_context_handle_t Context, ///< [in] handle of the context object ur_usm_pool_desc_t @@ -525,27 +585,30 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate( ///< ::ur_usm_pool_limits_desc_t ur_usm_pool_handle_t *Pool ///< [out] pointer to USM memory pool ) { - std::ignore = Context; - std::ignore = PoolDesc; - std::ignore = Pool; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + + try { + *Pool = reinterpret_cast( + new ur_usm_pool_handle_t_(Context, PoolDesc)); + } catch (const UsmAllocationException &Ex) { + return Ex.getError(); + } + return UR_RESULT_SUCCESS; } ur_result_t urUSMPoolRetain(ur_usm_pool_handle_t Pool ///< [in] pointer to USM memory pool ) { - std::ignore = Pool; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + Pool->RefCount.increment(); + return UR_RESULT_SUCCESS; } ur_result_t urUSMPoolRelease(ur_usm_pool_handle_t Pool ///< [in] pointer to USM memory pool ) { - std::ignore = Pool; - urPrint("[UR][L0] %s function not implemented!\n", __FUNCTION__); - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + if (Pool->RefCount.decrementAndTest()) { + delete Pool; + } + return UR_RESULT_SUCCESS; } ur_result_t urUSMPoolGetInfo( diff --git a/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_usm.hpp b/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_usm.hpp index ba0130089906e..a53b6d35712f9 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_usm.hpp +++ b/sycl/plugins/unified_runtime/ur/adapters/level_zero/ur_level_zero_usm.hpp @@ -9,6 +9,23 @@ #include "ur_level_zero_common.hpp" +struct ur_usm_pool_handle_t_ : _ur_object { + bool zeroInit; + + usm_settings::USMAllocatorConfig USMAllocatorConfigs; + + std::unique_ptr HostMemPool; + std::unordered_map> + SharedMemPools; + std::unordered_map> + SharedMemReadOnlyPools; + std::unordered_map> + DeviceMemPools; + + ur_usm_pool_handle_t_(ur_context_handle_t Context, + ur_usm_pool_desc_t *PoolDesc); +}; + // Exception type to pass allocation errors class UsmAllocationException { const ur_result_t Error;