From 838ec700c17cb9f2b942c8110184fa1eef4c1301 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 15 Nov 2016 02:29:53 -0800 Subject: [PATCH] Implemented cudaMemGetInfo for caching allocator --- init.c | 31 ++++++++++++++++++--------- lib/THC/THCCachingAllocator.cpp | 37 +++++++++++++++++++++++++++++---- lib/THC/THCGeneral.c | 28 +++++++++++++++++++++++++ lib/THC/THCGeneral.h.in | 2 ++ 4 files changed, 84 insertions(+), 14 deletions(-) diff --git a/init.c b/init.c index 69f5583a..f7c97d36 100644 --- a/init.c +++ b/init.c @@ -694,19 +694,31 @@ static int cutorch_setKernelPeerToPeerAccess(lua_State *L) } static int cutorch_getMemoryUsage(lua_State *L) { - size_t freeBytes = 0; size_t totalBytes = 0; - int curDevice; - THCudaCheck(cudaGetDevice(&curDevice)); + size_t freeBytes = 0; + + THCState *state = cutorch_getstate(L); int device = luaL_optint(L, 1, -10); - if (device == -10) { /* no argument passed, current device mem usage */ - THCudaCheck(cudaMemGetInfo(&freeBytes, &totalBytes)); - } else { /* argument was given, particular device's memory usage */ - THCudaCheck(cudaSetDevice(device-1)); /* zero indexed */ - THCudaCheck(cudaMemGetInfo(&freeBytes, &totalBytes)); - THCudaCheck(cudaSetDevice(curDevice)); + if (device != -10) { /* no argument passed, current device mem usage */ + --device; } + + int prevDevice, curDevice = -10; + THCudaCheck(cudaGetDevice(&prevDevice)); + + if (device != -10) { /* no argument passed, current device mem usage */ + curDevice = device; /* zero indexed */ + if (curDevice != prevDevice) + THCudaCheck(cudaSetDevice(curDevice)); + } + + THCudaCheck(THCudaMemGetInfo(state, &totalBytes, &freeBytes)); + + if (curDevice != prevDevice) { /* restore current device if we have changed it */ + THCudaCheck(cudaSetDevice(prevDevice)); + } + lua_pushnumber(L, freeBytes); lua_pushnumber(L, totalBytes); return 2; @@ -714,7 +726,6 @@ static int cutorch_getMemoryUsage(lua_State *L) { static int cutorch_setDevice(lua_State *L) { - THCState *state = cutorch_getstate(L); int device = (int)luaL_checknumber(L, 1)-1; THCudaCheck(cudaSetDevice(device)); return 0; diff --git a/lib/THC/THCCachingAllocator.cpp b/lib/THC/THCCachingAllocator.cpp index e2fc8d85..ab9528e1 100644 --- a/lib/THC/THCCachingAllocator.cpp +++ b/lib/THC/THCCachingAllocator.cpp @@ -158,12 +158,12 @@ struct THCCachingAllocator allocated_blocks.erase(it); bool small = block->size <= kSmallAlloc; - auto& free_blocks = small ? large_blocks : small_blocks; - try_merge_blocks(block, block->prev, free_blocks); - try_merge_blocks(block, block->next, free_blocks); + auto& cur_free_blocks = small ? large_blocks : small_blocks; + try_merge_blocks(block, block->prev, cur_free_blocks); + try_merge_blocks(block, block->next, cur_free_blocks); block->allocated = false; - free_blocks.insert(block); + cur_free_blocks.insert(block); return cudaSuccess; } @@ -205,6 +205,27 @@ struct THCCachingAllocator return basePtr; } + // Accumulates sizes of all memory blocks for given device in given free list + void cacheInfoAux(FreeBlocks& blocks, int dev_id, size_t* total, size_t* largest) + { + Block search_key(dev_id, 0, 0); + auto it = blocks.lower_bound(&search_key); + for (;it != blocks.end() && *it && (*it)->device == dev_id; ++it) { + size_t blocksize = (*it)->size; + total += blocksize; + if (blocksize > *largest) + *largest = blocksize; + } + } + + void cacheInfo(int dev_id, size_t* total, size_t* largest) + { + std::lock_guard lock(mutex); + cacheInfoAux(large_blocks, dev_id, total, largest); + cacheInfoAux(small_blocks, dev_id, total, largest); + } + + /** combine previously split blocks */ void try_merge_blocks(Block* dst, Block* src, FreeBlocks& free_blocks) { @@ -327,12 +348,20 @@ static cudaError_t THCCachingAllocator_emptyCache(void* ctx) return a->emptyCache(); } +static cudaError_t THCCachingAllocator_cacheInfo(void* ctx, int dev_id, size_t* totalCached, size_t* largestBlock) +{ + THCCachingAllocator* a = (THCCachingAllocator*) ctx; + a->cacheInfo(dev_id, totalCached, largestBlock); + return cudaSuccess; +} + static THCCachingAllocator caching_allocator; static THCDeviceAllocator device_allocator = { &THCCachingAllocator_malloc, NULL, &THCCachingAllocator_free, &THCCachingAllocator_emptyCache, + &THCCachingAllocator_cacheInfo, &caching_allocator }; diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c index 13f62be9..403c4fa6 100644 --- a/lib/THC/THCGeneral.c +++ b/lib/THC/THCGeneral.c @@ -41,6 +41,7 @@ static THCDeviceAllocator defaultDeviceAllocator = { NULL, &cudaFreeWrapper, NULL, + NULL, NULL }; @@ -710,6 +711,33 @@ cudaError_t THCudaFree(THCState *state, void *ptr) return allocator->free(allocator->state, ptr); } +cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes) +{ + size_t cachedBytes = 0; + size_t largestBlock = 0; + THCDeviceAllocator* allocator = state->cudaDeviceAllocator; + + /* get info from CUDA first */ + cudaError_t ret = cudaMemGetInfo(freeBytes, totalBytes); + if (ret!= cudaSuccess) + return ret; + + int device; + ret = cudaGetDevice(&device); + if (ret!= cudaSuccess) + return ret; + + /* not always true - our optimistic guess here */ + largestBlock = *freeBytes; + + if (allocator->cacheInfo != NULL) + allocator->cacheInfo(allocator->state, device, &cachedBytes, &largestBlock); + + /* Adjust resulting free bytes number. largesBlock unused for now */ + *freeBytes += cachedBytes; + return cudaSuccess; +} + static ptrdiff_t applyHeapDelta(THCState *state) { ptrdiff_t newHeapSize = THAtomicAddPtrdiff(&heapSize, state->heapDelta) + state->heapDelta; state->heapDelta = 0; diff --git a/lib/THC/THCGeneral.h.in b/lib/THC/THCGeneral.h.in index 8f55cf3f..c685d373 100644 --- a/lib/THC/THCGeneral.h.in +++ b/lib/THC/THCGeneral.h.in @@ -49,6 +49,7 @@ typedef struct _THCDeviceAllocator { cudaError_t (*realloc)(void*, void**, size_t, size_t, cudaStream_t); cudaError_t (*free)(void*, void*); cudaError_t (*emptyCache)(void*); + cudaError_t (*cacheInfo)(void*, int, size_t*, size_t*); void* state; } THCDeviceAllocator; @@ -177,6 +178,7 @@ THC_API void __THCublasCheck(cublasStatus_t status, const char *file, const int THC_API cudaError_t THCudaMalloc(THCState *state, void **ptr, size_t size); THC_API cudaError_t THCudaFree(THCState *state, void *ptr); +THC_API cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes); THC_API void THCSetGCHandler(THCState *state, void (*torchGCHandlerFunction)(void *data), void *data );