Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v3.1: pml_ucx: add ompi datatype attribute to release ucp_datatype - v3.1 #7053

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 52 additions & 12 deletions ompi/mca/pml/ucx/pml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "opal/runtime/opal.h"
#include "opal/mca/pmix/pmix.h"
#include "ompi/attribute/attribute.h"
#include "ompi/message/message.h"
#include "ompi/mca/pml/base/pml_base_bsend.h"
#include "pml_ucx_request.h"
Expand Down Expand Up @@ -187,9 +188,9 @@ int mca_pml_ucx_close(void)
int mca_pml_ucx_init(void)
{
ucp_worker_params_t params;
ucs_status_t status;
ucp_worker_attr_t attr;
int rc;
ucs_status_t status;
int i, rc;

PML_UCX_VERBOSE(1, "mca_pml_ucx_init");

Expand All @@ -206,30 +207,34 @@ int mca_pml_ucx_init(void)
&ompi_pml_ucx.ucp_worker);
if (UCS_OK != status) {
PML_UCX_ERROR("Failed to create UCP worker");
return OMPI_ERROR;
rc = OMPI_ERROR;
goto err;
}

attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE;
status = ucp_worker_query(ompi_pml_ucx.ucp_worker, &attr);
if (UCS_OK != status) {
ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
ompi_pml_ucx.ucp_worker = NULL;
PML_UCX_ERROR("Failed to query UCP worker thread level");
return OMPI_ERROR;
rc = OMPI_ERROR;
goto err_destroy_worker;
}

if (ompi_mpi_thread_multiple && attr.thread_mode != UCS_THREAD_MODE_MULTI) {
if (ompi_mpi_thread_multiple && (attr.thread_mode != UCS_THREAD_MODE_MULTI)) {
/* UCX does not support multithreading, disqualify current PML for now */
/* TODO: we should let OMPI to fallback to THREAD_SINGLE mode */
ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
ompi_pml_ucx.ucp_worker = NULL;
PML_UCX_ERROR("UCP worker does not support MPI_THREAD_MULTIPLE");
return OMPI_ERROR;
rc = OMPI_ERR_NOT_SUPPORTED;
goto err_destroy_worker;
}

rc = mca_pml_ucx_send_worker_address();
if (rc < 0) {
return rc;
goto err_destroy_worker;
}

ompi_pml_ucx.datatype_attr_keyval = MPI_KEYVAL_INVALID;
for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) {
ompi_pml_ucx.predefined_types[i] = PML_UCX_DATATYPE_INVALID;
}

/* Initialize the free lists */
Expand All @@ -245,15 +250,34 @@ int mca_pml_ucx_init(void)
PML_UCX_VERBOSE(2, "created ucp context %p, worker %p",
(void *)ompi_pml_ucx.ucp_context,
(void *)ompi_pml_ucx.ucp_worker);
return OMPI_SUCCESS;
return rc;

err_destroy_worker:
ucp_worker_destroy(ompi_pml_ucx.ucp_worker);
ompi_pml_ucx.ucp_worker = NULL;
err:
return OMPI_ERROR;
}

int mca_pml_ucx_cleanup(void)
{
int i;

PML_UCX_VERBOSE(1, "mca_pml_ucx_cleanup");

opal_progress_unregister(mca_pml_ucx_progress);

if (ompi_pml_ucx.datatype_attr_keyval != MPI_KEYVAL_INVALID) {
ompi_attr_free_keyval(TYPE_ATTR, &ompi_pml_ucx.datatype_attr_keyval, false);
}

for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) {
if (ompi_pml_ucx.predefined_types[i] != PML_UCX_DATATYPE_INVALID) {
ucp_dt_destroy(ompi_pml_ucx.predefined_types[i]);
ompi_pml_ucx.predefined_types[i] = PML_UCX_DATATYPE_INVALID;
}
}

ompi_pml_ucx.completed_send_req.req_state = OMPI_REQUEST_INVALID;
OMPI_REQUEST_FINI(&ompi_pml_ucx.completed_send_req);
OBJ_DESTRUCT(&ompi_pml_ucx.completed_send_req);
Expand Down Expand Up @@ -452,6 +476,22 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)

int mca_pml_ucx_enable(bool enable)
{
ompi_attribute_fn_ptr_union_t copy_fn;
ompi_attribute_fn_ptr_union_t del_fn;
int ret;

/* Create a key for adding custom attributes to datatypes */
copy_fn.attr_datatype_copy_fn =
(MPI_Type_internal_copy_attr_function*)MPI_TYPE_NULL_COPY_FN;
del_fn.attr_datatype_delete_fn = mca_pml_ucx_datatype_attr_del_fn;
ret = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn,
&ompi_pml_ucx.datatype_attr_keyval, NULL, 0,
NULL);
if (ret != OMPI_SUCCESS) {
PML_UCX_ERROR("Failed to create keyval for UCX datatypes: %d", ret);
return ret;
}

