Skip to content

Commit

Permalink
Merge pull request openucx#7 from xinzhao3/topic/cuda-md
Browse files Browse the repository at this point in the history
Topic/cuda: Add memory domain detect functions in UCT.
  • Loading branch information
bureddy authored Sep 8, 2017
2 parents ef5fd72 + fba2aa6 commit eeea863
Show file tree
Hide file tree
Showing 19 changed files with 120 additions and 11 deletions.
4 changes: 3 additions & 1 deletion src/ucp/api/ucp_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ typedef struct ucp_rkey *ucp_rkey_h;
*/
typedef struct ucp_mem *ucp_mem_h;

typedef uint64_t *ucp_addr_dn_h;
typedef struct ucp_addr_dn {
uint64_t mask;
} ucp_addr_dn_h;


/**
Expand Down
8 changes: 5 additions & 3 deletions src/ucp/core/ucp_mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ static ucs_status_t ucp_memh_reg_mds(ucp_context_h context, ucp_mem_h memh,
return UCS_OK;
}

static ucs_status_t ucp_adddr_domain_detect_mds(ucp_context_h context, ucp_addr_dn_h addr_dn)
ucs_status_t ucp_addr_domain_detect_mds(ucp_context_h context, void *addr,
ucp_addr_dn_h *addr_dn)
{
ucs_status_t status;
unsigned md_index;
Expand All @@ -117,16 +118,17 @@ static ucs_status_t ucp_adddr_domain_detect_mds(ucp_context_h context, ucp_addr_
for (md_index = 0; md_index < context->num_mds; ++md_index) {
if (context->tl_mds[md_index].attr.cap.flags & UCT_MD_FLAG_ADDR_DN) {
if(!(addr_dn->mask & context->tl_mds[md_index].attr.cap.addr_dn_mask)) {
dn_mask = 0;

status = uct_md_mem_detect(context->tl_mds[md_index].md, memh, memh->address,
memh->length, &dn_mask);
status = uct_md_mem_detect(context->tl_mds[md_index].md, addr, &dn_mask);
if (status != UCS_OK) {
return status;
}
addr_dn->mask |= dn_mask;
}
}
}

return UCS_OK;
}
/**
Expand Down
3 changes: 3 additions & 0 deletions src/ucp/core/ucp_mm.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ void ucp_mpool_free(ucs_mpool_t *mp, void *chunk);

void ucp_mpool_obj_init(ucs_mpool_t *mp, void *obj, void *chunk);

ucs_status_t ucp_addr_domain_detect_mds(ucp_context_h context, void *addr,
ucp_addr_dn_h *addr_dn);

static UCS_F_ALWAYS_INLINE uct_mem_h
ucp_memh2uct(ucp_mem_h memh, ucp_md_index_t md_idx)
{
Expand Down
1 change: 1 addition & 0 deletions src/ucp/core/ucp_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ typedef void (*ucp_request_callback_t)(ucp_request_t *req);
struct ucp_request {
ucs_status_t status; /* Operation status */
uint16_t flags; /* Request flags */
ucp_addr_dn_h dn_mask; /* Memory domain mask */

