Skip to content

Commit

Permalink
Faster Transpose 2D (apache#16104)
Browse files Browse the repository at this point in the history
* 2d transpose naive

* omp pragma

* omp pragma unroll

* blocksize

* make it 2d tile

* loop peeling

* better loop peeling

* redundancy

* removed bool

* removing excess for loops, memory save

* fix internal forloop

* remove commented code, lint fix

* Trigger notification

* explain params, indent fix, explain blocksize

* fix p,n and reduce for loop computation j+a,i+b

* kernel

* gpu thread 1

* remove gpu implementation

* fix internal for loop

* unittest to catch the previous error

* optimizations

* microsoft cpp doesn't support omp collapse
  • Loading branch information
ChaiBapchya authored and aaronmarkham committed Oct 16, 2019
1 parent d1897a6 commit 2d2938a
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
52 changes: 51 additions & 1 deletion src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,51 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> {
}
};


/*!
* \brief This function performs transpose operation on a 2D matrix by utilizing the L1 cache
* \param in input tensor
* \param out output tensor
* \param row shape of dim 0 of input
* \param col shape of dim 1 of input
*/
template<typename DType>
MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index_t col) {
// ensure cache line hits and prevent cache miss for any configuration
// L1 cache size to be utilized = 32kb = 2^15
// Largest size of a single unit of any dtype <= 8 byte = 2^3
// Number of elements - (2^15/2^3) = 2^12
// Block-size - 2^6 v 2^6 (64 v 64)

// But we could leverage unrolling of for loops (for parallelization)
// Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled
// blocksize * blocksize * num_threads = cache_size / dtype_size
// Instead of explicit unroll, let compiler figure out optimal unroll factor
index_t blocksize = 32;

// collapse 2 parallelizes 2 for loops
// inner 2 for loops aren't parallelized to prevent cache miss

// Microsoft Visual C++ compiler does not support omp collapse
#ifdef _MSC_VER
#pragma omp parallel for
#else
#pragma omp parallel for collapse(2)
#endif // _MSC_VER

for (index_t i = 0; i < row; i += blocksize) {
for (index_t j = 0; j < col; j += blocksize) {
// transpose the block
for (index_t a = j; (a < blocksize + j) && (a < col); ++a) {
for (index_t b = i; (b < blocksize + i) && (b < row); ++b) {
out[a * row + b] = in[b * col + a];
}
}
}
}
}


template<typename xpu>
void TransposeImpl(RunContext ctx,
const TBlob& src,
Expand Down Expand Up @@ -285,8 +330,13 @@ void TransposeImpl(RunContext ctx,
case 2: {
mshadow::Tensor<xpu, 2, DType> in = src.FlatTo2D<xpu, DType>(s);
mshadow::Tensor<xpu, 2, DType> out = ret.FlatTo2D<xpu, DType>(s);

if (axes[0] == 1 && axes[1] == 0) {
out = in.T();
if (ctx.get_ctx().dev_mask() == cpu::kDevMask) {
Transpose2D<DType>(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]);
} else {
out = in.T();
}
} else {
Copy(out, in, s);
}
Expand Down
7 changes: 7 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2874,6 +2874,13 @@ def test_transpose():
assert_allclose(np.transpose(x.asnumpy()), y.asnumpy())


@with_seed()
def test_larger_transpose():
x = mx.nd.random.normal(shape=(50,51))
y = mx.nd.transpose(x)
assert_allclose(np.transpose(x.asnumpy()), y.asnumpy())


@with_seed()
def test_expand_dims():
for ndim in range(1, 6):
Expand Down

0 comments on commit 2d2938a

Please sign in to comment.