PML_UCX_FREELIST_INIT(&ompi_pml_ucx.persistent_reqs,
mca_pml_ucx_persistent_request_t,
128, -1, 128);
Expand Down
5 changes: 5 additions & 0 deletions ompi/mca/pml/ucx/pml_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "ompi/mca/pml/pml.h"
#include "ompi/mca/pml/base/base.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/datatype/ompi_datatype_internal.h"
#include "ompi/communicator/communicator.h"
#include "ompi/request/request.h"

Expand All @@ -37,6 +38,10 @@ struct mca_pml_ucx_module {
ucp_context_h ucp_context;
ucp_worker_h ucp_worker;

/* Datatypes */
int datatype_attr_keyval;
ucp_datatype_t predefined_types[OMPI_DATATYPE_MPI_MAX_PREDEFINED];

/* Requests */
mca_pml_ucx_freelist_t persistent_reqs;
ompi_request_t completed_send_req;
Expand Down
37 changes: 34 additions & 3 deletions ompi/mca/pml/ucx/pml_ucx_datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "pml_ucx_datatype.h"

#include "ompi/runtime/mpiruntime.h"
#include "ompi/attribute/attribute.h"

#include <inttypes.h>

Expand Down Expand Up @@ -127,12 +128,25 @@ static ucp_generic_dt_ops_t pml_ucx_generic_datatype_ops = {
.finish = pml_ucx_generic_datatype_finish
};

int mca_pml_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval,
void *attr_val, void *extra)
{
ucp_datatype_t ucp_datatype = (ucp_datatype_t)attr_val;

PML_UCX_ASSERT((void*)ucp_datatype == datatype->pml_data);

ucp_dt_destroy(ucp_datatype);
datatype->pml_data = PML_UCX_DATATYPE_INVALID;
return OMPI_SUCCESS;
}

ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype)
{
ucp_datatype_t ucp_datatype;
ucs_status_t status;
ptrdiff_t lb;
size_t size;
int ret;

ompi_datatype_type_lb(datatype, &lb);

Expand All @@ -147,16 +161,33 @@ ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype)
}

status = ucp_dt_create_generic(&pml_ucx_generic_datatype_ops,
datatype, &ucp_datatype);
datatype, &ucp_datatype);
if (status != UCS_OK) {
PML_UCX_ERROR("Failed to create UCX datatype for %s", datatype->name);
ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1);
}

datatype->pml_data = ucp_datatype;

/* Add custom attribute, to clean up UCX resources when OMPI datatype is
* released.
*/
if (ompi_datatype_is_predefined(datatype)) {
PML_UCX_ASSERT(datatype->id < OMPI_DATATYPE_MAX_PREDEFINED);
ompi_pml_ucx.predefined_types[datatype->id] = ucp_datatype;
} else {
ret = ompi_attr_set_c(TYPE_ATTR, datatype, &datatype->d_keyhash,
ompi_pml_ucx.datatype_attr_keyval,
(void*)ucp_datatype, false);
if (ret != OMPI_SUCCESS) {
PML_UCX_ERROR("Failed to add UCX datatype attribute for %s: %d",
datatype->name, ret);
ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1);
}
}

PML_UCX_VERBOSE(7, "created generic UCX datatype 0x%"PRIx64, ucp_datatype)
// TODO put this on a list to be destroyed later

datatype->pml_data = ucp_datatype;
return ucp_datatype;
}

Expand Down
7 changes: 6 additions & 1 deletion ompi/mca/pml/ucx/pml_ucx_datatype.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "pml_ucx.h"


#define PML_UCX_DATATYPE_INVALID 0

struct pml_ucx_convertor {
opal_free_list_item_t super;
ompi_datatype_t *datatype;
Expand All @@ -23,14 +25,17 @@ struct pml_ucx_convertor {

ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype);

int mca_pml_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval,
void *attr_val, void *extra);

OBJ_CLASS_DECLARATION(mca_pml_ucx_convertor_t);


static inline ucp_datatype_t mca_pml_ucx_get_datatype(ompi_datatype_t *datatype)
{
ucp_datatype_t ucp_type = datatype->pml_data;

if (OPAL_LIKELY(ucp_type != 0)) {
if (OPAL_LIKELY(ucp_type != PML_UCX_DATATYPE_INVALID)) {
return ucp_type;
}

Expand Down