union {
struct {
Expand Down
2 changes: 2 additions & 0 deletions src/ucp/tag/tag_recv.c
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ ucp_tag_recv_request_init(ucp_request_t *req, ucp_worker_h worker, void* buffer,
req->recv.state.offset = 0;
req->recv.worker = worker;

ucp_addr_domain_detect_mds(worker->context, buffer, &(req->dn_mask));

switch (datatype & UCP_DATATYPE_CLASS_MASK) {
case UCP_DATATYPE_IOV:
req->recv.state.dt.iov.iov_offset = 0;
Expand Down
2 changes: 2 additions & 0 deletions src/ucp/tag/tag_send.c
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ static void ucp_tag_send_req_init(ucp_request_t* req, ucp_ep_h ep,
#if ENABLE_ASSERT
req->send.lane = UCP_NULL_LANE;
#endif

ucp_addr_domain_detect_mds(ep->worker->context, (void *)buffer, &(req->dn_mask));
}

UCS_PROFILE_FUNC(ucs_status_ptr_t, ucp_tag_send_nb,
Expand Down
12 changes: 12 additions & 0 deletions src/uct/api/uct.h
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,18 @@ ucs_status_t uct_md_mem_reg(uct_md_h md, void *address, size_t length,
ucs_status_t uct_md_mem_dereg(uct_md_h md, uct_mem_h memh);


/**
* @ingroup UCT_MD
* @brief Detect memory on the memory domain.
*
* Detect memory on the memory domain. Return memory domain in domain mask.
*
* @param [in] md Memory domain to register memory on.
* @param [in] address Memory address to detect.
* @param [out] dn_mask Filled with memory domain mask.
*/
ucs_status_t uct_md_mem_detect(uct_md_h md, void *addr, uint64_t *dn_mask);

/**
* @ingroup UCT_MD
* @brief Allocate memory for zero-copy communications and remote access.
Expand Down
5 changes: 5 additions & 0 deletions src/uct/base/uct_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -506,3 +506,8 @@ ucs_status_t uct_md_mem_dereg(uct_md_h md, uct_mem_h memh)
{
return md->ops->mem_dereg(md, memh);
}

ucs_status_t uct_md_mem_detect(uct_md_h md, void *addr, uint64_t *dn_mask)
{
return md->ops->mem_detect(md, addr, dn_mask);
}
2 changes: 2 additions & 0 deletions src/uct/base/uct_md.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ struct uct_md_ops {
ucs_status_t (*mem_dereg)(uct_md_h md, uct_mem_h memh);

ucs_status_t (*mkey_pack)(uct_md_h md, uct_mem_h memh, void *rkey_buffer);

ucs_status_t (*mem_detect)(uct_md_h md, void *addr, uint64_t *dn_mask);
};


Expand Down
36 changes: 35 additions & 1 deletion src/uct/cuda/cuda_copy/cuda_copy_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,39 @@ static ucs_status_t uct_cuda_copy_mem_dereg(uct_md_h md, uct_mem_h memh)
return UCS_OK;
}

static ucs_status_t uct_cuda_copy_mem_detect(uct_md_h md, void *addr, uint64_t *dn_mask)
{
#if HAVE_CUDA
int memory_type;
cudaError_t cuda_err = cudaSuccess;
struct cudaPointerAttributes attributes;
CUresult cu_err = CUDA_SUCCESS;

(*dn_mask) = 0;

if (addr == NULL) {
return UCS_OK;
}

cu_err = cuPointerGetAttribute(&memory_type,
CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
(CUdeviceptr)addr);
if (cu_err != CUDA_SUCCESS) {
cuda_err = cudaPointerGetAttributes (&attributes, addr);
if (cuda_err == cudaSuccess) {
if (attributes.memoryType == cudaMemoryTypeDevice) {
(*dn_mask) = UCT_MD_ADDR_DOMAIN_CUDA;
}
}
} else if (memory_type == CU_MEMORYTYPE_DEVICE) {
(*dn_mask) = UCT_MD_ADDR_DOMAIN_CUDA;
}
#else
(*dn_mask) = 0;
#endif
return UCS_OK;
}

static ucs_status_t uct_cuda_copy_query_md_resources(uct_md_resource_desc_t **resources_p,
unsigned *num_resources_p)
{
Expand All @@ -98,7 +131,8 @@ static ucs_status_t uct_cuda_copy_md_open(const char *md_name, const uct_md_conf
.query = uct_cuda_copy_md_query,
.mkey_pack = uct_cuda_copy_mkey_pack,
.mem_reg = uct_cuda_copy_mem_reg,
.mem_dereg = uct_cuda_copy_mem_dereg
.mem_dereg = uct_cuda_copy_mem_dereg,
.mem_detect = uct_cuda_copy_mem_detect
};
static uct_md_t md = {
.ops = &md_ops,
Expand Down
38 changes: 37 additions & 1 deletion src/uct/cuda/gdr_copy/gdr_copy_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <ucs/sys/sys.h>
#include <ucs/debug/memtrack.h>
#include <ucs/type/class.h>
#include <cuda_runtime.h>
#include <cuda.h>


static ucs_status_t uct_gdr_copy_md_query(uct_md_h md, uct_md_attr_t *md_attr)
Expand Down Expand Up @@ -71,6 +73,39 @@ static ucs_status_t uct_gdr_copy_mem_dereg(uct_md_h md, uct_mem_h memh)
return UCS_OK;
}

static ucs_status_t uct_gdr_copy_mem_detect(uct_md_h md, void *addr, uint64_t *dn_mask)
{
#if HAVE_CUDA
int memory_type;
cudaError_t cuda_err = cudaSuccess;
struct cudaPointerAttributes attributes;
CUresult cu_err = CUDA_SUCCESS;

(*dn_mask) = 0;

if (addr == NULL) {
return UCS_OK;
}

cu_err = cuPointerGetAttribute(&memory_type,
CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
(CUdeviceptr)addr);
if (cu_err != CUDA_SUCCESS) {
cuda_err = cudaPointerGetAttributes (&attributes, addr);
if (cuda_err == cudaSuccess) {
if (attributes.memoryType == cudaMemoryTypeDevice) {
(*dn_mask) = UCT_MD_ADDR_DOMAIN_CUDA;
}
}
} else if (memory_type == CU_MEMORYTYPE_DEVICE) {
(*dn_mask) = UCT_MD_ADDR_DOMAIN_CUDA;
}
#else
(*dn_mask) = 0;
#endif
return UCS_OK;
}

static ucs_status_t uct_gdr_copy_query_md_resources(uct_md_resource_desc_t **resources_p,
unsigned *num_resources_p)
{
Expand All @@ -85,7 +120,8 @@ static ucs_status_t uct_gdr_copy_md_open(const char *md_name, const uct_md_confi
.query = uct_gdr_copy_md_query,
.mkey_pack = uct_gdr_copy_mkey_pack,
.mem_reg = uct_gdr_copy_mem_reg,
.mem_dereg = uct_gdr_copy_mem_dereg
.mem_dereg = uct_gdr_copy_mem_dereg,
.mem_detect = uct_gdr_copy_mem_detect
};
static uct_md_t md = {
.ops = &md_ops,
Expand Down
1 change: 1 addition & 0 deletions src/uct/ib/base/ib_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,7 @@ static uct_md_ops_t uct_ib_md_ops = {
.mem_alloc = uct_ib_mem_alloc,
.mem_free = uct_ib_mem_free,
.mem_reg = uct_ib_mem_reg,
.mem_detect = ucs_empty_function_return_success,
.mem_dereg = uct_ib_mem_dereg,
.mem_advise = uct_ib_mem_advise,
.mkey_pack = uct_ib_mkey_pack,
Expand Down
3 changes: 2 additions & 1 deletion src/uct/rocm/rocm_cma_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ static ucs_status_t uct_rocm_cma_md_query(uct_md_h md, uct_md_attr_t *md_attr)
md_attr->rkey_packed_size = sizeof(uct_rocm_cma_key_t);
md_attr->cap.flags = UCT_MD_FLAG_REG |
UCT_MD_FLAG_NEED_RKEY;
md_attr->cap.addr_dn_mask = 0;
md_attr->cap.addr_dn_mask = 0;
md_attr->cap.max_alloc = 0;
md_attr->cap.max_reg = ULONG_MAX;

Expand Down Expand Up @@ -212,6 +212,7 @@ static ucs_status_t uct_rocm_cma_md_open(const char *md_name,
.query = uct_rocm_cma_md_query,
.mkey_pack = uct_rocm_cma_rkey_pack,
.mem_reg = uct_rocm_cma_mem_reg,
.mem_detect = ucs_empty_function_return_success,
.mem_dereg = uct_rocm_cma_mem_dereg
};

Expand Down
3 changes: 2 additions & 1 deletion src/uct/sm/cma/cma_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ static ucs_status_t uct_cma_md_open(const char *md_name, const uct_md_config_t *
.mem_free = (void*)ucs_empty_function_return_success,
.mkey_pack = (void*)ucs_empty_function_return_success,
.mem_reg = uct_cma_mem_reg,
.mem_dereg = (void*)ucs_empty_function_return_success
.mem_dereg = (void*)ucs_empty_function_return_success,
.mem_detect = ucs_empty_function_return_success,
};
static uct_md_t md = {
.ops = &md_ops,
Expand Down
3 changes: 2 additions & 1 deletion src/uct/sm/knem/knem_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ static ucs_status_t uct_knem_md_open(const char *md_name, const uct_md_config_t
.mem_free = (void*)ucs_empty_function_return_success,
.mkey_pack = uct_knem_rkey_pack,
.mem_reg = uct_knem_mem_reg,
.mem_dereg = uct_knem_mem_dereg
.mem_dereg = uct_knem_mem_dereg,
.mem_detect = ucs_empty_function_return_success
};

knem_md = ucs_malloc(sizeof(uct_knem_md_t), "uct_knem_md_t");
Expand Down
1 change: 1 addition & 0 deletions src/uct/sm/mm/mm_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ uct_md_ops_t uct_mm_md_ops = {
.mem_free = uct_mm_mem_free,
.mem_reg = uct_mm_mem_reg,
.mem_dereg = uct_mm_mem_dereg,
.mem_detect = ucs_empty_function_return_success,
.mkey_pack = uct_mm_mkey_pack,
};

Expand Down
3 changes: 2 additions & 1 deletion src/uct/sm/self/self_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ static ucs_status_t uct_self_md_open(const char *md_name, const uct_md_config_t
.query = uct_self_md_query,
.mkey_pack = ucs_empty_function_return_success,
.mem_reg = uct_self_mem_reg,
.mem_dereg = ucs_empty_function_return_success
.mem_dereg = ucs_empty_function_return_success,
.mem_detect = ucs_empty_function_return_success
};
static uct_md_t md = {
.ops = &md_ops,
Expand Down
3 changes: 2 additions & 1 deletion src/uct/tcp/tcp_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ static ucs_status_t uct_tcp_md_open(const char *md_name, const uct_md_config_t *
.query = uct_tcp_md_query,
.mkey_pack = ucs_empty_function_return_unsupported,
.mem_reg = ucs_empty_function_return_unsupported,
.mem_dereg = ucs_empty_function_return_unsupported
.mem_dereg = ucs_empty_function_return_unsupported,
.mem_detect = ucs_empty_function_return_success
};
static uct_md_t md = {
.ops = &md_ops,
Expand Down
1 change: 1 addition & 0 deletions src/uct/ugni/base/ugni_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ static ucs_status_t uct_ugni_md_open(const char *md_name, const uct_md_config_t
.mem_free = (void*)ucs_empty_function,
.mem_reg = uct_ugni_mem_reg,
.mem_dereg = uct_ugni_mem_dereg,
.mem_detect = ucs_empty_function_return_success,
.mkey_pack = uct_ugni_rkey_pack
};

Expand Down

0 comments on commit eeea863

Please sign in to comment.