From 5a17d4b3341637f80129f70dff87ab1d411e0e8b Mon Sep 17 00:00:00 2001 From: Thomas Benson Date: Thu, 30 Nov 2023 15:59:04 -0800 Subject: [PATCH] Fix matvec output dims to match A rather than B For matvecs, the batch dimensions for A and B should match and the final output dimension should match dim Rank-1 from A. Also generalize batching support so that the size of out_dims_ is based on the output rank. --- include/matx/operators/matvec.h | 7 ++++--- test/00_transform/MatMul.cu | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/include/matx/operators/matvec.h b/include/matx/operators/matvec.h index 6199d07c..8021021e 100644 --- a/include/matx/operators/matvec.h +++ b/include/matx/operators/matvec.h @@ -48,8 +48,9 @@ namespace matx OpB b_; float alpha_; float beta_; - std::array out_dims_; - mutable matx::tensor_t tmp_out_; + static constexpr int RANK = remove_cvref_t::Rank(); + std::array out_dims_; + mutable matx::tensor_t tmp_out_; public: using matxop = bool; @@ -65,7 +66,7 @@ namespace matx a_(A), b_(B), alpha_(alpha), beta_(beta) { for (int r = 0; r < Rank(); r++) { - out_dims_[r] = b_.Size(r); + out_dims_[r] = a_.Size(r); } } diff --git a/test/00_transform/MatMul.cu b/test/00_transform/MatMul.cu index 7aba606d..f7da7a4b 100644 --- a/test/00_transform/MatMul.cu +++ b/test/00_transform/MatMul.cu @@ -682,6 +682,12 @@ TYPED_TEST(MatMulTestFloatTypes, MediumMatVec) (cs = matvec(a, bs)).run(); // example-end matvec-test-1 + // Test the rank/size of the matvec operator + auto a_times_bs = matvec(a, bs); + ASSERT_EQ(a_times_bs.Rank(), 1); + ASSERT_EQ(a_times_bs.Size(0), m); + ASSERT_EQ(cs.Size(0), m); + MATX_TEST_ASSERT_COMPARE(this->pb, c, "c", this->thresh); // Test also with rank-1 tensors rather than just slices @@ -693,6 +699,26 @@ TYPED_TEST(MatMulTestFloatTypes, MediumMatVec) MATX_TEST_ASSERT_COMPARE(this->pb, c, "c", this->thresh); + // Test with batching + constexpr index_t batch1 = 5; + constexpr index_t batch2 = 9; + auto a_batch = clone<4>(a, {batch1, batch2, matxKeepDim, matxKeepDim}); + auto b_batch = clone<3>(bs, {batch1, batch2, matxKeepDim}); + auto batched_matvec = matvec(a_batch, b_batch); + ASSERT_EQ(batched_matvec.Rank(), 3); + ASSERT_EQ(batched_matvec.Size(0), batch1); + ASSERT_EQ(batched_matvec.Size(1), batch2); + ASSERT_EQ(batched_matvec.Size(2), m); + auto result = make_tensor(batched_matvec.Shape()); + (result = batched_matvec).run(); + for (index_t i = 0; i < batch1; i++) { + for (index_t j = 0; j < batch2; j++) { + auto rs = slice<1>(result, {i,j,0}, {matxDropDim,matxDropDim,matxEnd}); + auto rsc = clone<2>(rs, {matxKeepDim,1}); + MATX_TEST_ASSERT_COMPARE(this->pb, rsc, "c", this->thresh); + } + } + MATX_EXIT_HANDLER(); }