Skip to content
This repository has been archived by the owner on Aug 3, 2021. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
borisfom committed Oct 3, 2016
2 parents b928ca0 + afd74e7 commit 6933617
Show file tree
Hide file tree
Showing 14 changed files with 329 additions and 304 deletions.
4 changes: 0 additions & 4 deletions FFI.lua
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,10 @@ typedef struct THCState
{
struct THCRNGState* rngState;
struct cudaDeviceProp* deviceProperties;
cudaStream_t currentStream;
cublasHandle_t currentBlasHandle;
THCCudaResourcesPerDevice* resourcesPerDevice;
int numDevices;
int numUserStreams;
int numUserBlasHandles;
int currentPerDeviceStream;
int currentPerDeviceBlasHandle;
struct THAllocator* cudaHostAllocator;
} THCState;
Expand Down
22 changes: 7 additions & 15 deletions init.c
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ static int cutorch_setStream(lua_State *L)
{
THCState *state = cutorch_getstate(L);
int stream = (int) luaL_checknumber(L, 1);
THCState_setStreamForCurrentDevice(state, stream);
THCState_setCurrentStreamIndex(state, stream);

return 0;
}
Expand All @@ -366,7 +366,7 @@ static int cutorch_setBlasHandle(lua_State *L)
{
THCState *state = cutorch_getstate(L);
int handle = (int) luaL_checknumber(L, 1);
THCState_setBlasHandleForCurrentDevice(state, handle);
THCState_setCurrentBlasHandleIndex(state, handle);

return 0;
}
Expand Down Expand Up @@ -408,8 +408,7 @@ static int cutorch_getBlasHandle(lua_State *L)
static int cutorch_setDefaultStream(lua_State *L)
{
THCState *state = cutorch_getstate(L);
THCState_setStreamForCurrentDevice(state, 0);

THCState_setCurrentStreamIndex(state, 0);

return 0;
}
Expand Down Expand Up @@ -719,12 +718,6 @@ static int cutorch_setDevice(lua_State *L)
THCState *state = cutorch_getstate(L);
int device = (int)luaL_checknumber(L, 1)-1;
THCudaCheck(cudaSetDevice(device));
THCRandom_setGenerator(state, device);

/* The stream is per device, so update the stream as well */
THCState_setStream(state, device, THCState_getCurrentStreamIndex(state));
THCState_setBlasHandle(state, device, THCState_getCurrentBlasHandleIndex(state));

return 0;
}

Expand Down Expand Up @@ -903,7 +896,7 @@ static int cutorch_shutdown(lua_State *L)
{
THCState **state = (THCState **) lua_topointer(L, 1);
THCudaShutdown(*state);
free(*state);
THCState_free(*state);
return 0;
}

Expand Down Expand Up @@ -955,18 +948,17 @@ int luaopen_libcutorch(lua_State *L)
lua_setglobal(L, "cutorch");
luaL_setfuncs(L, cutorch_stuff__, 0);

THCState* state = (THCState*)malloc(sizeof(THCState));
memset(state, 0, sizeof(THCState));
THCState* state = THCState_alloc();

char* thc_caching_allocator = getenv("THC_CACHING_ALLOCATOR");
if (thc_caching_allocator && strcmp(thc_caching_allocator, "1") == 0) {
THCCachingAllocator_init(&state->cudaDeviceAllocator);
THCCachingAllocator_init(THCState_getDeviceAllocator(state));
}

THCudaInit(state);

/* Register torch.CudaHostAllocator. */
luaT_pushudata(L, state->cudaHostAllocator, "torch.Allocator");
luaT_pushudata(L, THCState_getCudaHostAllocator(state), "torch.Allocator");
lua_setfield(L, -2, "CudaHostAllocator");

#ifdef USE_MAGMA
Expand Down
11 changes: 6 additions & 5 deletions lib/THC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ IF(MAGMA_FOUND)
MESSAGE(STATUS "MAGMA INCLUDE DIRECTORIES: ${MAGMA_INCLUDE_DIR}")
MESSAGE(STATUS "MAGMA LIBRARIES: ${MAGMA_LIBRARIES}")
MESSAGE(STATUS "MAGMA V2 check: ${MAGMA_V2}")
# This is required for MAGMA which has C++ symbols
IF ($ENV{TH_BINARY_BUILD})
MESSAGE(STATUS "TH_BINARY_BUILD detected. Statically linking libstdc++")
SET(CMAKE_CXX_FLAGS "-static-libstdc++ ${CMAKE_CXX_FLAGS}")
ENDIF()
ELSE(MAGMA_FOUND)
MESSAGE(STATUS "MAGMA not found. Compiling without MAGMA support")
ENDIF(MAGMA_FOUND)

IF ($ENV{TH_BINARY_BUILD})
MESSAGE(STATUS "TH_BINARY_BUILD detected. Statically linking libstdc++")
SET(CMAKE_CXX_FLAGS "-static-libstdc++ ${CMAKE_CXX_FLAGS}")
ENDIF()

