Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sxjscience committed Dec 28, 2017
1 parent 5cf2d8a commit 399ba02
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 82 deletions.
19 changes: 0 additions & 19 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,25 +136,6 @@ inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const OpContext& ctx,
});
}


template<typename DType, typename IType>
inline void ScatterNDAccForwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s) {
for (int i = 0; i < N; i++) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
out[offset + j] += data[i * K + j];
}
}
}

DMLC_REGISTER_PARAMETER(EmbeddingParam);
DMLC_REGISTER_PARAMETER(TakeParam);
DMLC_REGISTER_PARAMETER(OneHotParam);
Expand Down
43 changes: 0 additions & 43 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -179,49 +179,6 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
});
}

template<typename DType, typename IType>
__global__ void ScatterNDAccForwardImplKernel(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
atomicAdd(out + (offset + j), data[i * K + j]);
}
}
}

struct scatter_nd_acc_gpu {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out, const DType* data,
const IType* indices) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
atomicAdd(out + (offset + j), data[i * K + j]);
}
}
};

template<typename DType, typename IType>
inline void ScatterNDAccForwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<gpu> *s) {
mxnet_op::Kernel<scatter_nd_acc_gpu, gpu>::Launch(s, N, N, M, K, strides, out, data, indices);
}

NNVM_REGISTER_OP(Embedding)
.set_attr<FCompute>("FCompute<gpu>", EmbeddingOpForward<gpu>);

Expand Down
43 changes: 23 additions & 20 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1143,21 +1143,25 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
});
}

template<typename DType, typename IType>
inline void ScatterNDAccForwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s);

template<typename DType, typename IType>
inline void ScatterNDAccForwardImpl(int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<gpu> *s);
struct scatter_nd_acc {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, int N, int M, int K,
const mshadow::Shape<10> strides,
DType* out, const DType* data,
const IType* indices) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
#if __CUDA__
atomicAdd(out + (offset + j), data[i * K + j]);
#else
out[offset + j] += data[i * K + j];
#endif
}
}
};

template<typename xpu>
void ScatterNDAccForward(const nnvm::NodeAttrs& attrs,
Expand All @@ -1182,11 +1186,10 @@ void ScatterNDAccForward(const nnvm::NodeAttrs& attrs,
}
MXNET_NO_INT8_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
ScatterNDAccForwardImpl(N, M, K, strides,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<IType>(),
s);
mxnet_op::Kernel<scatter_nd_acc, xpu>::Launch(s, N, N, M, K, strides,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<IType>());
});
});
}
Expand Down

0 comments on commit 399ba02

Please sign in to comment.