Skip to content
This repository has been archived by the owner on Aug 3, 2021. It is now read-only.

Commit

Permalink
Implemented cudaMemGetInfo for caching allocator
Browse files Browse the repository at this point in the history
borisfom committed Nov 15, 2016

Verified

This commit was signed with the committer’s verified signature.
primeos Michael Weiss
1 parent 5774690 commit 838ec70
Showing 4 changed files with 84 additions and 14 deletions.
31 changes: 21 additions & 10 deletions init.c
Original file line number Diff line number Diff line change
@@ -694,27 +694,38 @@ static int cutorch_setKernelPeerToPeerAccess(lua_State *L)
}

static int cutorch_getMemoryUsage(lua_State *L) {
size_t freeBytes = 0;
size_t totalBytes = 0;
int curDevice;
THCudaCheck(cudaGetDevice(&curDevice));
size_t freeBytes = 0;

THCState *state = cutorch_getstate(L);

int device = luaL_optint(L, 1, -10);
if (device == -10) { /* no argument passed, current device mem usage */
THCudaCheck(cudaMemGetInfo(&freeBytes, &totalBytes));
} else { /* argument was given, particular device's memory usage */
THCudaCheck(cudaSetDevice(device-1)); /* zero indexed */
THCudaCheck(cudaMemGetInfo(&freeBytes, &totalBytes));
THCudaCheck(cudaSetDevice(curDevice));
if (device != -10) { /* no argument passed, current device mem usage */
--device;
}

int prevDevice, curDevice = -10;
THCudaCheck(cudaGetDevice(&prevDevice));

if (device != -10) { /* no argument passed, current device mem usage */
curDevice = device; /* zero indexed */
if (curDevice != prevDevice)
THCudaCheck(cudaSetDevice(curDevice));
}

THCudaCheck(THCudaMemGetInfo(state, &totalBytes, &freeBytes));

if (curDevice != prevDevice) { /* restore current device if we have changed it */
THCudaCheck(cudaSetDevice(prevDevice));
}

lua_pushnumber(L, freeBytes);
lua_pushnumber(L, totalBytes);
return 2;
}

static int cutorch_setDevice(lua_State *L)
{
THCState *state = cutorch_getstate(L);
int device = (int)luaL_checknumber(L, 1)-1;
THCudaCheck(cudaSetDevice(device));
return 0;
37 changes: 33 additions & 4 deletions lib/THC/THCCachingAllocator.cpp
Original file line number Diff line number Diff line change
@@ -158,12 +158,12 @@ struct THCCachingAllocator
allocated_blocks.erase(it);

bool small = block->size <= kSmallAlloc;
auto& free_blocks = small ? large_blocks : small_blocks;
try_merge_blocks(block, block->prev, free_blocks);
try_merge_blocks(block, block->next, free_blocks);
auto& cur_free_blocks = small ? large_blocks : small_blocks;
try_merge_blocks(block, block->prev, cur_free_blocks);
try_merge_blocks(block, block->next, cur_free_blocks);

block->allocated = false;
free_blocks.insert(block);
cur_free_blocks.insert(block);

return cudaSuccess;
}
@@ -205,6 +205,27 @@ struct THCCachingAllocator
return basePtr;
}

// Accumulates sizes of all memory blocks for given device in given free list
void cacheInfoAux(FreeBlocks& blocks, int dev_id, size_t* total, size_t* largest)
{
Block search_key(dev_id, 0, 0);
auto it = blocks.lower_bound(&search_key);
for (;it != blocks.end() && *it && (*it)->device == dev_id; ++it) {
size_t blocksize = (*it)->size;
total += blocksize;
if (blocksize > *largest)
*largest = blocksize;
}
}

void cacheInfo(int dev_id, size_t* total, size_t* largest)
{
std::lock_guard<std::mutex> lock(mutex);
cacheInfoAux(large_blocks, dev_id, total, largest);
cacheInfoAux(small_blocks, dev_id, total, largest);
}


/** combine previously split blocks */
void try_merge_blocks(Block* dst, Block* src, FreeBlocks& free_blocks)
{
@@ -327,12 +348,20 @@ static cudaError_t THCCachingAllocator_emptyCache(void* ctx)
return a->emptyCache();
}

static cudaError_t THCCachingAllocator_cacheInfo(void* ctx, int dev_id, size_t* totalCached, size_t* largestBlock)
{
THCCachingAllocator* a = (THCCachingAllocator*) ctx;
a->cacheInfo(dev_id, totalCached, largestBlock);
return cudaSuccess;
}

static THCCachingAllocator caching_allocator;
static THCDeviceAllocator device_allocator = {
&THCCachingAllocator_malloc,
NULL,
&THCCachingAllocator_free,
&THCCachingAllocator_emptyCache,
&THCCachingAllocator_cacheInfo,
&caching_allocator
};

28 changes: 28 additions & 0 deletions lib/THC/THCGeneral.c
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@ static THCDeviceAllocator defaultDeviceAllocator = {
NULL,
&cudaFreeWrapper,
NULL,
NULL,
NULL
};

@@ -710,6 +711,33 @@ cudaError_t THCudaFree(THCState *state, void *ptr)
return allocator->free(allocator->state, ptr);
}

cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes)
{
size_t cachedBytes = 0;
size_t largestBlock = 0;
THCDeviceAllocator* allocator = state->cudaDeviceAllocator;

/* get info from CUDA first */
cudaError_t ret = cudaMemGetInfo(freeBytes, totalBytes);
if (ret!= cudaSuccess)
return ret;

int device;
ret = cudaGetDevice(&device);
if (ret!= cudaSuccess)
return ret;

/* not always true - our optimistic guess here */
largestBlock = *freeBytes;

if (allocator->cacheInfo != NULL)
allocator->cacheInfo(allocator->state, device, &cachedBytes, &largestBlock);

/* Adjust resulting free bytes number. largesBlock unused for now */
*freeBytes += cachedBytes;
return cudaSuccess;
}

static ptrdiff_t applyHeapDelta(THCState *state) {
ptrdiff_t newHeapSize = THAtomicAddPtrdiff(&heapSize, state->heapDelta) + state->heapDelta;
state->heapDelta = 0;
2 changes: 2 additions & 0 deletions lib/THC/THCGeneral.h.in
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ typedef struct _THCDeviceAllocator {
cudaError_t (*realloc)(void*, void**, size_t, size_t, cudaStream_t);
cudaError_t (*free)(void*, void*);
cudaError_t (*emptyCache)(void*);
cudaError_t (*cacheInfo)(void*, int, size_t*, size_t*);
void* state;
} THCDeviceAllocator;

@@ -177,6 +178,7 @@ THC_API void __THCublasCheck(cublasStatus_t status, const char *file, const int

THC_API cudaError_t THCudaMalloc(THCState *state, void **ptr, size_t size);
THC_API cudaError_t THCudaFree(THCState *state, void *ptr);
THC_API cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes);
THC_API void THCSetGCHandler(THCState *state,
void (*torchGCHandlerFunction)(void *data),
void *data );

0 comments on commit 838ec70

Please sign in to comment.