Skip to content

Commit

Permalink
Adding slice() operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
luitjens committed Aug 8, 2022
1 parent 7da1304 commit 7314ecb
Show file tree
Hide file tree
Showing 3 changed files with 327 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs_input/api/tensorops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,4 @@ Advanced Operators
.. doxygenfunction:: rcollapse
.. doxygenfunction:: lcollapse
.. doxygenfunction:: clone
.. doxygenfunction:: slice
165 changes: 165 additions & 0 deletions include/matx_tensor_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,171 @@ auto __MATX_INLINE__ remap(Op t, Ind idx, Inds... inds)
return detail::RemapOp<DIM, decltype(op) , Ind>(op, idx);
};

/**
* Slices elements from an operator/tensor.
*/
namespace detail {
template <int DIM, typename T>
class SliceOp : public BaseOp<SliceOp<DIM, T>>
{
public:
using scalar_type = typename T::scalar_type;
using shape_type = typename T::shape_type;

private:
typename base_type<T>::type op_;
std::array<shape_type, DIM> sizes_;
std::array<shape_type, DIM> dims_;
std::array<shape_type, T::Rank()> starts_;
std::array<shape_type, T::Rank()> strides_;

public:
using matxop = bool;
using matxoplvalue = bool;

static_assert(T::Rank()>0, "SliceOp: Rank of operator must be greater than 0.");
static_assert(DIM<=T::Rank(), "SliceOp: DIM must be less than or equal to operator rank.");

__MATX_INLINE__ SliceOp(T op, const shape_type (&starts)[T::Rank()], const shape_type (&ends)[T::Rank()], const shape_type (&strides)[T::Rank()]) : op_(op) {
int d = 0;
for(int i = 0; i < T::Rank(); i++) {
shape_type start = starts[i];
shape_type end = ends[i];

starts_[i] = start;
strides_[i] = strides[i];

// compute dims and sizes
if(end != matxDropDim) {
dims_[d] = i;

if(end == matxEnd) {
sizes_[d] = op.Size(i) - start;
} else {
sizes_[d] = end - start;
}

//adjust size by stride
sizes_[d] = (shape_type)std::ceil(static_cast<double>(sizes_[d])/ static_cast<double>(strides_[d]));
d++;
}
}
MATX_ASSERT_STR(d==Rank(), matxInvalidDim, "SliceOp: Number of dimensions without matxDropDim must equal new rank.");
};

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
static_assert(sizeof...(Is)==Rank());
static_assert((std::is_convertible_v<Is, index_t> && ... ));

// convert variadic type to tuple so we can read/update
std::array<index_t, Rank()> inds{indices...};
std::array<index_t, T::Rank()> ind{indices...};

#pragma unroll
for(int i = 0; i < T::Rank(); i++) {
ind[i] = starts_[i];
}

#pragma unroll
for(int i = 0; i < Rank(); i++) {
ind[dims_[i]] += inds[i] * strides_[i];
}

//return op_(ind);
return mapply(op_, ind);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& operator()(Is... indices)
{
static_assert(sizeof...(Is)==Rank());
static_assert((std::is_convertible_v<Is, index_t> && ... ));

// convert variadic type to tuple so we can read/update
std::array<index_t, Rank()> inds{indices...};
std::array<index_t, T::Rank()> ind{indices...};

#pragma unroll
for(int i = 0; i < T::Rank(); i++) {
ind[i] = starts_[i];
}

#pragma unroll
for(int i = 0; i < Rank(); i++) {
ind[dims_[i]] += inds[i] * strides_[i];
}

//return op_(ind);
return mapply(op_, ind);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
return DIM;
}
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
{
return sizes_[dim];
}

template<typename R> __MATX_INLINE__ auto operator=(const R &rhs) { return set(*this, rhs); }
};
}

template <typename T>
__MATX_INLINE__ auto slice( const T op,
const typename T::shape_type (&starts)[T::Rank()],
const typename T::shape_type (&ends)[T::Rank()],
const typename T::stride_type (&strides)[T::Rank()]) {
return detail::SliceOp<T::Rank(),T>(op, starts, ends, strides);
}

