Skip to content
This repository has been archived by the owner on Aug 3, 2021. It is now read-only.

Commit

Permalink
Merge pull request torch#558 from gchanan/genericDeviceTensorUtils
Browse files Browse the repository at this point in the history
Add generic type support for toDeviceTensor.
  • Loading branch information
soumith authored Oct 19, 2016
2 parents 8c3f15a + 0d6ae18 commit 21ad069
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 53 deletions.
1 change: 1 addition & 0 deletions lib/THC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -275,4 +275,5 @@ INSTALL(FILES
generic/THCTensorIndex.cu
generic/THCTensorSort.h
generic/THCTensorSort.cu
generic/THCDeviceTensorUtils.cu
DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC/generic")
34 changes: 0 additions & 34 deletions lib/THC/THCDeviceTensorUtils-inl.cuh
Original file line number Diff line number Diff line change
@@ -1,37 +1,3 @@
#include <limits>

template <typename T, int Dim,
typename IndexT, template <typename U> class PtrTraits>
THCDeviceTensor<T, Dim, IndexT, PtrTraits>
toDeviceTensor(THCState* state, THCudaTensor* t) {
if (Dim != THCudaTensor_nDimension(state, t)) {
THError("THCudaTensor dimension mismatch");
}

// Determine the maximum offset into the tensor achievable; `IndexT`
// must be smaller than this type in order to use it.
ptrdiff_t maxOffset = 0;
IndexT sizes[Dim];
IndexT strides[Dim];

for (int i = 0; i < Dim; ++i) {
long size = THCudaTensor_size(state, t, i);
long stride = THCudaTensor_stride(state, t, i);

maxOffset += (size - 1) * stride;

sizes[i] = (IndexT) size;
strides[i] = (IndexT) stride;
}

if (maxOffset > std::numeric_limits<IndexT>::max()) {
THError("THCudaTensor sizes too large for THCDeviceTensor conversion");
}

return THCDeviceTensor<T, Dim, IndexT, PtrTraits>(
THCudaTensor_data(state, t), sizes, strides);
}

namespace detail {

// Add a layer of SFINAE to support static_assert
Expand Down
23 changes: 4 additions & 19 deletions lib/THC/THCDeviceTensorUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,7 @@

#include "THCDeviceTensor.cuh"
#include "THCTensor.h"

/// Constructs a THCDeviceTensor initialized from a THCudaTensor. Will
/// error if the dimensionality does not match exactly.
template <typename T, int Dim,
typename IndexT, template <typename U> class PtrTraits>
THCDeviceTensor<T, Dim, IndexT, PtrTraits>
toDeviceTensor(THCState* state, THCudaTensor* t);

template <typename T, int Dim, typename IndexT>
THCDeviceTensor<T, Dim, IndexT, DefaultPtrTraits>
toDeviceTensor(THCState* state, THCudaTensor* t) {
return toDeviceTensor<T, Dim, IndexT, DefaultPtrTraits>(state, t);
}

template <typename T, int Dim>
THCDeviceTensor<T, Dim, int, DefaultPtrTraits>
toDeviceTensor(THCState* state, THCudaTensor* t) {
return toDeviceTensor<T, Dim, int, DefaultPtrTraits>(state, t);
}
#include <limits>

/// Constructs a DeviceTensor initialized from a THCudaTensor by
/// upcasting or downcasting the tensor to that of a different
Expand All @@ -43,6 +25,9 @@ toDeviceTensorCast(THCState* state, THCudaTensor* t) {
return toDeviceTensorCast<T, Dim, int, DefaultPtrTraits>(state, t);
}

#include "generic/THCDeviceTensorUtils.cu"
#include "THCGenerateAllTypes.h"

#include "THCDeviceTensorUtils-inl.cuh"

#endif // THC_DEVICE_TENSOR_UTILS_INC
55 changes: 55 additions & 0 deletions lib/THC/generic/THCDeviceTensorUtils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCDeviceTensorUtils.cu"
#else

/// Constructs a THCDeviceTensor initialized from a THCudaTensor. Will
/// error if the dimensionality does not match exactly.
template <typename T, int Dim,
typename IndexT, template <typename U> class PtrTraits>
THCDeviceTensor<T, Dim, IndexT, PtrTraits>
toDeviceTensor(THCState* state, THCTensor* t);

template <typename T, int Dim, typename IndexT>
THCDeviceTensor<T, Dim, IndexT, DefaultPtrTraits>
toDeviceTensor(THCState* state, THCTensor* t) {
return toDeviceTensor<T, Dim, IndexT, DefaultPtrTraits>(state, t);
}

template <typename T, int Dim>
THCDeviceTensor<T, Dim, int, DefaultPtrTraits>
toDeviceTensor(THCState* state, THCTensor* t) {
return toDeviceTensor<T, Dim, int, DefaultPtrTraits>(state, t);
}

template <typename T, int Dim,
typename IndexT, template <typename U> class PtrTraits>
THCDeviceTensor<T, Dim, IndexT, PtrTraits>
toDeviceTensor(THCState* state, THCTensor* t) {
if (Dim != THCTensor_(nDimension)(state, t)) {
THError("THCudaTensor dimension mismatch");
}
// Determine the maximum offset into the tensor achievable; `IndexT`
// must be smaller than this type in order to use it.
ptrdiff_t maxOffset = 0;
IndexT sizes[Dim];
IndexT strides[Dim];

for (int i = 0; i < Dim; ++i) {
long size = THCTensor_(size)(state, t, i);
long stride = THCTensor_(stride)(state, t, i);

maxOffset += (size - 1) * stride;

sizes[i] = (IndexT) size;
strides[i] = (IndexT) stride;
}

if (maxOffset > std::numeric_limits<IndexT>::max()) {
THError("THCudaTensor sizes too large for THCDeviceTensor conversion");
}

return THCDeviceTensor<T, Dim, IndexT, PtrTraits>(
THCTensor_(data)(state, t), sizes, strides);
}

#endif

0 comments on commit 21ad069

Please sign in to comment.