IF(APPLE)
IF(${CUDA_VERSION} LESS 6.0)
# work around for mac os x bug:
Expand Down Expand Up @@ -126,6 +126,7 @@ SET(src
THCStorageCopy.c
THCTensor.c
THCTensorCopy.c
THCThreadLocal.c
)

SET(src-cuda
Expand Down
13 changes: 4 additions & 9 deletions lib/THC/THCAllocator.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,8 @@ static void *THCudaHostAllocator_realloc(void* ctx, void* ptr, long size) {
return ptr;
}

void THCAllocator_init(THCState *state) {
state->cudaHostAllocator = (THAllocator*)malloc(sizeof(THAllocator));
state->cudaHostAllocator->malloc = &THCudaHostAllocator_alloc;
state->cudaHostAllocator->realloc = &THCudaHostAllocator_realloc;
state->cudaHostAllocator->free = &THCudaHostAllocator_free;
}

void THCAllocator_shutdown(THCState *state) {
free(state->cudaHostAllocator);
void THCAllocator_init(THAllocator *cudaHostAllocator) {
cudaHostAllocator->malloc = &THCudaHostAllocator_alloc;
cudaHostAllocator->realloc = &THCudaHostAllocator_realloc;
cudaHostAllocator->free = &THCudaHostAllocator_free;
}
3 changes: 1 addition & 2 deletions lib/THC/THCAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "THCGeneral.h"

THC_API void THCAllocator_init(THCState *state);
THC_API void THCAllocator_shutdown(THCState *state);
THC_API void THCAllocator_init(THAllocator *state);

#endif
63 changes: 47 additions & 16 deletions lib/THC/THCBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ float THCudaBlas_Sdot(THCState *state, long n, float *x, long incx, float *y, lo
int i_incx = (int)incx;
int i_incy = (int)incy;
float result;
THCublasCheck(cublasSdot(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy, &result));
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSdot(handle, i_n, x, i_incx, y, i_incy, &result));
return result;
}

Expand All @@ -35,7 +37,9 @@ double THCudaBlas_Ddot(THCState *state, long n, double *x, long incx, double *y,
int i_incx = (int)incx;
int i_incy = (int)incy;
double result;
THCublasCheck(cublasDdot(THCState_getCurrentBlasHandle(state), i_n, x, i_incx, y, i_incy, &result));
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDdot(handle, i_n, x, i_incx, y, i_incy, &result));
return result;
}

