Skip to content

Commit

Permalink
Update reverse/shift APIs (#207)
Browse files Browse the repository at this point in the history
* remove reverseXYZW and replace with just reverse<>

* Update shift operators to take tempaltes and multiple dimensions
  • Loading branch information
luitjens authored Jul 11, 2022
1 parent 9a57e92 commit 1f5cd6e
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 117 deletions.
10 changes: 2 additions & 8 deletions docs_input/api/tensorops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,8 @@ Advanced Operators

.. doxygenclass:: matx::IF
.. doxygenclass:: matx::IFELSE
.. doxygenfunction:: reverseX
.. doxygenfunction:: reverseY
.. doxygenfunction:: reverseZ
.. doxygenfunction:: reverseW
.. doxygenfunction:: shift0
.. doxygenfunction:: shift1
.. doxygenfunction:: shift2
.. doxygenfunction:: shift3
.. doxygenfunction:: reverse
.. doxygenfunction:: shift
.. doxygenfunction:: fftshift1D
.. doxygenfunction:: fftshift2D
.. doxygenfunction:: repmat(T1 t, index_t reps)
Expand Down
2 changes: 1 addition & 1 deletion include/matx_corr.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void corr(OutputTensor &o, const In1Type &i1, const In2Type &i2,
MATX_ASSERT_STR(method == MATX_C_METHOD_DIRECT, matxNotSupported,
"Only direct correlation method supported at this time");

auto i2r = reverseX(conj(i2));
auto i2r = reverse<In2Type::Rank()-1>(conj(i2));
conv1d(o, i1, i2r, mode, stream);
}

Expand Down
159 changes: 62 additions & 97 deletions include/matx_tensor_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -882,8 +882,8 @@ auto __MATX_INLINE__ as_uint8(T t)
*
*/
namespace detail {
template <typename T1, int DIM>
class ReverseOp : public BaseOp<ReverseOp<T1, DIM>>
template <int DIM, typename T1>
class ReverseOp : public BaseOp<ReverseOp<DIM, T1>>
{
private:
typename base_type<T1>::type op_;
Expand Down Expand Up @@ -928,55 +928,38 @@ auto __MATX_INLINE__ as_uint8(T t)
};
}

/**
* Helper function to reverse the indexing of the last dimension of a tensor
/**
* @brief Operator to logically reverse elements of an operator. Base case for variadic template.
*
* Requires a tensor of at least rank 1
* @tparam DIM Dimension to apply the reverse
* @tparam Op Input operator/tensor type
* @param t Input operator
*/
template <typename T1>
auto __MATX_INLINE__ reverseX(T1 t)
template <int DIM, typename Op>
auto __MATX_INLINE__ reverse(Op t)
{
MATX_STATIC_ASSERT(T1::Rank() > 0, matxInvalidDim);
return detail::ReverseOp<T1, T1::Rank() - 1>(t);
return detail::ReverseOp<DIM, Op>(t);
};

/**
* Helper function to reverse the indexing of the second-to-last
* dimension of a tensor
/**
* @brief Operator to logically reverse elements of an operator.
*
* Requires a tensor of at least rank 2
*/
template <typename T1>
auto __MATX_INLINE__ reverseY(T1 t)
{
MATX_STATIC_ASSERT(T1::Rank() > 1, matxInvalidDim);
return detail::ReverseOp<T1, T1::Rank() - 2>(t);
};

/**
* Helper function to reverse the indexing of the third-to-last
* dimension of a tensor
* This operator can appear as an rvalue or lvalue.
*
* Requires a tensor of at least rank 3
* @tparam DIM Dimension to apply the reverse
* @tparam DIMS... list of multiple dimensions to reverse along
* @tparam Op Input operator/tensor type
* @param t Input operator
*/
template <typename T1>
auto __MATX_INLINE__ reverseZ(T1 t)
{
MATX_STATIC_ASSERT(T1::Rank() > 2, matxInvalidDim);
return detail::ReverseOp<T1, T1::Rank() - 3>(t);
};
template <int DIM1, int DIM2, int... DIMS, typename Op>
auto __MATX_INLINE__ reverse(Op t)
{
// recursively call remap on remaining bits
auto op = reverse<DIM2, DIMS...>(t);

/**
* Helper function to reverse the indexing of the first dimension of a tensor
*
* Requires a tensor of rank 4
*/
template <typename T1>
auto __MATX_INLINE__ reverseW(T1 t)
{
MATX_STATIC_ASSERT(T1::Rank() > 3, matxInvalidDim);
return detail::ReverseOp<T1, T1::Rank() - 4>(t);
};
// construct remap op
return detail::ReverseOp<DIM1, decltype(op)>(op);
};

/**
* Flip the vertical axis of a tensor.
Expand All @@ -986,10 +969,10 @@ auto __MATX_INLINE__ as_uint8(T t)
{
if constexpr (T1::Rank() == 1)
{
return detail::ReverseOp<T1, T1::Rank() - 1>(t);
return detail::ReverseOp<T1::Rank() - 1 , T1>(t);
}

return detail::ReverseOp<T1, T1::Rank() - 2>(t);
return detail::ReverseOp<T1::Rank() - 2, T1>(t);
};

/**
Expand All @@ -1000,10 +983,10 @@ auto __MATX_INLINE__ as_uint8(T t)
{
if constexpr (T1::Rank() == 1)
{
return detail::ReverseOp<T1, T1::Rank() - 1>(t);
return detail::ReverseOp<T1::Rank() - 1, T1>(t);
}

return detail::ReverseOp<T1, T1::Rank() - 1>(t);
return detail::ReverseOp<T1::Rank() - 1, T1>(t);
};

/**
Expand Down Expand Up @@ -1453,8 +1436,8 @@ auto __MATX_INLINE__ as_uint8(T t)
* of the tensor.
*/
namespace detail {
template <typename T1, int DIM>
class ShiftOp : public BaseOp<ShiftOp<T1, DIM>>
template <int DIM, typename T1>
class ShiftOp : public BaseOp<ShiftOp<DIM, T1>>
{
private:
typename base_type<T1>::type op_;
Expand Down Expand Up @@ -1504,81 +1487,63 @@ auto __MATX_INLINE__ as_uint8(T t)
}
};
}