template <typename T>
__MATX_INLINE__ auto slice( const T op,
const typename T::shape_type (&starts)[T::Rank()],
const typename T::shape_type (&ends)[T::Rank()]) {
typename T::shape_type strides[T::Rank()];
for(int i = 0; i < T::Rank(); i++)
strides[i] = 1;
return detail::SliceOp<T::Rank(),T>(op, starts, ends, strides);
}

/**
* @brief Operator to logically slice a tensor or operator.
*
* The rank of the the operator must be greater than 0.
* This operator can appear as an rvalue or lvalue.
*
* @tparam N The Rank of the output operator
* @tparam T Input operator/tensor type
* @param Op Input operator
* @param starts the first element (inclusive) of each dimension of the input operator.
* @param ends the last element (exclusive) of each dimension of the input operator. matxDrop Dim removes that dimension. matxEnd deontes all remaining elements in that dimension.
* @param strides Optional: the stride between consecutive elements
* @return sliced operator
*/
template <int N, typename T>
__MATX_INLINE__ auto slice( const T op,
const typename T::shape_type (&starts)[T::Rank()],
const typename T::shape_type (&ends)[T::Rank()],
const typename T::stride_type (&strides)[T::Rank()]) {
return detail::SliceOp<N,T>(op, starts, ends, strides);
}

template <int N, typename T>
__MATX_INLINE__ auto slice( const T op,
const typename T::shape_type (&starts)[T::Rank()],
const typename T::shape_type (&ends)[T::Rank()]) {
typename T::shape_type strides[T::Rank()];
for(int i = 0; i < T::Rank(); i++)
strides[i] = 1;
return detail::SliceOp<N,T>(op, starts, ends, strides);
}


