Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index* for new types #461

Merged
merged 7 commits into from
Aug 12, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,11 @@ Compared to version 1.0, these are the following API changes:
| operators | 1.0 | master |
|---|---|---|
| `lt`, `le`, `gt`, `ge`, `eq`, `ne` return type | torch.CudaTensor | torch.CudaByteTensor |
| `min`,`max` (2-nd output) | torch.CudaTensor | torch.CudaLongTensor |
| `maskedFill`, `maskedCopy` (mask input) | torch.CudaTensor | torch.CudaByteTensor |
| `min`,`max` (2nd return value) | torch.CudaTensor | torch.CudaLongTensor |
| `maskedFill`, `maskedCopy` (mask input) | torch.CudaTensor | torch.CudaByteTensor |
| `topk`, `sort` (2nd return value) | torch.CudaTensor | torch.CudaLongTensor |

## Inconsistencies with CPU API

| operators | CPU | CUDA |
|---|---|---|
| `topk`, `sort` (2-nd output) | torch.LongTensor | torch.CudaTensor |
| `index`, `indexAdd`, `indexFill`, `indexCopy` input | torch.LongTensor | torch.CudaTensor (or torch.LongTensor) |
| `scatter`, `gather` | torch.LongTensor | torch.CudaTensor |
|---|---|---|
30 changes: 20 additions & 10 deletions TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,16 @@ for k, Tensor_ in pairs(handledTypenames) do
{name=real}}
)

wrap("sort",
cname("sort"),
{{name=Tensor, default=true, returned=true},
{name="CudaLongTensor", default=true, returned=true, noreadadd=true},
{name=Tensor},
{name="index", default=lastdim(3)},
{name="boolean", default=0}}
)


-- BLAS functions
if real == 'float' or real == 'double' or real == 'half' then
wrap("mv",
Expand Down Expand Up @@ -1063,20 +1073,20 @@ wrap("scatter",
wrap("sort",
cname("sort"),
{{name=Tensor, default=true, returned=true},
{name=Tensor, default=true, returned=true, noreadadd=true},
{name=Tensor},
{name="index", default=lastdim(3)},
{name="boolean", default=0}})
{name="CudaLongTensor", default=true, returned=true, noreadadd=true},
{name=Tensor},
{name="index", default=lastdim(3)},
{name="boolean", default=0}})

wrap("topk",
cname("topk"),
{{name=Tensor, default=true, returned=true},
{name=Tensor, default=true, returned=true, noreadadd=true},
{name=Tensor},
{name="long", default=1},
{name="index", default=lastdim(3)},
{name="boolean", default=0},
{name="boolean", default=0}})
{name="CudaLongTensor", default=true, returned=true, noreadadd=true},
{name=Tensor},
{name="long", default=1},
{name="index", default=lastdim(3)},
{name="boolean", default=0},
{name="boolean", default=0}})

do
local Tensor = Tensor
Expand Down
5 changes: 4 additions & 1 deletion lib/THC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ INSTALL(FILES
THCTensorRandom.h
THCTensorMath.h
THCTensorConv.h
THCTensorSort.h
THCTensorTopK.h
THCApply.cuh
THCReduce.cuh
Expand Down Expand Up @@ -213,4 +212,8 @@ INSTALL(FILES
generic/THCTensorMathReduce.cu
generic/THCTensorScatterGather.h
generic/THCTensorScatterGather.cu
generic/THCTensorIndex.h
generic/THCTensorIndex.cu
generic/THCTensorSort.h
generic/THCTensorSort.cu
DESTINATION "${THC_INSTALL_INCLUDE_SUBDIR}/THC/generic")
1 change: 0 additions & 1 deletion lib/THC/THC.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "THCTensorRandom.h"
#include "THCTensorMath.h"
#include "THCTensorConv.h"
#include "THCTensorSort.h"
#include "THCTensorTopK.h"

#endif
13 changes: 7 additions & 6 deletions lib/THC/THCSortUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@

#include "THCReduceApplyUtils.cuh"
#include "THCTensorTypeUtils.cuh"
#include "THCNumerics.cuh"

// Collection of kernel sort routines
template <typename T>
struct LTComp {
__device__ inline bool operator()(const T& a, const T& b) const {
return (a < b);
return THCNumerics<T>::lt(a, b);
}
};

template <typename T>
struct GTComp {
__device__ inline bool operator()(const T& a, const T& b) const {
return (a > b);
return THCNumerics<T>::gt(a, b);
}
};

Expand Down Expand Up @@ -127,19 +128,19 @@ bitonicSortKVInPlace(TensorInfo<K, IndexType> keys,

bool valid1 = (elem1 < keySliceSize);
K k1 = valid1 ?
keys.data[keyStartOffset + elem1 * keySliceStride] : (K) 0;
keys.data[keyStartOffset + elem1 * keySliceStride] : ScalarConvert<int, K>::to(0);
V v1 = valid1 ?
values.data[valueStartOffset + elem1 * valueSliceStride] : (V) 0;
values.data[valueStartOffset + elem1 * valueSliceStride] : ScalarConvert<int, V>::to(0);

sharedKeys[elem1] = k1;
sharedValues[elem1] = v1;
sharedValid[elem1] = valid1;

bool valid2 = (elem2 < keySliceSize);
K k2 = valid2 ?
keys.data[keyStartOffset + elem2 * keySliceStride] : (K) 0;
keys.data[keyStartOffset + elem2 * keySliceStride] : ScalarConvert<int, K>::to(0);
V v2 = valid2 ?
values.data[valueStartOffset + elem2 * valueSliceStride] : (V) 0;
values.data[valueStartOffset + elem2 * valueSliceStride] : ScalarConvert<int, V>::to(0);

sharedKeys[elem2] = k2;
sharedValues[elem2] = v2;
Expand Down
Loading