Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding permute operator #239

Merged
merged 1 commit into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs_input/api/tensorops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,4 @@ Advanced Operators
.. doxygenfunction:: lcollapse
.. doxygenfunction:: clone
.. doxygenfunction:: slice
.. doxygenfunction:: permute
148 changes: 134 additions & 14 deletions include/matx_tensor_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,25 @@ __MATX_INLINE__
}
};
}

/**
* @brief Helper function to select values from a predicate operator
*
* select() is used to index from a source operator using indices stored
* in another operator. This is commonly used with the find_idx executor
* which returns the indices of values meeting a selection criteria.
*
* @tparam T Input type
* @tparam IdxType Operator with indices
* @param t Input operator
* @param idx Index tensor
* @return Value in t from each location in idx
*/
template <typename T, typename IdxType>
auto __MATX_INLINE__ select(T t, IdxType idx)
{
return detail::SelectOp<T, IdxType>(t, idx);
};

namespace detail {
template <int CRank, typename T, typename Ind>
Expand Down Expand Up @@ -704,7 +723,7 @@ __MATX_INLINE__
{
return T::Rank();
}
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int32_t dim) const
{
if(dim == DIM)
return idx_.Size(0);
Expand Down Expand Up @@ -787,7 +806,7 @@ auto __MATX_INLINE__ remap(Op t, Ind idx, Inds... inds)
private:
typename base_type<T>::type op_;
std::array<shape_type, DIM> sizes_;
std::array<shape_type, DIM> dims_;
std::array<int32_t, DIM> dims_;
std::array<shape_type, T::Rank()> starts_;
std::array<shape_type, T::Rank()> strides_;

Expand All @@ -799,8 +818,8 @@ auto __MATX_INLINE__ remap(Op t, Ind idx, Inds... inds)
static_assert(DIM<=T::Rank(), "SliceOp: DIM must be less than or equal to operator rank.");

__MATX_INLINE__ SliceOp(T op, const shape_type (&starts)[T::Rank()], const shape_type (&ends)[T::Rank()], const shape_type (&strides)[T::Rank()]) : op_(op) {
int d = 0;
for(int i = 0; i < T::Rank(); i++) {
int32_t d = 0;
for(int32_t i = 0; i < T::Rank(); i++) {
shape_type start = starts[i];
shape_type end = ends[i];

Expand Down Expand Up @@ -836,12 +855,12 @@ auto __MATX_INLINE__ remap(Op t, Ind idx, Inds... inds)
std::array<index_t, T::Rank()> ind{indices...};

#pragma unroll
for(int i = 0; i < T::Rank(); i++) {
for(int32_t i = 0; i < T::Rank(); i++) {
ind[i] = starts_[i];
}

#pragma unroll
for(int i = 0; i < Rank(); i++) {
for(int32_t i = 0; i < Rank(); i++) {
ind[dims_[i]] += inds[i] * strides_[i];
}

Expand All @@ -856,8 +875,8 @@ auto __MATX_INLINE__ remap(Op t, Ind idx, Inds... inds)
static_assert((std::is_convertible_v<Is, index_t> && ... ));

// convert variadic type to tuple so we can read/update
std::array<index_t, Rank()> inds{indices...};
std::array<index_t, T::Rank()> ind{indices...};
std::array<shape_type, Rank()> inds{indices...};
std::array<shape_type, T::Rank()> ind{indices...};

#pragma unroll
for(int i = 0; i < T::Rank(); i++) {
Expand All @@ -877,7 +896,7 @@ auto __MATX_INLINE__ remap(Op t, Ind idx, Inds... inds)
{
return DIM;
}
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ shape_type Size(int32_t dim) const
{
return sizes_[dim];
}
Expand Down Expand Up @@ -951,13 +970,114 @@ auto __MATX_INLINE__ remap(Op t, Ind idx, Inds... inds)
* @param idx Index tensor
* @return Value in t from each location in idx
*/
template <typename T, typename IdxType>
auto __MATX_INLINE__ select(T t, IdxType idx)
{
return detail::SelectOp<T, IdxType>(t, idx);
};


/**
* permutes dimensions of a tensor/operator
*/
namespace detail {
template <typename T>
class PermuteOp : public BaseOp<PermuteOp<T>>
{
public:
using scalar_type = typename T::scalar_type;
using shape_type = typename T::shape_type;

private:
typename base_type<T>::type op_;
std::array<int32_t, T::Rank()> dims_;

public:
using matxop = bool;
using matxoplvalue = bool;

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
return T::Rank();
}

static_assert(Rank() > 0, "PermuteOp: Rank of operator must be greater than 0.");

__MATX_INLINE__ PermuteOp(T op, const int32_t (&dims)[Rank()]) : op_(op) {

bool selected[Rank()] = {0};

for(int32_t i = 0; i < Rank(); i++) {
int32_t dim = dims[i];
MATX_ASSERT_STR(dim < Rank() && dim >= 0, matxInvalidDim, "PermuteOp: Invalid permute index.");
MATX_ASSERT_STR(selected[dim] == false, matxInvalidDim, "PermuteOp: Dim selected more than once");
selected[dim] = true;

dims_[i] = dims[i];
}
};

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
static_assert(sizeof...(Is)==Rank());
static_assert((std::is_convertible_v<Is, index_t> && ... ));

// convert variadic type to tuple so we can read/update
std::array<shape_type, Rank()> inds{indices...};
std::array<shape_type, T::Rank()> ind{indices...};

#pragma unroll
for(int32_t i = 0; i < Rank(); i++) {
ind[dims_[i]] = inds[i];
//ind[i] = inds[dims_[i]];
}

//return op_(ind);
return mapply(op_, ind);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& operator()(Is... indices)
{
static_assert(sizeof...(Is)==Rank());
// static_assert((std::is_convertible_v<Is, index_t> && ... ));

// convert variadic type to tuple so we can read/update
std::array<shape_type, Rank()> inds{indices...};
std::array<shape_type, T::Rank()> ind{indices...};

#pragma unroll
for(int i = 0; i < Rank(); i++) {
ind[dims_[i]] = inds[i];
//ind[i] = inds[dims_[i]];
}

//return op_(ind);
return mapply(op_, ind);
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ shape_type Size(int32_t dim) const
{
return op_.Size(dims_[dim]);
}

template<typename R> __MATX_INLINE__ auto operator=(const R &rhs) { return set(*this, rhs); }
};
}

/**
* @brief Operator to permute the dimensions of a tensor or operator.
*
* The each dimension must appear in the dims array once.

* This operator can appear as an rvalue or lvalue.
*
* @tparam T Input operator/tensor type
* @param Op Input operator
* @param dims the reordered dimensions of the operator.
* @return permuted operator
*/
template <typename T>
__MATX_INLINE__ auto permute( const T op,
const int32_t (&dims)[T::Rank()]) {
return detail::PermuteOp<T>(op, dims);
}

/**
* Casts the element of the tensor to a specified type
Expand Down
36 changes: 36 additions & 0 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,42 @@ TYPED_TEST(OperatorTestsComplex, BaseOp)
MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsNumericNonComplex, PermuteOp)
{
MATX_ENTER_HANDLER();
auto A = make_tensor<TypeParam>({10,20,30});
for(int i=0; i < A.Size(0); i++) {
for(int j=0; j < A.Size(1); j++) {
for(int k=0; k < A.Size(2); k++) {
A(i,j,k) = TypeParam( i * A.Size(1)*A.Size(2) +
j * A.Size(2) + k);
}
}
}

auto op = permute(A, {2, 0, 1});
auto At = A.Permute({2, 0, 1});

ASSERT_TRUE(op.Size(0) == A.Size(2));
ASSERT_TRUE(op.Size(1) == A.Size(0));
ASSERT_TRUE(op.Size(2) == A.Size(1));

ASSERT_TRUE(op.Size(0) == At.Size(0));
ASSERT_TRUE(op.Size(1) == At.Size(1));
ASSERT_TRUE(op.Size(2) == At.Size(2));

for(int i=0; i < op.Size(0); i++) {
for(int j=0; j < op.Size(1); j++) {
for(int k=0; k < op.Size(2); k++) {
ASSERT_TRUE( op(i,j,k) == A(j,k,i));
ASSERT_TRUE( op(i,j,k) == At(i,j,k));
}
}
}

MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsFloatNonComplex, FMod)
{
MATX_ENTER_HANDLER();
Expand Down