Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed issues with static tensor unit tests compiling #615

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs_input/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ tensor:

Note that for a static tensor the shape is moved to the template parameters instead of function arguments.

.. note::
While static tensors are preferred over dynamic tensors, they currently have limitations when calling certain library functions.
These limitations may be removed over time, but in cases where they don't work a dynamic tensor will work.

After calling the make function, MatX will allocate CUDA managed memory large enough to accommodate the specified tensor size. Users can also
pass their own pointers in a different for of the ``make_`` family of functions to allow for more control over buffer types and ownership
semantics.
Expand Down
18 changes: 14 additions & 4 deletions include/matx/core/tensor_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,12 @@ class tensor_desc_t {
template <index_t I, index_t... Is>
class static_tensor_desc_t {
public:
using ShapeContainer = std::array<index_t, sizeof...(Is)>; ///< Type trait of shape type
using StrideContainer = std::array<index_t, sizeof...(Is)>; ///< Type trait of stride type
using shape_container = std::array<index_t, sizeof...(Is) + 1>; ///< Type trait of shape type
using stride_container = std::array<index_t, sizeof...(Is) + 1>; ///< Type trait of stride type
using shape_type = index_t; ///< Type trait of shape container
using stride_type = index_t; ///< Type trait of stride container
using matx_descriptor = bool; ///< Type trait to indicate this is a tensor descriptor
using matx_static_descriptor = bool; ///< Type trait to indicate this is a static tensor descriptor

/**
* Check if a descriptor is linear in memory for all elements in the view
Expand Down Expand Up @@ -380,6 +381,15 @@ class static_tensor_desc_t {
*/
static constexpr auto Stride(int dim) { return stride_[dim]; }

/**
* @brief Return strides contaienr of descriptor
*
* @return Strides container
*/
static constexpr auto Strides() {
return stride_;
}

/**
* @brief Get rank of descriptor
*
Expand Down Expand Up @@ -419,8 +429,8 @@ class static_tensor_desc_t {
return m;
}

static constexpr ShapeContainer shape_ = make_shape();
static constexpr StrideContainer stride_ = make_strides();
static constexpr shape_container shape_ = make_shape();
static constexpr stride_container stride_ = make_strides();
};

/**
Expand Down
18 changes: 18 additions & 0 deletions include/matx/core/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,24 @@ struct is_matx_descriptor<T, std::void_t<typename T::matx_descriptor>>
template <typename T>
inline constexpr bool is_matx_descriptor_v = detail::is_matx_descriptor<typename remove_cvref<T>::type>::value;

namespace detail {
template <typename T, typename = void>
struct is_matx_static_descriptor : std::false_type {
};
template <typename T>
struct is_matx_static_descriptor<T, std::void_t<typename T::matx_static_descriptor>>
: std::true_type {
};
}

/**
* @brief Determine if a type is a MatX static descriptor
*
* @tparam T Type to test
*/
template <typename T>
inline constexpr bool is_matx_static_descriptor_v = detail::is_matx_static_descriptor<typename remove_cvref<T>::type>::value;


namespace detail {
template <typename T, typename = void>
Expand Down
2 changes: 2 additions & 0 deletions include/matx/transforms/fft/fft_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ namespace detail {
index_type nom_fft_size = in_size;
index_type act_fft_size;

static_assert(!is_matx_static_descriptor_v<decltype(i.Descriptor())>, "FFTs cannot use static descriptors at this time");

// Auto-detect FFT size
if (fft_size == 0) {
act_fft_size = o.Lsize();
Expand Down
40 changes: 20 additions & 20 deletions test/00_tensor/TensorCreationTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,23 +116,23 @@ TYPED_TEST(TensorCreationTestsAll, MakeShape)
ASSERT_EQ(mt4.Size(3), 3);
}

// TYPED_TEST(TensorCreationTestsAll, MakeStaticShape)
// {
// auto mt1 = make_static_tensor<TypeParam, 10>();
// ASSERT_EQ(mt1.Size(0), 10);

// auto mt2 = make_static_tensor<float, 10, 40>();
// ASSERT_EQ(mt2.Size(0), 10);
// ASSERT_EQ(mt2.Size(1), 40);

// auto mt3 = make_static_tensor<TypeParam, 10, 40, 30>();
// ASSERT_EQ(mt3.Size(0), 10);
// ASSERT_EQ(mt3.Size(1), 40);
// ASSERT_EQ(mt3.Size(2), 30);

// auto mt4 = make_static_tensor<TypeParam, 10, 40, 30, 6>();
// ASSERT_EQ(mt4.Size(0), 10);
// ASSERT_EQ(mt4.Size(1), 40);
// ASSERT_EQ(mt4.Size(2), 30);
// ASSERT_EQ(mt4.Size(3), 6);
// }
TYPED_TEST(TensorCreationTestsAll, MakeStaticShape)
{
auto mt1 = make_static_tensor<TypeParam, 10>();
ASSERT_EQ(mt1.Size(0), 10);

auto mt2 = make_static_tensor<float, 10, 40>();
ASSERT_EQ(mt2.Size(0), 10);
ASSERT_EQ(mt2.Size(1), 40);

auto mt3 = make_static_tensor<TypeParam, 10, 40, 30>();
ASSERT_EQ(mt3.Size(0), 10);
ASSERT_EQ(mt3.Size(1), 40);
ASSERT_EQ(mt3.Size(2), 30);

auto mt4 = make_static_tensor<TypeParam, 10, 40, 30, 6>();
ASSERT_EQ(mt4.Size(0), 10);
ASSERT_EQ(mt4.Size(1), 40);
ASSERT_EQ(mt4.Size(2), 30);
ASSERT_EQ(mt4.Size(3), 6);
}