Skip to content

Commit

Permalink
add support for std::array as shapes in slice
Browse files Browse the repository at this point in the history
  • Loading branch information
luitjens committed Jan 26, 2023
1 parent 8a65b85 commit a57951f
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 21 deletions.
28 changes: 22 additions & 6 deletions include/matx/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1521,9 +1521,9 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
*
*/
template <int N = RANK>
__MATX_INLINE__ auto Slice([[maybe_unused]] const typename Desc::shape_type (&firsts)[RANK],
[[maybe_unused]] const typename Desc::shape_type (&ends)[RANK],
[[maybe_unused]] const typename Desc::stride_type (&strides)[RANK]) const
__MATX_INLINE__ auto Slice([[maybe_unused]] const std::array<typename Desc::shape_type, RANK> &firsts,
[[maybe_unused]] const std::array<typename Desc::shape_type, RANK> &ends,
[[maybe_unused]] const std::array<typename Desc::stride_type, RANK> &strides) const
{
static_assert(N <= RANK && RANK > 0, "Must slice to a rank the same or less than current rank.");

Expand Down Expand Up @@ -1578,6 +1578,14 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
tensor_desc_t<decltype(n), decltype(s), N> new_desc{std::move(n), std::move(s)};
return tensor_t<T, N, Storage, decltype(new_desc)>{storage_, std::move(new_desc), data};
}

template <int N = RANK>
__MATX_INLINE__ auto Slice(const typename Desc::shape_type (&firsts)[RANK],
const typename Desc::shape_type (&ends)[RANK],
const typename Desc::stride_type (&strides)[RANK]) const
{
return Slice<N>(detail::to_array(firsts), detail::to_array(ends), detail::to_array(strides));
}

/**
* Slice a tensor either within the same dimension or to a lower dimension
Expand All @@ -1604,17 +1612,25 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
*
*/
template <int N = RANK>
__MATX_INLINE__ auto Slice(const typename Desc::shape_type (&firsts)[RANK],
const typename Desc::shape_type (&ends)[RANK]) const
__MATX_INLINE__ auto Slice(const std::array<typename Desc::shape_type, RANK> &firsts,
const std::array<typename Desc::shape_type, RANK> &ends) const
{
static_assert(N <= RANK && RANK > 0, "Must slice to a rank the same or less than current rank.");

MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

const std::array<typename Desc::stride_type, RANK> strides = {-1};

const typename Desc::stride_type strides[RANK] = {-1};
return Slice<N>(firsts, ends, strides);
}

template <int N = RANK>
__MATX_INLINE__ auto Slice(const typename Desc::shape_type (&firsts)[RANK],
const typename Desc::shape_type (&ends)[RANK]) const
{
return Slice<N>(detail::to_array(firsts), detail::to_array(ends));
}


/**
* Print a value
Expand Down
75 changes: 60 additions & 15 deletions include/matx/operators/slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ namespace matx

__MATX_INLINE__ std::string str() const { return "slice(" + op_.str() + ")"; }

__MATX_INLINE__ SliceOp(T op, const shape_type (&starts)[T::Rank()], const shape_type (&ends)[T::Rank()], const shape_type (&strides)[T::Rank()]) : op_(op) {
__MATX_INLINE__ SliceOp(T op, const std::array<shape_type, T::Rank()> &starts,
const std::array<shape_type, T::Rank()> &ends,
const std::array<shape_type, T::Rank()> &strides) : op_(op) {
int32_t d = 0;
for(int32_t i = 0; i < T::Rank(); i++) {
shape_type start = starts[i];
Expand Down Expand Up @@ -169,9 +171,9 @@ namespace matx
*/
template <typename OpType>
__MATX_INLINE__ auto slice( const OpType &op,
const index_t (&starts)[OpType::Rank()],
const index_t (&ends)[OpType::Rank()],
const index_t (&strides)[OpType::Rank()])
const std::array<index_t, OpType::Rank()> &starts,
const std::array<index_t, OpType::Rank()> &ends,
const std::array<index_t, OpType::Rank()> &strides)
{
if constexpr (is_tensor_view_v<OpType>) {
return op.Slice(starts, ends, strides);
Expand All @@ -180,6 +182,18 @@ namespace matx
}
}