Expand Down Expand Up @@ -66,7 +70,9 @@ void THCudaBlas_Sgemv(THCState *state, char trans, long m, long n, float alpha,
int i_incx = (int)incx;
int i_incy = (int)incy;

THCublasCheck(cublasSgemv(THCState_getCurrentBlasHandle(state), op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgemv(handle, op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
return;
}
THError("Cublas_Sgemv only supports m, n, lda, incx, incy"
Expand Down Expand Up @@ -94,7 +100,9 @@ void THCudaBlas_Dgemv(THCState *state, char trans, long m, long n, double alpha,
int i_incx = (int)incx;
int i_incy = (int)incy;

THCublasCheck(cublasDgemv(THCState_getCurrentBlasHandle(state), op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgemv(handle, op, i_m, i_n, &alpha, a, i_lda, x, i_incx, &beta, y, i_incy));
return;
}
THError("Cublas_Dgemv only supports m, n, lda, incx, incy"
Expand All @@ -114,7 +122,9 @@ void THCudaBlas_Sger(THCState *state, long m, long n, float alpha, float *x, lon
int i_incx = (int)incx;
int i_incy = (int)incy;

THCublasCheck(cublasSger(THCState_getCurrentBlasHandle(state), i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSger(handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
return;
}
THError("Cublas_Sger only supports m, n, lda, incx, incy"
Expand All @@ -134,7 +144,9 @@ void THCudaBlas_Dger(THCState *state, long m, long n, double alpha, double *x, l
int i_incx = (int)incx;
int i_incy = (int)incy;

THCublasCheck(cublasDger(THCState_getCurrentBlasHandle(state), i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDger(handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
return;
}
THError("Cublas_Dger only supports m, n, lda, incx, incy"
Expand Down Expand Up @@ -199,7 +211,9 @@ void THCudaBlas_Sgemm(THCState *state, char transa, char transb, long m, long n,
int i_ldb = (int)ldb;
int i_ldc = (int)ldc;

THCublasCheck(cublasSgemm(THCState_getCurrentBlasHandle(state), opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc));
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgemm(handle, opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc));
return;
}
THError("Cublas_Sgemm only supports m, n, k, lda, ldb, ldc"
Expand Down Expand Up @@ -227,17 +241,20 @@ void THCudaBlas_Hgemm(THCState *state, char transa, char transb, long m, long n,
int i_ldb = (int)ldb;
int i_ldc = (int)ldc;

cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));

// Check for native Hgemm support
if (THC_nativeHalfInstructions(state)) {
THCublasCheck(cublasHgemm(THCState_getCurrentBlasHandle(state), opa, opb,
THCublasCheck(cublasHgemm(handle, opa, opb,
i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb,
&beta, c, i_ldc));
} else {
// Simulated Hgemm
float fAlpha = THC_half2float(alpha);
float fBeta = THC_half2float(beta);

THCublasCheck(cublasSgemmEx(THCState_getCurrentBlasHandle(state), opa, opb,
THCublasCheck(cublasSgemmEx(handle, opa, opb,
i_m, i_n, i_k, &fAlpha,
a, CUDA_R_16F, i_lda, b, CUDA_R_16F,
i_ldb, &fBeta, c, CUDA_R_16F, i_ldc));
Expand Down Expand Up @@ -265,7 +282,9 @@ void THCudaBlas_Dgemm(THCState *state, char transa, char transb, long m, long n,
int i_ldb = (int)ldb;
int i_ldc = (int)ldc;

THCublasCheck(cublasDgemm(THCState_getCurrentBlasHandle(state), opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc));
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgemm(handle, opa, opb, i_m, i_n, i_k, &alpha, a, i_lda, b, i_ldb, &beta, c, i_ldc));
return;
}
THError("Cublas_Dgemm only supports m, n, k, lda, ldb, ldc"
Expand All @@ -287,7 +306,9 @@ void THCudaBlas_SgemmBatched(THCState *state, char transa, char transb, long m,
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);

THCublasCheck(cublasSgemmBatched(THCState_getCurrentBlasHandle(state),
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgemmBatched(handle,
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
(int)batchCount));
Expand All @@ -307,7 +328,9 @@ void THCudaBlas_DgemmBatched(THCState *state, char transa, char transb, long m,
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);

THCublasCheck(cublasDgemmBatched(THCState_getCurrentBlasHandle(state),
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgemmBatched(handle,
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, b, (int)ldb, &beta, c, (int)ldc,
(int)batchCount));
Expand All @@ -320,7 +343,9 @@ void THCudaBlas_Sgetrf(THCState *state, int n, float **a, int lda, int *pivot, i
THError("Cublas_Sgetrf only supports n, lda, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
THCublasCheck(cublasSgetrfBatched(THCState_getCurrentBlasHandle(state), n, a, lda, pivot, info, batchSize));
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgetrfBatched(handle, n, a, lda, pivot, info, batchSize));
}

void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot, int *info, int batchSize) {
Expand All @@ -329,7 +354,9 @@ void THCudaBlas_Dgetrf(THCState *state, int n, double **a, int lda, int *pivot,
THError("Cublas_Dgetrf only supports n, lda, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
THCublasCheck(cublasDgetrfBatched(THCState_getCurrentBlasHandle(state), n, a, lda, pivot, info, batchSize));
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgetrfBatched(handle, n, a, lda, pivot, info, batchSize));
}

void THCudaBlas_Sgetri(THCState *state, int n, const float **a, int lda, int *pivot, float **c, int ldc, int *info, int batchSize) {
Expand All @@ -339,7 +366,9 @@ void THCudaBlas_Sgetri(THCState *state, int n, const float **a, int lda, int *pi
THError("Cublas_Sgetri only supports n, lda, ldc, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
THCublasCheck(cublasSgetriBatched(THCState_getCurrentBlasHandle(state), n, a, lda, pivot, c, ldc, info, batchSize));
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasSgetriBatched(handle, n, a, lda, pivot, c, ldc, info, batchSize));
}

void THCudaBlas_Dgetri(THCState *state, int n, const double **a, int lda, int *pivot, double **c, int ldc, int *info, int batchSize) {
Expand All @@ -349,5 +378,7 @@ void THCudaBlas_Dgetri(THCState *state, int n, const double **a, int lda, int *p
THError("Cublas_Dgetri only supports n, lda, ldc, batchSize"
"with the bound [val] <= %d", INT_MAX);
}
THCublasCheck(cublasDgetriBatched(THCState_getCurrentBlasHandle(state), n, a, lda, pivot, c, ldc, info, batchSize));
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgetriBatched(handle, n, a, lda, pivot, c, ldc, info, batchSize));
}
5 changes: 2 additions & 3 deletions lib/THC/THCCachingAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct THCCachingAllocator
FreeBlocks small_blocks;

// allocated blocks by device pointer
std::unordered_map<char*, Block*> allocated_blocks;
std::unordered_map<void*, Block*> allocated_blocks;

THCCachingAllocator() :
large_blocks(BlockComparator),
Expand Down Expand Up @@ -141,13 +141,12 @@ struct THCCachingAllocator
return cudaSuccess;
}

auto it = allocated_blocks.find((char*)ptr);
auto it = allocated_blocks.find(ptr);
if (it == allocated_blocks.end()) {
return cudaErrorInvalidDevicePointer;
}

Block* block = it->second;
int device = block->device;
allocated_blocks.erase(it);

bool small = block->size <= kSmallAlloc;
Expand Down
Loading

0 comments on commit 6933617

Please sign in to comment.