/**
* @brief Helper function to select values from a predicate operator
*
Expand Down
161 changes: 161 additions & 0 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,167 @@ TYPED_TEST(OperatorTestsNumericNonComplex, CloneOp)
MATX_EXIT_HANDLER();
}



TYPED_TEST(OperatorTestsNumericNonComplex, SliceStrideOp)
{
MATX_ENTER_HANDLER();
tensor_t<TypeParam, 1> t1{{10}};

t1.SetVals({10, 20, 30, 40, 50, 60, 70, 80, 90, 100});
auto t1t = slice(t1, {0}, {matxEnd}, {2});

for (index_t i = 0; i < t1.Size(0); i += 2) {
ASSERT_EQ(t1(i), t1t(i / 2));
}

auto t1t2 = slice(t1, {2}, {matxEnd}, {2});

for (index_t i = 0; i < t1t2.Size(0); i++) {
ASSERT_EQ(TypeParam(30 + 20 * i), t1t2(i));
}

MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsNumericNonComplex, SliceOp)
{
MATX_ENTER_HANDLER();

tensor_t<TypeParam, 2> t2{{20, 10}};
tensor_t<TypeParam, 3> t3{{30, 20, 10}};
tensor_t<TypeParam, 4> t4{{40, 30, 20, 10}};

(t2 = linspace<1>(t2.Shape(), 0, 10)).run();
(t3 = linspace<2>(t3.Shape(), 0, 10)).run();
(t4 = linspace<3>(t4.Shape(), 0, 10)).run();

auto t2t = slice(t2, {1, 2}, {3, 5});
auto t3t = slice(t3, {1, 2, 3}, {3, 5, 7});
auto t4t = slice(t4, {1, 2, 3, 4}, {3, 5, 7, 9});

ASSERT_EQ(t2t.Size(0), 2);
ASSERT_EQ(t2t.Size(1), 3);

ASSERT_EQ(t3t.Size(0), 2);
ASSERT_EQ(t3t.Size(1), 3);
ASSERT_EQ(t3t.Size(2), 4);

ASSERT_EQ(t4t.Size(0), 2);
ASSERT_EQ(t4t.Size(1), 3);
ASSERT_EQ(t4t.Size(2), 4);
ASSERT_EQ(t4t.Size(3), 5);

for (index_t i = 0; i < t2t.Size(0); i++) {
for (index_t j = 0; j < t2t.Size(1); j++) {
ASSERT_EQ(t2t(i, j), t2(i + 1, j + 2));
}
}

for (index_t i = 0; i < t3t.Size(0); i++) {
for (index_t j = 0; j < t3t.Size(1); j++) {
for (index_t k = 0; k < t3t.Size(2); k++) {
ASSERT_EQ(t3t(i, j, k), t3(i + 1, j + 2, k + 3));
}
}
}

for (index_t i = 0; i < t4t.Size(0); i++) {
for (index_t j = 0; j < t4t.Size(1); j++) {
for (index_t k = 0; k < t4t.Size(2); k++) {
for (index_t l = 0; l < t4t.Size(3); l++) {
ASSERT_EQ(t4t(i, j, k, l), t4(i + 1, j + 2, k + 3, l + 4));
}
}
}
}
MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsNumericNonComplex, SliceAndReduceOp)
{
MATX_ENTER_HANDLER();

tensor_t<TypeParam, 2> t2t{{20, 10}};
tensor_t<TypeParam, 3> t3t{{30, 20, 10}};
(t2t = linspace<1>(t2t.Shape(), 0, 10)).run();
(t3t = linspace<2>(t3t.Shape(), 0, 10)).run();

{
index_t j = 0;
auto t2sly = slice<1>(t2t, {0, j}, {matxEnd, matxDropDim});
for (index_t i = 0; i < t2sly.Size(0); i++) {
ASSERT_EQ(t2sly(i), t2t(i, j));
}
}

{
index_t i = 0;
auto t2slx = slice<1>(t2t, {i, 0}, {matxDropDim, matxEnd});
for (index_t j = 0; j < t2slx.Size(0); j++) {
ASSERT_EQ(t2slx(j), t2t(i, j));
}
}

{
index_t j = 0;
index_t k = 0;
auto t3slz = slice<1>(t3t, {0, j, k}, {matxEnd, matxDropDim, matxDropDim});
for (index_t i = 0; i < t3slz.Size(0); i++) {
ASSERT_EQ(t3slz(i), t3t(i, j, k));
}
}

{
index_t i = 0;
index_t k = 0;
auto t3sly = slice<1>(t3t, {i, 0, k}, {matxDropDim, matxEnd, matxDropDim});
for (index_t j = 0; j < t3sly.Size(0); j++) {
ASSERT_EQ(t3sly(j), t3t(i, j, k));
}
}

{
index_t i = 0;
index_t j = 0;
auto t3slx = slice<1>(t3t, {i, j, 0}, {matxDropDim, matxDropDim, matxEnd});
for (index_t k = 0; k < t3slx.Size(0); k++) {
ASSERT_EQ(t3slx(k), t3t(i, j, k));
}
}

{
index_t k = 0;
auto t3slzy = slice<2>(t3t, {0, 0, k}, {matxEnd, matxEnd, matxDropDim});
for (index_t i = 0; i < t3slzy.Size(0); i++) {
for (index_t j = 0; j < t3slzy.Size(1); j++) {
ASSERT_EQ(t3slzy(i, j), t3t(i, j, k));
}
}
}

{
index_t j = 0;
auto t3slzx = slice<2>(t3t, {0, j, 0}, {matxEnd, matxDropDim, matxEnd});
for (index_t i = 0; i < t3slzx.Size(0); i++) {
for (index_t k = 0; k < t3slzx.Size(1); k++) {
ASSERT_EQ(t3slzx(i, k), t3t(i, j, k));
}
}
}

{
index_t i = 0;
auto t3slyx = slice<2>(t3t, {i, 0, 0}, {matxDropDim, matxEnd, matxEnd});
for (index_t j = 0; j < t3slyx.Size(0); j++) {
for (index_t k = 0; k < t3slyx.Size(1); k++) {
ASSERT_EQ(t3slyx(j, k), t3t(i, j, k));
}
}
}
MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsNumericNonComplex, CollapseOp)
{
int N = 10;
Expand Down

0 comments on commit 7314ecb

Please sign in to comment.