diff --git a/prov/verbs/src/fi_verbs.h b/prov/verbs/src/fi_verbs.h index 8710c4bbb3d..4b61c9035ce 100644 --- a/prov/verbs/src/fi_verbs.h +++ b/prov/verbs/src/fi_verbs.h @@ -289,6 +289,7 @@ struct fi_ibv_msg_ep { ofi_atomic32_t unsignaled_send_cnt; ofi_atomic32_t comp_pending; uint64_t ep_id; + struct fi_ibv_domain *domain; }; struct fi_ibv_msg_epe { @@ -370,6 +371,10 @@ int fi_ibv_cq_signal(struct fid_cq *cq); ssize_t fi_ibv_eq_write_event(struct fi_ibv_eq *eq, uint32_t event, const void *buf, size_t len); +int fi_ibv_query_atomic(struct fid_domain *domain_fid, enum fi_datatype datatype, + enum fi_op op, struct fi_atomic_attr *attr, + uint64_t flags); + #define fi_ibv_set_sge(sge, buf, len, desc) \ do { \ sge.addr = (uintptr_t)buf; \ diff --git a/prov/verbs/src/verbs_atomic.c b/prov/verbs/src/verbs_atomic.c index 65cabe0ee6b..1c4e44a59be 100644 --- a/prov/verbs/src/verbs_atomic.c +++ b/prov/verbs/src/verbs_atomic.c @@ -35,54 +35,75 @@ #include "fi_verbs.h" -static int -fi_ibv_msg_ep_atomic_writevalid(struct fid_ep *ep, enum fi_datatype datatype, - enum fi_op op, size_t *count) +#define fi_ibv_atomicvalid(name, flags) \ +static int fi_ibv_msg_ep_atomic_ ## name(struct fid_ep *ep_fid, \ + enum fi_datatype datatype, \ + enum fi_op op, size_t *count) \ +{ \ + struct fi_ibv_msg_ep *ep = container_of(ep_fid, \ + struct fi_ibv_msg_ep, \ + ep_fid); \ + struct fi_atomic_attr attr; \ + int ret; \ + \ + ret = fi_ibv_query_atomic(&ep->domain->domain_fid, datatype, \ + op, &attr, flags); \ + if (!ret) \ + *count = attr.count; \ + return ret; \ +} \ + +fi_ibv_atomicvalid(writevalid, 0); +fi_ibv_atomicvalid(readwritevalid, FI_FETCH_ATOMIC); +fi_ibv_atomicvalid(compwritevalid, FI_COMPARE_ATOMIC); + +int fi_ibv_query_atomic(struct fid_domain *domain_fid, enum fi_datatype datatype, + enum fi_op op, struct fi_atomic_attr *attr, + uint64_t flags) { - switch (op) { - case FI_ATOMIC_WRITE: - break; - default: + struct fi_ibv_domain *domain = container_of(domain_fid, + struct fi_ibv_domain, + domain_fid); + char *log_str_fetch = "fi_fetch_atomic with FI_SUM op"; + char *log_str_comp = "fi_compare_atomic"; + char *log_str; + + if (flags & FI_TAGGED) return -FI_ENOSYS; - } - - switch (datatype) { - case FI_INT64: - case FI_UINT64: -#if __BITS_PER_LONG == 64 - case FI_DOUBLE: - case FI_FLOAT: -#endif - break; - default: - return -FI_EINVAL; - } - if (count) - *count = 1; - return 0; -} + if ((flags & FI_FETCH_ATOMIC) && (flags & FI_COMPARE_ATOMIC)) + return -FI_EBADFLAGS; -static int -fi_ibv_msg_ep_atomic_readwritevalid(struct fid_ep *ep, enum fi_datatype datatype, - enum fi_op op, size_t *count) -{ - struct fi_ibv_msg_ep *_ep = container_of(ep, struct fi_ibv_msg_ep, ep_fid); - - switch (op) { - case FI_ATOMIC_READ: - break; - case FI_SUM: - if (_ep->info->tx_attr->op_flags & FI_INJECT) { - VERBS_INFO(FI_LOG_EP_DATA,"FI_INJECT not " - "supported for fi_fetch_atomic with FI_SUM op\n"); + if (!flags) { + switch (op) { + case FI_ATOMIC_WRITE: + break; + default: + return -FI_ENOSYS; + } + } else { + if (flags & FI_FETCH_ATOMIC) { + switch (op) { + case FI_ATOMIC_READ: + goto check_datatype; + case FI_SUM: + log_str = log_str_fetch; + break; + default: + return -FI_ENOSYS; + } + } else if (flags & FI_COMPARE_ATOMIC) { + if (op != FI_CSWAP) + return -FI_ENOSYS; + log_str = log_str_comp; + } + if (domain->info->tx_attr->op_flags & FI_INJECT) { + VERBS_INFO(FI_LOG_EP_DATA, + "FI_INJECT not supported for %s\n", log_str); return -FI_EINVAL; } - break; - default: - return -FI_ENOSYS; } - +check_datatype: switch (datatype) { case FI_INT64: case FI_UINT64: @@ -95,40 +116,11 @@ fi_ibv_msg_ep_atomic_readwritevalid(struct fid_ep *ep, enum fi_datatype datatype return -FI_EINVAL; } - if (count) - *count = 1; - return 0; -} - -static int -fi_ibv_msg_ep_atomic_compwritevalid(struct fid_ep *ep, enum fi_datatype datatype, - enum fi_op op, size_t *count) -{ - struct fi_ibv_msg_ep *_ep = container_of(ep, struct fi_ibv_msg_ep, ep_fid); - - if (op != FI_CSWAP) - return -FI_ENOSYS; - - if (_ep->info->tx_attr->op_flags & FI_INJECT) { - VERBS_INFO(FI_LOG_EP_DATA, "FI_INJECT not supported " - "for fi_compare_atomic\n"); + attr->size = fi_datatype_size(datatype); + if (attr->size == 0) return -FI_EINVAL; - } - - switch (datatype) { - case FI_INT64: - case FI_UINT64: -#if __BITS_PER_LONG == 64 - case FI_DOUBLE: - case FI_FLOAT: -#endif - break; - default: - return -FI_EINVAL; - } - if (count) - *count = 1; + attr->count = 1; return 0; } diff --git a/prov/verbs/src/verbs_domain.c b/prov/verbs/src/verbs_domain.c index 47600b23ec2..fcba4828e1b 100644 --- a/prov/verbs/src/verbs_domain.c +++ b/prov/verbs/src/verbs_domain.c @@ -307,6 +307,7 @@ static struct fi_ops_domain fi_ibv_rdm_domain_ops = { .poll_open = fi_no_poll_open, .stx_ctx = fi_no_stx_context, .srx_ctx = fi_no_srx_context, + .query_atomic = fi_ibv_query_atomic, }; static int diff --git a/prov/verbs/src/verbs_msg_ep.c b/prov/verbs/src/verbs_msg_ep.c index 138b4d5be44..eb7d6311f98 100644 --- a/prov/verbs/src/verbs_msg_ep.c +++ b/prov/verbs/src/verbs_msg_ep.c @@ -361,6 +361,7 @@ int fi_ibv_open_ep(struct fid_domain *domain, struct fi_info *info, ofi_atomic_initialize32(&_ep->unsignaled_send_cnt, 0); ofi_atomic_initialize32(&_ep->comp_pending, 0); + _ep->domain = dom; *ep = &_ep->ep_fid; return 0;