From 064d12a3c6ded6fd634525f1811298962507ca27 Mon Sep 17 00:00:00 2001 From: luitjens Date: Thu, 26 Jan 2023 09:08:44 -0800 Subject: [PATCH] add support for std::array as shapes in slice --- include/matx/core/tensor.h | 28 ++++++++++--- include/matx/operators/slice.h | 75 +++++++++++++++++++++++++++------- 2 files changed, 82 insertions(+), 21 deletions(-) diff --git a/include/matx/core/tensor.h b/include/matx/core/tensor.h index f893232cd..0d00a2d80 100644 --- a/include/matx/core/tensor.h +++ b/include/matx/core/tensor.h @@ -1521,9 +1521,9 @@ class tensor_t : public detail::tensor_impl_t { * */ template - __MATX_INLINE__ auto Slice([[maybe_unused]] const typename Desc::shape_type (&firsts)[RANK], - [[maybe_unused]] const typename Desc::shape_type (&ends)[RANK], - [[maybe_unused]] const typename Desc::stride_type (&strides)[RANK]) const + __MATX_INLINE__ auto Slice([[maybe_unused]] const std::array &firsts, + [[maybe_unused]] const std::array &ends, + [[maybe_unused]] const std::array &strides) const { static_assert(N <= RANK && RANK > 0, "Must slice to a rank the same or less than current rank."); @@ -1578,6 +1578,14 @@ class tensor_t : public detail::tensor_impl_t { tensor_desc_t new_desc{std::move(n), std::move(s)}; return tensor_t{storage_, std::move(new_desc), data}; } + + template + __MATX_INLINE__ auto Slice(const typename Desc::shape_type (&firsts)[RANK], + const typename Desc::shape_type (&ends)[RANK], + const typename Desc::stride_type (&strides)[RANK]) const + { + return Slice(detail::to_array(firsts), detail::to_array(ends), detail::to_array(strides)); + } /** * Slice a tensor either within the same dimension or to a lower dimension @@ -1604,17 +1612,25 @@ class tensor_t : public detail::tensor_impl_t { * */ template - __MATX_INLINE__ auto Slice(const typename Desc::shape_type (&firsts)[RANK], - const typename Desc::shape_type (&ends)[RANK]) const + __MATX_INLINE__ auto Slice(const std::array &firsts, + const std::array &ends) const { static_assert(N <= RANK && RANK > 0, "Must slice to a rank the same or less than current rank."); MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + + const std::array strides = {-1}; - const typename Desc::stride_type strides[RANK] = {-1}; return Slice(firsts, ends, strides); } + template + __MATX_INLINE__ auto Slice(const typename Desc::shape_type (&firsts)[RANK], + const typename Desc::shape_type (&ends)[RANK]) const + { + return Slice(detail::to_array(firsts), detail::to_array(ends)); + } + /** * Print a value diff --git a/include/matx/operators/slice.h b/include/matx/operators/slice.h index a77f11d71..ec611e528 100644 --- a/include/matx/operators/slice.h +++ b/include/matx/operators/slice.h @@ -65,7 +65,9 @@ namespace matx __MATX_INLINE__ std::string str() const { return "slice(" + op_.str() + ")"; } - __MATX_INLINE__ SliceOp(T op, const shape_type (&starts)[T::Rank()], const shape_type (&ends)[T::Rank()], const shape_type (&strides)[T::Rank()]) : op_(op) { + __MATX_INLINE__ SliceOp(T op, const std::array &starts, + const std::array &ends, + const std::array &strides) : op_(op) { int32_t d = 0; for(int32_t i = 0; i < T::Rank(); i++) { shape_type start = starts[i]; @@ -169,9 +171,9 @@ namespace matx */ template __MATX_INLINE__ auto slice( const OpType &op, - const index_t (&starts)[OpType::Rank()], - const index_t (&ends)[OpType::Rank()], - const index_t (&strides)[OpType::Rank()]) + const std::array &starts, + const std::array &ends, + const std::array &strides) { if constexpr (is_tensor_view_v) { return op.Slice(starts, ends, strides); @@ -180,6 +182,18 @@ namespace matx } } + template + __MATX_INLINE__ auto slice( const OpType &op, + const index_t (&starts)[OpType::Rank()], + const index_t (&ends)[OpType::Rank()], + const index_t (&strides)[OpType::Rank()]) + { + return slice(op, + detail::to_array(starts), + detail::to_array(ends), + detail::to_array(strides)); + } + /** * @brief Operator to logically slice a tensor or operator. * @@ -194,14 +208,23 @@ namespace matx * @return sliced operator */ template + __MATX_INLINE__ auto slice( const OpType &op, + const std::array &starts, + const std::array &ends) + { + std::array strides; + strides.fill(1); + + return slice(op, starts, ends, strides); + } + template __MATX_INLINE__ auto slice( const OpType &op, const index_t (&starts)[OpType::Rank()], const index_t (&ends)[OpType::Rank()]) { - index_t strides[OpType::Rank()]; - for(int i = 0; i < OpType::Rank(); i++) - strides[i] = 1; - return slice(op, starts, ends, strides); + return slice(op, + detail::to_array(starts), + detail::to_array(ends)); } /** @@ -223,9 +246,9 @@ namespace matx */ template __MATX_INLINE__ auto slice( const OpType op, - const index_t (&starts)[OpType::Rank()], - const index_t (&ends)[OpType::Rank()], - const index_t (&strides)[OpType::Rank()]) + const std::array &starts, + const std::array &ends, + const std::array &strides) { if constexpr (is_tensor_view_v) { return op.template Slice(starts, ends, strides); @@ -234,6 +257,18 @@ namespace matx } } + template + __MATX_INLINE__ auto slice( const OpType op, + const index_t (&starts)[OpType::Rank()], + const index_t (&ends)[OpType::Rank()], + const index_t (&strides)[OpType::Rank()]) + { + return slice(op, + detail::to_array(starts), + detail::to_array(ends), + detail::to_array(strides)); + } + /** * @brief Operator to logically slice a tensor or operator. * @@ -250,14 +285,24 @@ namespace matx * @param ends the last element (exclusive) of each dimension of the input operator. matxDrop Dim removes that dimension. matxEnd deontes all remaining elements in that dimension. * @return sliced operator */ + template + __MATX_INLINE__ auto slice (const OpType opIn, + const std::array &starts, + const std::array &ends) + { + std::array strides; + for(int i = 0; i < OpType::Rank(); i++) + strides[i] = 1; + return slice(opIn, starts, ends, strides); + } + template __MATX_INLINE__ auto slice (const OpType opIn, const index_t (&starts)[OpType::Rank()], const index_t (&ends)[OpType::Rank()]) { - typename OpType::shape_type strides[OpType::Rank()]; - for (int i = 0; i < OpType::Rank(); i++) - strides[i] = 1; - return slice(opIn, starts, ends, strides); + return slice(opIn, + detail::to_array(starts), + detail::to_array(ends)); } } // end namespace matx