diff --git a/src/uct/ib/mlx5/dv/ib_mlx5dv_md.c b/src/uct/ib/mlx5/dv/ib_mlx5dv_md.c index f86fd719bc5..d4dab85484a 100644 --- a/src/uct/ib/mlx5/dv/ib_mlx5dv_md.c +++ b/src/uct/ib/mlx5/dv/ib_mlx5dv_md.c @@ -113,11 +113,12 @@ uct_ib_mlx5_devx_reg_ksm(uct_ib_mlx5_md_t *md, int atomic, uint64_t address, } static ucs_status_t -uct_ib_mlx5_devx_reg_ksm_data(uct_ib_mlx5_md_t *md, int atomic, void *address, - uct_ib_mlx5_devx_ksm_data_t *ksm_data, - size_t length, uint64_t iova, uint32_t mkey_index, - const char *reason, struct mlx5dv_devx_obj **mr_p, - uint32_t *mkey) +uct_ib_mlx5_devx_reg_ksm_data_mt(uct_ib_mlx5_md_t *md, int atomic, + void *address, + uct_ib_mlx5_devx_ksm_data_t *ksm_data, + size_t length, uint64_t iova, + uint32_t mkey_index, const char *reason, + struct mlx5dv_devx_obj **mr_p, uint32_t *mkey) { void *mr_address = address; ucs_status_t status; @@ -362,6 +363,37 @@ static void uct_ib_mlx5_devx_mr_lru_cleanup(uct_ib_mlx5_md_t *md) kh_destroy_inplace(rkeys, &md->lru_rkeys.hash); } +static ucs_status_t +uct_ib_mlx5_devx_reg_ksm_data(uct_ib_mlx5_md_t *md, uct_ib_mlx5_devx_mem_t *memh, + uct_ib_mr_type_t mr_type, uint32_t iova_offset, int atomic, uint32_t mkey_index, + const char *reason, struct mlx5dv_devx_obj **mr_p, + uint32_t *mkey) { + uct_ib_mlx5_devx_mr_t *mr = &memh->mrs[mr_type]; + uint64_t iova = (uint64_t)memh->address + iova_offset; + void *address = uct_ib_mlx5_devx_memh_base_address(memh); + ucs_status_t status; + size_t length; + + if (memh->super.flags & UCT_IB_MEM_MULTITHREADED) { + length = mr->ksm_data->length; + status = uct_ib_mlx5_devx_reg_ksm_data_mt(md, atomic, address, + mr->ksm_data, length, + iova, mkey_index, + reason, mr_p, mkey); + } else { + length = mr->super.ib->length; + status = uct_ib_mlx5_devx_reg_ksm_data_contig(md, mr, address, iova, + atomic, mkey_index, + reason, mr_p, mkey); + } + + ucs_debug("KSM registered memory %p..%p offset 0x%x%s on %s rkey 0x%x", + address, UCS_PTR_BYTE_OFFSET(address, length), + iova_offset, atomic ? " atomic" : "", + uct_ib_device_name(&md->super.dev), *mkey); + return status; +} + UCS_PROFILE_FUNC_ALWAYS(ucs_status_t, uct_ib_mlx5_devx_reg_indirect_key, (md, memh), uct_ib_mlx5_md_t *md, uct_ib_mlx5_devx_mem_t *memh) @@ -372,11 +404,10 @@ UCS_PROFILE_FUNC_ALWAYS(ucs_status_t, uct_ib_mlx5_devx_reg_indirect_key, md->super.name); do { - status = uct_ib_mlx5_devx_reg_ksm_data_contig( - md, &memh->mrs[UCT_IB_MR_DEFAULT], - uct_ib_mlx5_devx_memh_base_address(memh), - (uint64_t)memh->address, 0, 0, "indirect key", - &memh->indirect_dvmr, &memh->indirect_rkey); + status = uct_ib_mlx5_devx_reg_ksm_data(md, memh, UCT_IB_MR_DEFAULT, 0, + 0, 0, "indirect key", + &memh->indirect_dvmr, + &memh->indirect_rkey); if (status != UCS_OK) { break; } @@ -408,12 +439,9 @@ UCS_PROFILE_FUNC_ALWAYS(ucs_status_t, uct_ib_mlx5_devx_reg_atomic_key, uct_ib_mlx5_devx_mem_t *memh) { uct_ib_mr_type_t mr_type = uct_ib_md_get_atomic_mr_type(&md->super); - uct_ib_mlx5_devx_mr_t *mr = &memh->mrs[mr_type]; uint8_t mr_id = uct_ib_md_get_atomic_mr_id(&md->super); uint32_t atomic_offset = uct_ib_md_atomic_offset(mr_id); uint32_t mkey_index; - uint64_t iova; - ucs_status_t status; int is_atomic; if (memh->smkey_mr != NULL) { @@ -424,31 +452,11 @@ UCS_PROFILE_FUNC_ALWAYS(ucs_status_t, uct_ib_mlx5_devx_reg_atomic_key, } is_atomic = memh->super.flags & UCT_IB_MEM_ACCESS_REMOTE_ATOMIC; - iova = (uint64_t)memh->address + atomic_offset; - - if (memh->super.flags & UCT_IB_MEM_MULTITHREADED) { - return uct_ib_mlx5_devx_reg_ksm_data(md, is_atomic, memh->address, - mr->ksm_data, mr->ksm_data->length, - iova, mkey_index, - "multi-thread atomic key", - &memh->atomic_dvmr, - &memh->atomic_rkey); - } - - status = uct_ib_mlx5_devx_reg_ksm_data_contig( - md, mr, uct_ib_mlx5_devx_memh_base_address(memh), iova, is_atomic, - mkey_index, "atomic key", &memh->atomic_dvmr, &memh->atomic_rkey); - if (status != UCS_OK) { - return status; - } - ucs_debug("KSM registered memory %p..%p lkey 0x%x offset 0x%x%s on %s rkey " - "0x%x", - memh->address, - UCS_PTR_BYTE_OFFSET(memh->address, mr->super.ib->length), - mr->super.ib->lkey, atomic_offset, is_atomic ? " atomic" : "", - uct_ib_device_name(&md->super.dev), memh->atomic_rkey); - return UCS_OK; + return uct_ib_mlx5_devx_reg_ksm_data(md, memh, mr_type, atomic_offset, + is_atomic, mkey_index, "atomic key", + &memh->atomic_dvmr, + &memh->atomic_rkey); } static ucs_status_t @@ -494,10 +502,10 @@ uct_ib_mlx5_devx_reg_mt(uct_ib_mlx5_md_t *md, void *address, size_t length, goto err_free; } - status = uct_ib_mlx5_devx_reg_ksm_data(md, is_atomic, address, ksm_data, - length, (uint64_t)address, 0, - "multi-thread key", &ksm_data->dvmr, - mkey_p); + status = uct_ib_mlx5_devx_reg_ksm_data_mt(md, is_atomic, address, ksm_data, + length, (uint64_t)address, 0, + "multi-thread key", + &ksm_data->dvmr, mkey_p); if (status != UCS_OK) { goto err_dereg; } @@ -556,6 +564,7 @@ static void uct_ib_mlx5_devx_reg_symmetric(uct_ib_mlx5_md_t *md, uint32_t symmetric_rkey; ucs_status_t status; + ucs_assert(!(memh->super.flags & UCT_IB_MEM_MULTITHREADED)); /* Best effort, only allocate in the range below the atomic keys. */ while (md->smkey_index < md->super.mkey_by_name_reserve.size) { status = uct_ib_mlx5_devx_reg_ksm_data_contig( @@ -1977,6 +1986,7 @@ UCS_PROFILE_FUNC_ALWAYS(ucs_status_t, uct_ib_mlx5_devx_reg_exported_key, goto out_umem_mr; } + ucs_assert(!(memh->super.flags & UCT_IB_MEM_MULTITHREADED)); status = uct_ib_mlx5_devx_reg_ksm_data_contig(md, &memh->mrs[UCT_IB_MR_DEFAULT], memh->address, diff --git a/test/gtest/uct/ib/test_ib_md.cc b/test/gtest/uct/ib/test_ib_md.cc index e7ab20876df..65e290b6d87 100644 --- a/test/gtest/uct/ib/test_ib_md.cc +++ b/test/gtest/uct/ib/test_ib_md.cc @@ -27,6 +27,9 @@ class test_ib_md : public test_md uct_rkey_t *rkey_p = NULL); void check_smkeys(uct_rkey_t rkey1, uct_rkey_t rkey2); + void test_mkey_pack_mt(bool invalidate); + void test_mkey_pack_mt_internal(unsigned access_mask, bool invalidate); + private: #ifdef HAVE_MLX5_DV uint32_t m_mlx5_flags = 0; @@ -222,4 +225,73 @@ UCS_TEST_P(test_ib_md, smkey_reg_atomic) ucs_mmap_free(buffer, size); } +void +test_ib_md::test_mkey_pack_mt_internal(unsigned access_mask, bool invalidate) +{ + constexpr size_t size = 1 * UCS_MBYTE; + std::array buffer; + unsigned pack_flags, dereg_flags; + uct_mem_h memh; + + if ((access_mask & UCT_MD_MEM_ACCESS_REMOTE_ATOMIC) && is_bf_arm()) { + UCS_TEST_SKIP_R("FIXME: AMO reg key bug on BF device, skipping"); + return; + } + + if (!is_supported_pack_mem_flags(access_mask)) { + UCS_TEST_SKIP_R("memory packing is unsupported"); + } + + if (invalidate) { + pack_flags = UCT_MD_MKEY_PACK_FLAG_INVALIDATE_RMA; + dereg_flags = UCT_MD_MEM_DEREG_FLAG_INVALIDATE; + } else { + pack_flags = dereg_flags = 0; + } + + ucs_status_t status = reg_mem(access_mask, buffer.data(), size, &memh); + ASSERT_UCS_OK(status); + + /* memh isn't always registered as multithreaded due to the following error: + mlx5dv_devx_obj_create(CREATE_MKEY, mode=KSM) failed, syndrome 0x103e77: Remote I/O error + */ + + uct_ib_mem_t *ib_memh = (uct_ib_mem_t *)memh; + EXPECT_TRUE(ib_memh->flags & UCT_IB_MEM_MULTITHREADED); + + std::vector rkey(md_attr().rkey_packed_size); + uct_md_mkey_pack_params_t pack_params; + pack_params.field_mask = UCT_MD_MKEY_PACK_FIELD_FLAGS; + pack_params.flags = pack_flags; + status = uct_md_mkey_pack_v2(md(), memh, &pack_params, rkey.data()); + EXPECT_UCS_OK(status); + + uct_md_mem_dereg_params_t params; + params.field_mask = UCT_MD_MEM_DEREG_FIELD_MEMH | + UCT_MD_MEM_DEREG_FIELD_COMPLETION | + UCT_MD_MEM_DEREG_FIELD_FLAGS; + params.memh = memh; + params.flags = dereg_flags; + comp().comp.func = dereg_cb; + comp().comp.count = 1; + comp().comp.status = UCS_OK; + comp().self = this; + params.comp = &comp().comp; + status = uct_md_mem_dereg_v2(md(), ¶ms); + EXPECT_UCS_OK(status); +} + +void test_ib_md::test_mkey_pack_mt(bool invalidate) { + test_mkey_pack_mt_internal(UCT_MD_MEM_ACCESS_REMOTE_ATOMIC, invalidate); + test_mkey_pack_mt_internal(md_flags_remote_rma, invalidate); + test_mkey_pack_mt_internal(UCT_MD_MEM_ACCESS_ALL, invalidate); + test_mkey_pack_mt_internal(UCT_MD_MEM_ACCESS_RMA, invalidate); +} + +UCS_TEST_P(test_ib_md, pack_mkey_mt, "REG_MT_THRESH=128K", "REG_MT_CHUNK=128K") +{ + test_mkey_pack_mt(false); + test_mkey_pack_mt(true); +} + _UCT_MD_INSTANTIATE_TEST_CASE(test_ib_md, ib) diff --git a/test/gtest/uct/test_md.cc b/test/gtest/uct/test_md.cc index 9137e132345..0155924554f 100644 --- a/test/gtest/uct/test_md.cc +++ b/test/gtest/uct/test_md.cc @@ -100,7 +100,7 @@ void test_md::test_reg_mem(unsigned access_mask, params.flags = UCT_MD_MEM_DEREG_FLAG_INVALIDATE; params.comp = &comp().comp; - if (!is_supported_reg_mem_flags(access_mask)) { + if (!is_supported_pack_mem_flags(access_mask)) { params.field_mask = UCT_MD_MEM_DEREG_FIELD_COMPLETION | UCT_MD_MEM_DEREG_FIELD_FLAGS | UCT_MD_MEM_DEREG_FIELD_MEMH; @@ -171,7 +171,7 @@ test_md::test_md() /* coverity[uninit_member] */ } -bool test_md::is_supported_reg_mem_flags(unsigned reg_flags) const +bool test_md::is_supported_pack_mem_flags(unsigned reg_flags) const { return (reg_flags & md_flags_remote_rma) ? check_caps(UCT_MD_FLAG_INVALIDATE_RMA) : diff --git a/test/gtest/uct/test_md.h b/test/gtest/uct/test_md.h index f81fd91525d..f08340836a6 100644 --- a/test/gtest/uct/test_md.h +++ b/test/gtest/uct/test_md.h @@ -30,7 +30,7 @@ class test_md : public testing::TestWithParam, test_md(); - bool is_supported_reg_mem_flags(unsigned reg_flags) const; + bool is_supported_pack_mem_flags(unsigned reg_flags) const; bool is_bf_arm() const;