template <typename OpType>
__MATX_INLINE__ auto slice( const OpType &op,
const index_t (&starts)[OpType::Rank()],
const index_t (&ends)[OpType::Rank()],
const index_t (&strides)[OpType::Rank()])
{
return slice(op,
detail::to_array(starts),
detail::to_array(ends),
detail::to_array(strides));
}

/**
* @brief Operator to logically slice a tensor or operator.
*
Expand All @@ -194,14 +208,23 @@ namespace matx
* @return sliced operator
*/
template <typename OpType>
__MATX_INLINE__ auto slice( const OpType &op,
const std::array<index_t, OpType::Rank()> &starts,
const std::array<index_t, OpType::Rank()> &ends)
{
std::array<index_t, OpType::Rank()> strides;
strides.fill(1);

return slice(op, starts, ends, strides);
}
template <typename OpType>
__MATX_INLINE__ auto slice( const OpType &op,
const index_t (&starts)[OpType::Rank()],
const index_t (&ends)[OpType::Rank()])
{
index_t strides[OpType::Rank()];
for(int i = 0; i < OpType::Rank(); i++)
strides[i] = 1;
return slice(op, starts, ends, strides);
return slice(op,
detail::to_array(starts),
detail::to_array(ends));
}

/**
Expand All @@ -223,9 +246,9 @@ namespace matx
*/
template <int N, typename OpType>
__MATX_INLINE__ auto slice( const OpType op,
const index_t (&starts)[OpType::Rank()],
const index_t (&ends)[OpType::Rank()],
const index_t (&strides)[OpType::Rank()])
const std::array<index_t, OpType::Rank()> &starts,
const std::array<index_t, OpType::Rank()> &ends,
const std::array<index_t, OpType::Rank()> &strides)
{
if constexpr (is_tensor_view_v<OpType>) {
return op.template Slice<N>(starts, ends, strides);
Expand All @@ -234,6 +257,18 @@ namespace matx
}
}

template <int N, typename OpType>
__MATX_INLINE__ auto slice( const OpType op,
const index_t (&starts)[OpType::Rank()],
const index_t (&ends)[OpType::Rank()],
const index_t (&strides)[OpType::Rank()])
{
return slice<N,OpType>(op,
detail::to_array(starts),
detail::to_array(ends),
detail::to_array(strides));
}

/**
* @brief Operator to logically slice a tensor or operator.
*
Expand All @@ -250,14 +285,24 @@ namespace matx
* @param ends the last element (exclusive) of each dimension of the input operator. matxDrop Dim removes that dimension. matxEnd deontes all remaining elements in that dimension.
* @return sliced operator
*/
template <int N, typename OpType>
__MATX_INLINE__ auto slice (const OpType opIn,
const std::array<index_t, OpType::Rank()> &starts,
const std::array<index_t, OpType::Rank()> &ends)
{
std::array<index_t, OpType::Rank()> strides;
for(int i = 0; i < OpType::Rank(); i++)
strides[i] = 1;
return slice<N,OpType>(opIn, starts, ends, strides);
}

template <int N, typename OpType>
__MATX_INLINE__ auto slice (const OpType opIn,
const index_t (&starts)[OpType::Rank()],
const index_t (&ends)[OpType::Rank()])
{
typename OpType::shape_type strides[OpType::Rank()];
for (int i = 0; i < OpType::Rank(); i++)
strides[i] = 1;
return slice<N, OpType>(opIn, starts, ends, strides);
return slice<N,OpType>(opIn,
detail::to_array(starts),
detail::to_array(ends));
}
} // end namespace matx

0 comments on commit a57951f

Please sign in to comment.