Skip to content
This repository has been archived by the owner on Aug 3, 2021. It is now read-only.

Commit

Permalink
Implemented cudaMemGetInfo for caching allocator
Browse files Browse the repository at this point in the history
borisfom committed Nov 15, 2016
1 parent 5774690 commit 838ec70
Showing 4 changed files with 84 additions and 14 deletions.
31 changes: 21 additions & 10 deletions init.c
Original file line number Diff line number Diff line change
@@ -694,27 +694,38 @@ static int cutorch_setKernelPeerToPeerAccess(lua_State *L)
}

static int cutorch_getMemoryUsage(lua_State *L) {
size_t freeBytes = 0;
size_t totalBytes = 0;
int curDevice;
THCudaCheck(cudaGetDevice(&curDevice));
size_t freeBytes = 0;

THCState *state = cutorch_getstate(L);

int device = luaL_optint(L, 1, -10);
if (device == -10) { /* no argument passed, current device mem usage */
THCudaCheck(cudaMemGetInfo(&freeBytes, &totalBytes));
} else { /* argument was given, particular device's memory usage */
THCudaCheck(cudaSetDevice(device-1)); /* zero indexed */
THCudaCheck(cudaMemGetInfo(&freeBytes, &totalBytes));
THCudaCheck(cudaSetDevice(curDevice));
if (device != -10) { /* no argument passed, current device mem usage */
--device;
}

int prevDevice, curDevice = -10;
THCudaCheck(cudaGetDevice(&prevDevice));

if (device != -10) { /* no argument passed, current device mem usage */
curDevice = device; /* zero indexed */
if (curDevice != prevDevice)
THCudaCheck(cudaSetDevice(curDevice));
}

THCudaCheck(THCudaMemGetInfo(state, &totalBytes, &freeBytes));

if (curDevice != prevDevice) { /* restore current device if we have changed it */
THCudaCheck(cudaSetDevice(prevDevice));
}

lua_pushnumber(L, freeBytes);
lua_pushnumber(L, totalBytes);
return 2;
}

static int cutorch_setDevice(lua_State *L)
{
THCState *state = cutorch_getstate(L);
int device = (int)luaL_checknumber(L, 1)-1;
THCudaCheck(cudaSetDevice(device));
return 0;
37 changes: 33 additions & 4 deletions lib/THC/THCCachingAllocator.cpp
Original file line number Diff line number Diff line change
@@ -158,12 +158,12 @@ struct THCCachingAllocator
allocated_blocks.erase(it);

bool small = block->size <= kSmallAlloc;
auto& free_blocks = small ? large_blocks : small_blocks;
try_merge_blocks(block, block->prev, free_blocks);
try_merge_blocks(block, block->next, free_blocks);
auto& cur_free_blocks = small ? large_blocks : small_blocks;
try_merge_blocks(block, block->prev, cur_free_blocks);
try_merge_blocks(block, block->next, cur_free_blocks);

block->allocated = false;
free_blocks.insert(block);
cur_free_blocks.insert(block);

return cudaSuccess;
}
@@ -205,6 +205,27 @@ struct THCCachingAllocator
return basePtr;
}

// Accumulates sizes of all memory blocks for given device in given free list
void cacheInfoAux(FreeBlocks& blocks, int dev_id, size_t* total, size_t* largest)
{
Block search_key(dev_id, 0, 0);
auto it = blocks.lower_bound(&search_key);
for (;it != blocks.end() && *it && (*it)->device == dev_id; ++it) {
size_t blocksize = (*it)->size;
total += blocksize;
if (blocksize > *largest)
*largest = blocksize;
}
}

void cacheInfo(int dev_id, size_t* total, size_t* largest)
{
std::lock_guard<std::mutex> lock(mutex);
cacheInfoAux(large_blocks, dev_id, total, largest);
cacheInfoAux(small_blocks, dev_id, total, largest);
}


/** combine previously split blocks */
void try_merge_blocks(Block* dst, Block* src, FreeBlocks& free_blocks)
{
@@ -327,12 +348,20 @@ static cudaError_t THCCachingAllocator_emptyCache(void* ctx)
return a->emptyCache();
}

static cudaError_t THCCachingAllocator_cacheInfo(void* ctx, int dev_id, size_t* totalCached, size_t* largestBlock)
{
THCCachingAllocator* a = (THCCachingAllocator*) ctx;
a->cacheInfo(dev_id, totalCached, largestBlock);
return cudaSuccess;
}

static THCCachingAllocator caching_allocator;
static THCDeviceAllocator device_allocator = {
&THCCachingAllocator_malloc,
NULL,
&THCCachingAllocator_free,
&THCCachingAllocator_emptyCache,
&THCCachingAllocator_cacheInfo,
&caching_allocator
};

28 changes: 28 additions & 0 deletions lib/THC/THCGeneral.c
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@ static THCDeviceAllocator defaultDeviceAllocator = {
NULL,
&cudaFreeWrapper,
NULL,
NULL,
NULL
};

@@ -710,6 +711,33 @@ cudaError_t THCudaFree(THCState *state, void *ptr)
return allocator->free(allocator->state, ptr);
}

cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes)
{
size_t cachedBytes = 0;
size_t largestBlock = 0;
THCDeviceAllocator* allocator = state->cudaDeviceAllocator;

/* get info from CUDA first */
cudaError_t ret = cudaMemGetInfo(freeBytes, totalBytes);
if (ret!= cudaSuccess)
return ret;

int device;
ret = cudaGetDevice(&device);
if (ret!= cudaSuccess)
return ret;

/* not always true - our optimistic guess here */
largestBlock = *freeBytes;

if (allocator->cacheInfo != NULL)
allocator->cacheInfo(allocator->state, device, &cachedBytes, &largestBlock);

/* Adjust resulting free bytes number. largesBlock unused for now */
*freeBytes += cachedBytes;
return cudaSuccess;
}

static ptrdiff_t applyHeapDelta(THCState *state) {
ptrdiff_t newHeapSize = THAtomicAddPtrdiff(&heapSize, state->heapDelta) + state->heapDelta;
state->heapDelta = 0;
2 changes: 2 additions & 0 deletions lib/THC/THCGeneral.h.in
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ typedef struct _THCDeviceAllocator {
cudaError_t (*realloc)(void*, void**, size_t, size_t, cudaStream_t);
cudaError_t (*free)(void*, void*);
cudaError_t (*emptyCache)(void*);
cudaError_t (*cacheInfo)(void*, int, size_t*, size_t*);
void* state;
} THCDeviceAllocator;

@@ -177,6 +178,7 @@ THC_API void __THCublasCheck(cublasStatus_t status, const char *file, const int

THC_API cudaError_t THCudaMalloc(THCState *state, void **ptr, size_t size);
THC_API cudaError_t THCudaFree(THCState *state, void *ptr);
THC_API cudaError_t THCudaMemGetInfo(THCState *state, size_t* freeBytes, size_t* totalBytes);
THC_API void THCSetGCHandler(THCState *state,
void (*torchGCHandlerFunction)(void *data),
void *data );

0 comments on commit 838ec70

Please sign in to comment.