/**
* Helper function to shift dimension 0 by a given amount
* Operator to shift dimension by a given amount
*
* @tparam T1
* Type of operator or view
* @param t
* Operator or view to shift
* @param s
* Amount to shift forward
* @tparam DIM
* The dimension to be shifted
*
* @returns
* New operator with shifted indices
*/
template <typename T1>
auto __MATX_INLINE__ shift0(T1 t, index_t s)
{
return detail::ShiftOp<T1, 0>(t, s);
};

/**
* Helper function to shift dimension 1 by a given amount
*
* @tparam T1
* @tparam Op
* Type of operator or view
* @param t
*
* @param op
* Operator or view to shift
*
* @param s
* Amount to shift forward
*
* @returns
* New operator with shifted indices
*/
template <typename T1>
auto __MATX_INLINE__ shift1(T1 t, index_t s)
template <int DIM, typename Op>
auto __MATX_INLINE__ shift(Op op, index_t s)
{
return detail::ShiftOp<T1, 1>(t, s);
return detail::ShiftOp<DIM, Op>(op, s);
};

/**
* Helper function to shift dimension 2 by a given amount

/**
* Operator to shift dimension by a given amount.
* This version allows multiple dimensions.
*
* @tparam T1
* Type of operator or view
* @param t
* Operator or view to shift
* @param s
* Amount to shift forward
* @tparam DIM
* The dimension to be shifted
*
* @returns
* New operator with shifted indices
*/
template <typename T1>
auto __MATX_INLINE__ shift2(T1 t, index_t s)
{
return detail::ShiftOp<T1, 2>(t, s);
};

/**
* Helper function to shift dimension 3 by a given amount
* @tparam DIMS...
* The dimensions targeted for shifts
*
* @tparam T1
* @tparam Op
* Type of operator or view
* @param t
*
* @param op
* Operator or view to shift
*
* @param s
* Amount to shift forward
*
* @returns
* New operator with shifted indices
*/
template <typename T1>
auto __MATX_INLINE__ shift3(T1 t, index_t s)
template <int DIM, int... DIMS, typename Op, typename... Shifts>
auto __MATX_INLINE__ shift(Op op, index_t s, Shifts... shifts)
{
return detail::ShiftOp<T1, 3>(t, s);
static_assert(sizeof...(DIMS) == sizeof...(shifts), "shift: number of DIMs must match number of shifts");

// recursively call shift on remaining bits
auto rop = shift<DIMS...>(op, shifts...);

// construct remap op
return detail::ShiftOp<DIM, decltype(rop)>(rop, s);
};

namespace detail {
Expand Down
22 changes: 11 additions & 11 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)
}

{
(t2s = shift0(t2, 5)).run();
(t2s = shift<0>(t2, 5)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand All @@ -1853,7 +1853,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)
}

{
(t2s = shift1(t2, 5)).run();
(t2s = shift<1>(t2, 5)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand All @@ -1865,7 +1865,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)
}

{
(t2s = shift0(shift1(t2, 5), 6)).run();
(t2s = shift<1,0>(t2, 5, 6)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand Down Expand Up @@ -1904,7 +1904,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)

// Negative shifts
{
(t2s = shift0(t2, -5)).run();
(t2s = shift<0>(t2, -5)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand All @@ -1916,7 +1916,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)
}

{
(t2s = shift1(t2, -5)).run();
(t2s = shift<1>(t2, -5)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand All @@ -1929,7 +1929,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)

// Large shifts
{
(t2s = shift0(t2, t2.Size(0) * 4)).run();
(t2s = shift<0>(t2, t2.Size(0) * 4)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand All @@ -1942,8 +1942,8 @@ TYPED_TEST(OperatorTestsNumeric, Shift)
{
// Shift 4 times the size back, minus one. This should be equivalent to
// simply shifting by -1
(t2s = shift0(t2, -t2.Size(0) * 4 - 1)).run();
(t2s2 = shift0(t2, -1)).run();
(t2s = shift<0>(t2, -t2.Size(0) * 4 - 1)).run();
(t2s2 = shift<0>(t2, -1)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand Down Expand Up @@ -1971,7 +1971,7 @@ TYPED_TEST(OperatorTestsNumeric, Reverse)
}

{
(t2r = reverseY(t2)).run();
(t2r = reverse<0>(t2)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand All @@ -1983,7 +1983,7 @@ TYPED_TEST(OperatorTestsNumeric, Reverse)
}

{
(t2r = reverseX(t2)).run();
(t2r = reverse<1>(t2)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand All @@ -1995,7 +1995,7 @@ TYPED_TEST(OperatorTestsNumeric, Reverse)
}

{
(t2r = reverseX(reverseY(t2))).run();
(t2r = reverse<0,1>(t2)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
Expand Down

0 comments on commit 1f5cd6e

Please sign in to comment.