diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index e2296d6b..49c97ee8 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -70,7 +70,7 @@ class Endpoint : public Component { * @param[in] endpointErrorHandling whether to enable endpoint error handling. */ Endpoint(std::shared_ptr workerOrListener, - std::unique_ptr params, + ucp_ep_params_t* params, bool endpointErrorHandling); /** diff --git a/cpp/include/ucxx/utils/sockaddr.h b/cpp/include/ucxx/utils/sockaddr.h index d48d5151..7581a056 100644 --- a/cpp/include/ucxx/utils/sockaddr.h +++ b/cpp/include/ucxx/utils/sockaddr.h @@ -4,36 +4,26 @@ */ #pragma once -#include +#include +#include namespace ucxx { namespace utils { /** - * @brief Set socket address and port of a socket address storage. + * @brief Get an addrinfo struct corresponding to an address and port. * - * Set a socket address and port as defined by the user in a socket address storage that - * may later be used to specify an address to bind a UCP listener to. + * This information can later be used to bind a UCP listener or endpoint. * - * @param[in] sockaddr pointer to the UCS socket address storage. * @param[in] ip_address valid socket address (e.g., IP address or hostname) or NULL as a * wildcard for "all" to set the socket address storage to. - * @param[in] port port to set the socket address storaget to. - */ -int sockaddr_set(ucs_sock_addr_t* sockaddr, const char* ip_address, uint16_t port); - -/** - * @brief Release the underlying socket address. - * - * Release the underlying socket address container. - * - * NOTE: This function does not release the `ucs_sock_addr_t`, only the underlying - * `sockaddr` member. + * @param[in] port port to set the socket address storage to. * - * @param[in] sockaddr pointer to the UCS socket address storage. + * @returns unique pointer wrapping a `struct addrinfo` (frees the addrinfo when out of scope) */ -void sockaddr_free(ucs_sock_addr_t* sockaddr); +std::unique_ptr get_addrinfo(const char* ip_address, + uint16_t port); /** * @brief Get socket address and port of a socket address storage. diff --git a/cpp/src/context.cpp b/cpp/src/context.cpp index 2bd0c1a2..29d150db 100644 --- a/cpp/src/context.cpp +++ b/cpp/src/context.cpp @@ -17,13 +17,10 @@ namespace ucxx { Context::Context(const ConfigMap ucxConfig, const uint64_t featureFlags) : _config{ucxConfig}, _featureFlags{featureFlags} { - ucp_params_t params{}; - parseLogLevel(); // UCP - params.field_mask = UCP_PARAM_FIELD_FEATURES; - params.features = featureFlags; + ucp_params_t params = {.field_mask = UCP_PARAM_FIELD_FEATURES, .features = featureFlags}; utils::ucsErrorThrow(ucp_init(¶ms, this->_config.getHandle(), &this->_handle)); ucxx_trace("Context created: %p", this->_handle); diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 92eda07e..b0b61e85 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -23,14 +23,8 @@ namespace ucxx { -void EpParamsDeleter::operator()(ucp_ep_params_t* ptr) -{ - if (ptr != nullptr && ptr->field_mask & UCP_EP_PARAM_FIELD_SOCK_ADDR) - ucxx::utils::sockaddr_free(&ptr->sockaddr); -} - Endpoint::Endpoint(std::shared_ptr workerOrListener, - std::unique_ptr params, + ucp_ep_params_t* params, bool endpointErrorHandling) : _endpointErrorHandling{endpointErrorHandling} { @@ -49,7 +43,7 @@ Endpoint::Endpoint(std::shared_ptr workerOrListener, params->err_handler.cb = Endpoint::errorCallback; params->err_handler.arg = _callbackData.get(); - utils::ucsErrorThrow(ucp_ep_create(worker->getHandle(), params.get(), &_handle)); + utils::ucsErrorThrow(ucp_ep_create(worker->getHandle(), params, &_handle)); ucxx_trace("Endpoint created: %p", _handle); } @@ -61,17 +55,16 @@ std::shared_ptr createEndpointFromHostname(std::shared_ptr wor if (worker == nullptr || worker->getHandle() == nullptr) throw ucxx::Error("Worker not initialized"); - auto params = std::unique_ptr(new ucp_ep_params_t); - - struct hostent* hostname = gethostbyname(ipAddress.c_str()); - if (hostname == nullptr) throw ucxx::Error(std::string("Invalid IP address or hostname")); + ucp_ep_params_t params = {.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_SOCK_ADDR | + UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | + UCP_EP_PARAM_FIELD_ERR_HANDLER, + .flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER}; + auto info = ucxx::utils::get_addrinfo(ipAddress.c_str(), port); - params->field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_SOCK_ADDR | - UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; - params->flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER; - if (ucxx::utils::sockaddr_set(¶ms->sockaddr, hostname->h_name, port)) throw std::bad_alloc(); + params.sockaddr.addrlen = info->ai_addrlen; + params.sockaddr.addr = info->ai_addr; - return std::shared_ptr(new Endpoint(worker, std::move(params), endpointErrorHandling)); + return std::shared_ptr(new Endpoint(worker, ¶ms, endpointErrorHandling)); } std::shared_ptr createEndpointFromConnRequest(std::shared_ptr listener, @@ -81,14 +74,13 @@ std::shared_ptr createEndpointFromConnRequest(std::shared_ptrgetHandle() == nullptr) throw ucxx::Error("Worker not initialized"); - auto params = std::unique_ptr(new ucp_ep_params_t); - params->field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_CONN_REQUEST | - UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; - params->flags = UCP_EP_PARAMS_FLAGS_NO_LOOPBACK; - params->conn_request = connRequest; + ucp_ep_params_t params = { + .field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_CONN_REQUEST | + UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER, + .flags = UCP_EP_PARAMS_FLAGS_NO_LOOPBACK, + .conn_request = connRequest}; - return std::shared_ptr( - new Endpoint(listener, std::move(params), endpointErrorHandling)); + return std::shared_ptr(new Endpoint(listener, ¶ms, endpointErrorHandling)); } std::shared_ptr createEndpointFromWorkerAddress(std::shared_ptr worker, @@ -100,12 +92,12 @@ std::shared_ptr createEndpointFromWorkerAddress(std::shared_ptrgetHandle() == nullptr || address->getLength() == 0) throw ucxx::Error("Address not initialized"); - auto params = std::unique_ptr(new ucp_ep_params_t); - params->field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | - UCP_EP_PARAM_FIELD_ERR_HANDLER; - params->address = address->getHandle(); + ucp_ep_params_t params = {.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | + UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | + UCP_EP_PARAM_FIELD_ERR_HANDLER, + .address = address->getHandle()}; - return std::shared_ptr(new Endpoint(worker, std::move(params), endpointErrorHandling)); + return std::shared_ptr(new Endpoint(worker, ¶ms, endpointErrorHandling)); } Endpoint::~Endpoint() diff --git a/cpp/src/listener.cpp b/cpp/src/listener.cpp index 3996f852..d6e23524 100644 --- a/cpp/src/listener.cpp +++ b/cpp/src/listener.cpp @@ -27,25 +27,19 @@ Listener::Listener(std::shared_ptr worker, if (worker == nullptr || worker->getHandle() == nullptr) throw ucxx::Error("Worker not initialized"); - ucp_listener_params_t params{}; - ucp_listener_attr_t attr{}; - - params.field_mask = UCP_LISTENER_PARAM_FIELD_SOCK_ADDR | UCP_LISTENER_PARAM_FIELD_CONN_HANDLER; - params.conn_handler.cb = callback; - params.conn_handler.arg = callbackArgs; - - if (ucxx::utils::sockaddr_set(¶ms.sockaddr, NULL, port)) - // throw std::bad_alloc("Failed allocation of sockaddr") - throw std::bad_alloc(); - std::unique_ptr sockaddr(¶ms.sockaddr, - ucxx::utils::sockaddr_free); + ucp_listener_params_t params = { + .field_mask = UCP_LISTENER_PARAM_FIELD_SOCK_ADDR | UCP_LISTENER_PARAM_FIELD_CONN_HANDLER, + .conn_handler = {.cb = callback, .arg = callbackArgs}}; + auto info = ucxx::utils::get_addrinfo(NULL, port); + params.sockaddr.addr = info->ai_addr; + params.sockaddr.addrlen = info->ai_addrlen; ucp_listener_h handle = nullptr; utils::ucsErrorThrow(ucp_listener_create(worker->getHandle(), ¶ms, &handle)); _handle = std::unique_ptr(handle, ucpListenerDestructor); ucxx_trace("Listener created: %p", _handle.get()); - attr.field_mask = UCP_LISTENER_ATTR_FIELD_SOCKADDR; + ucp_listener_attr_t attr = {.field_mask = UCP_LISTENER_ATTR_FIELD_SOCKADDR}; utils::ucsErrorThrow(ucp_listener_query(_handle.get(), &attr)); char ipString[INET6_ADDRSTRLEN]; diff --git a/cpp/src/utils/sockaddr.cpp b/cpp/src/utils/sockaddr.cpp index 752a9f37..3d501b6f 100644 --- a/cpp/src/utils/sockaddr.cpp +++ b/cpp/src/utils/sockaddr.cpp @@ -3,33 +3,39 @@ * SPDX-License-Identifier: BSD-3-Clause */ #include -#include +#include +#include #include #include #include +#include #include namespace ucxx { namespace utils { -int sockaddr_set(ucs_sock_addr_t* sockaddr, const char* ip_address, uint16_t port) +std::unique_ptr get_addrinfo(const char* ip_address, + uint16_t port) { - struct sockaddr_in* addr = reinterpret_cast(malloc(sizeof(struct sockaddr_in))); - if (addr == NULL) { return 1; } - memset(addr, 0, sizeof(struct sockaddr_in)); - addr->sin_family = AF_INET; - addr->sin_addr.s_addr = ip_address == NULL ? INADDR_ANY : inet_addr(ip_address); - addr->sin_port = htons(port); - sockaddr->addr = (const struct sockaddr*)addr; - sockaddr->addrlen = sizeof(struct sockaddr_in); - return 0; -} - -void sockaddr_free(ucs_sock_addr_t* sockaddr) -{ - ::free(const_cast(reinterpret_cast(sockaddr->addr))); + std::unique_ptr info(nullptr, ::freeaddrinfo); + { + char ports[6]; + struct addrinfo* result = nullptr; + struct addrinfo hints; + // Don't restrict lookups + ::memset(&hints, 0, sizeof(hints)); + // Except, port is numeric, address may be NULL meaning the + // returned address is the wildcard. + hints.ai_flags = AI_NUMERICSERV | AI_PASSIVE; + if (::snprintf(ports, sizeof(ports), "%u", port) > sizeof(ports)) + throw ucxx::Error(std::string("Invalid port")); + if (::getaddrinfo(ip_address, ports, &hints, &result)) + throw ucxx::Error(std::string("Invalid IP address or hostname")); + info.reset(result); + } + return info; } void sockaddr_get_ip_port_str(const struct sockaddr_storage* sockaddr, @@ -37,18 +43,18 @@ void sockaddr_get_ip_port_str(const struct sockaddr_storage* sockaddr, char* port_str, size_t max_str_size) { - struct sockaddr_in addr_in; - struct sockaddr_in6 addr_in6; + const struct sockaddr_in* addr_in = nullptr; + const struct sockaddr_in6* addr_in6 = nullptr; switch (sockaddr->ss_family) { case AF_INET: - memcpy(&addr_in, sockaddr, sizeof(struct sockaddr_in)); - inet_ntop(AF_INET, &addr_in.sin_addr, ip_str, max_str_size); - snprintf(port_str, max_str_size, "%d", ntohs(addr_in.sin_port)); + addr_in = reinterpret_cast(sockaddr); + inet_ntop(AF_INET, &addr_in->sin_addr, ip_str, max_str_size); + snprintf(port_str, max_str_size, "%u", ntohs(addr_in->sin_port)); case AF_INET6: - memcpy(&addr_in6, sockaddr, sizeof(struct sockaddr_in6)); - inet_ntop(AF_INET6, &addr_in6.sin6_addr, ip_str, max_str_size); - snprintf(port_str, max_str_size, "%d", ntohs(addr_in6.sin6_port)); + addr_in6 = reinterpret_cast(sockaddr); + inet_ntop(AF_INET6, &addr_in6->sin6_addr, ip_str, max_str_size); + snprintf(port_str, max_str_size, "%u", ntohs(addr_in6->sin6_port)); default: ip_str = const_cast(reinterpret_cast("Invalid address family")); port_str = const_cast(reinterpret_cast("Invalid address family")); diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index e9b47968..d08c88ae 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -24,13 +24,11 @@ namespace ucxx { Worker::Worker(std::shared_ptr context, const bool enableDelayedSubmission) { - ucp_worker_params_t params{}; - if (context == nullptr || context->getHandle() == nullptr) throw std::runtime_error("Context not initialized"); - params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - params.thread_mode = UCS_THREAD_MODE_MULTI; + ucp_worker_params_t params = {.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE, + .thread_mode = UCS_THREAD_MODE_MULTI}; utils::ucsErrorThrow(ucp_worker_create(context->getHandle(), ¶ms, &_handle)); if (enableDelayedSubmission)