Skip to content

Commit

Permalink
Update shift operators to take tempaltes and multiple dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
luitjens committed Jul 11, 2022
1 parent 0d33b16 commit 9b9a244
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 63 deletions.
5 changes: 1 addition & 4 deletions docs_input/api/tensorops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ Advanced Operators
.. doxygenclass:: matx::IF
.. doxygenclass:: matx::IFELSE
.. doxygenfunction:: reverse
.. doxygenfunction:: shift0
.. doxygenfunction:: shift1
.. doxygenfunction:: shift2
.. doxygenfunction:: shift3
.. doxygenfunction:: shift
.. doxygenfunction:: fftshift1D
.. doxygenfunction:: fftshift2D
.. doxygenfunction:: repmat(T1 t, index_t reps)
Expand Down
84 changes: 33 additions & 51 deletions include/matx_tensor_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1436,8 +1436,8 @@ auto __MATX_INLINE__ reverse(Op t)
* of the tensor.
*/
namespace detail {
template <typename T1, int DIM>
class ShiftOp : public BaseOp<ShiftOp<T1, DIM>>
template <int DIM, typename T1>
class ShiftOp : public BaseOp<ShiftOp<DIM, T1>>
{
private:
typename base_type<T1>::type op_;
Expand Down Expand Up @@ -1487,81 +1487,63 @@ auto __MATX_INLINE__ reverse(Op t)
}
};
}

/**
* Helper function to shift dimension 0 by a given amount
*
* @tparam T1
* Type of operator or view
* @param t
* Operator or view to shift
* @param s
* Amount to shift forward
* Operator to shift dimension by a given amount
*
* @returns
* New operator with shifted indices
*/
template <typename T1>
auto __MATX_INLINE__ shift0(T1 t, index_t s)
{
return detail::ShiftOp<T1, 0>(t, s);
};

/**
* Helper function to shift dimension 1 by a given amount
* @tparam DIM
* The dimension to be shifted
*
* @tparam T1
* @tparam Op
* Type of operator or view
* @param t
*
* @param op
* Operator or view to shift
*
* @param s
* Amount to shift forward
*
* @returns
* New operator with shifted indices
*/
template <typename T1>
auto __MATX_INLINE__ shift1(T1 t, index_t s)
template <int DIM, typename Op>
auto __MATX_INLINE__ shift(Op op, index_t s)
{
return detail::ShiftOp<T1, 1>(t, s);
return detail::ShiftOp<DIM, Op>(op, s);
};

/**
* Helper function to shift dimension 2 by a given amount

/**
* Operator to shift dimension by a given amount.
* This version allows multiple dimensions.
*
* @tparam T1
* Type of operator or view
* @param t
* Operator or view to shift
* @param s
* Amount to shift forward
* @tparam DIM
* The dimension to be shifted
*
* @returns
* New operator with shifted indices
*/
template <typename T1>
auto __MATX_INLINE__ shift2(T1 t, index_t s)
{
return detail::ShiftOp<T1, 2>(t, s);
};

/**
* Helper function to shift dimension 3 by a given amount
* @tparam DIMS...
* The dimensions targeted for shifts
*
* @tparam T1
* @tparam Op
* Type of operator or view
* @param t
*
* @param op
* Operator or view to shift
*
* @param s
* Amount to shift forward
*
* @returns
* New operator with shifted indices
*/
template <typename T1>
auto __MATX_INLINE__ shift3(T1 t, index_t s)
template <int DIM, int... DIMS, typename Op, typename... Shifts>
auto __MATX_INLINE__ shift(Op op, index_t s, Shifts... shifts)
{
return detail::ShiftOp<T1, 3>(t, s);
static_assert(sizeof...(DIMS) == sizeof...(shifts), "shift: number of DIMs must match number of shifts");

// recursively call shift on remaining bits
auto rop = shift<DIMS...>(op, shifts...);

// construct remap op
return detail::ShiftOp<DIM, decltype(rop)>(rop, s);
};

namespace detail {
Expand Down
16 changes: 8 additions & 8 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)
}

{
(t2s = shift0(t2, 5)).run();
(t2s = shift<0>(t2, 5)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand All @@ -1853,7 +1853,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)
}

{
(t2s = shift1(t2, 5)).run();
(t2s = shift<1>(t2, 5)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand All @@ -1865,7 +1865,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)
}

{
(t2s = shift0(shift1(t2, 5), 6)).run();
(t2s = shift<1,0>(t2, 5, 6)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand Down Expand Up @@ -1904,7 +1904,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)

// Negative shifts
{
(t2s = shift0(t2, -5)).run();
(t2s = shift<0>(t2, -5)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand All @@ -1916,7 +1916,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)
}

{
(t2s = shift1(t2, -5)).run();
(t2s = shift<1>(t2, -5)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand All @@ -1929,7 +1929,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)

// Large shifts
{
(t2s = shift0(t2, t2.Size(0) * 4)).run();
(t2s = shift<0>(t2, t2.Size(0) * 4)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand All @@ -1942,8 +1942,8 @@ TYPED_TEST(OperatorTestsNumeric, Shift)
{
// Shift 4 times the size back, minus one. This should be equivalent to
// simply shifting by -1
(t2s = shift0(t2, -t2.Size(0) * 4 - 1)).run();
(t2s2 = shift0(t2, -1)).run();
(t2s = shift<0>(t2, -t2.Size(0) * 4 - 1)).run();
(t2s2 = shift<0>(t2, -1)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand Down

0 comments on commit 9b9a244

Please sign in to comment.