Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify handling of sockaddr information lookup #38

Merged
merged 8 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/include/ucxx/endpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Endpoint : public Component {
* @param[in] endpointErrorHandling whether to enable endpoint error handling.
*/
Endpoint(std::shared_ptr<Component> workerOrListener,
std::unique_ptr<ucp_ep_params_t, EpParamsDeleter> params,
ucp_ep_params_t* params,
bool endpointErrorHandling);

/**
Expand Down
26 changes: 8 additions & 18 deletions cpp/include/ucxx/utils/sockaddr.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,26 @@
*/
#pragma once

#include <ucp/api/ucp.h>
#include <memory>
#include <netdb.h>

namespace ucxx {

namespace utils {

/**
* @brief Set socket address and port of a socket address storage.
* @brief Get an addrinfo struct corresponding to an address and port.
*
* Set a socket address and port as defined by the user in a socket address storage that
* may later be used to specify an address to bind a UCP listener to.
* This information can later be used to bind a UCP listener or endpoint.
*
* @param[in] sockaddr pointer to the UCS socket address storage.
* @param[in] ip_address valid socket address (e.g., IP address or hostname) or NULL as a
* wildcard for "all" to set the socket address storage to.
* @param[in] port port to set the socket address storaget to.
*/
int sockaddr_set(ucs_sock_addr_t* sockaddr, const char* ip_address, uint16_t port);

/**
* @brief Release the underlying socket address.
*
* Release the underlying socket address container.
*
* NOTE: This function does not release the `ucs_sock_addr_t`, only the underlying
* `sockaddr` member.
* @param[in] port port to set the socket address storage to.
*
* @param[in] sockaddr pointer to the UCS socket address storage.
* @returns unique pointer wrapping a `struct addrinfo` (frees the addrinfo when out of scope)
*/
void sockaddr_free(ucs_sock_addr_t* sockaddr);
std::unique_ptr<struct addrinfo, void (*)(struct addrinfo*)> get_addrinfo(const char* ip_address,
uint16_t port);

/**
* @brief Get socket address and port of a socket address storage.
Expand Down
50 changes: 21 additions & 29 deletions cpp/src/endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,8 @@

namespace ucxx {

void EpParamsDeleter::operator()(ucp_ep_params_t* ptr)
{
if (ptr != nullptr && ptr->field_mask & UCP_EP_PARAM_FIELD_SOCK_ADDR)
ucxx::utils::sockaddr_free(&ptr->sockaddr);
}

Endpoint::Endpoint(std::shared_ptr<Component> workerOrListener,
std::unique_ptr<ucp_ep_params_t, EpParamsDeleter> params,
ucp_ep_params_t* params,
bool endpointErrorHandling)
: _endpointErrorHandling{endpointErrorHandling}
{
Expand All @@ -49,7 +43,7 @@ Endpoint::Endpoint(std::shared_ptr<Component> workerOrListener,
params->err_handler.cb = Endpoint::errorCallback;
params->err_handler.arg = _callbackData.get();

utils::ucsErrorThrow(ucp_ep_create(worker->getHandle(), params.get(), &_handle));
utils::ucsErrorThrow(ucp_ep_create(worker->getHandle(), params, &_handle));
ucxx_trace("Endpoint created: %p", _handle);
}

Expand All @@ -61,17 +55,16 @@ std::shared_ptr<Endpoint> createEndpointFromHostname(std::shared_ptr<Worker> wor
if (worker == nullptr || worker->getHandle() == nullptr)
throw ucxx::Error("Worker not initialized");

auto params = std::unique_ptr<ucp_ep_params_t, EpParamsDeleter>(new ucp_ep_params_t);

struct hostent* hostname = gethostbyname(ipAddress.c_str());
if (hostname == nullptr) throw ucxx::Error(std::string("Invalid IP address or hostname"));
ucp_ep_params_t params = {.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_SOCK_ADDR |
UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
UCP_EP_PARAM_FIELD_ERR_HANDLER,
.flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER};
auto info = ucxx::utils::get_addrinfo(ipAddress.c_str(), port);

params->field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_SOCK_ADDR |
UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER;
params->flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER;
if (ucxx::utils::sockaddr_set(&params->sockaddr, hostname->h_name, port)) throw std::bad_alloc();
params.sockaddr.addrlen = info->ai_addrlen;
params.sockaddr.addr = info->ai_addr;

return std::shared_ptr<Endpoint>(new Endpoint(worker, std::move(params), endpointErrorHandling));
return std::shared_ptr<Endpoint>(new Endpoint(worker, &params, endpointErrorHandling));
}

std::shared_ptr<Endpoint> createEndpointFromConnRequest(std::shared_ptr<Listener> listener,
Expand All @@ -81,14 +74,13 @@ std::shared_ptr<Endpoint> createEndpointFromConnRequest(std::shared_ptr<Listener
if (listener == nullptr || listener->getHandle() == nullptr)
throw ucxx::Error("Worker not initialized");

auto params = std::unique_ptr<ucp_ep_params_t, EpParamsDeleter>(new ucp_ep_params_t);
params->field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_CONN_REQUEST |
UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER;
params->flags = UCP_EP_PARAMS_FLAGS_NO_LOOPBACK;
params->conn_request = connRequest;
ucp_ep_params_t params = {
.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_CONN_REQUEST |
UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER,
.flags = UCP_EP_PARAMS_FLAGS_NO_LOOPBACK,
.conn_request = connRequest};

return std::shared_ptr<Endpoint>(
new Endpoint(listener, std::move(params), endpointErrorHandling));
return std::shared_ptr<Endpoint>(new Endpoint(listener, &params, endpointErrorHandling));
}

std::shared_ptr<Endpoint> createEndpointFromWorkerAddress(std::shared_ptr<Worker> worker,
Expand All @@ -100,12 +92,12 @@ std::shared_ptr<Endpoint> createEndpointFromWorkerAddress(std::shared_ptr<Worker
if (address == nullptr || address->getHandle() == nullptr || address->getLength() == 0)
throw ucxx::Error("Address not initialized");

auto params = std::unique_ptr<ucp_ep_params_t, EpParamsDeleter>(new ucp_ep_params_t);
params->field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
UCP_EP_PARAM_FIELD_ERR_HANDLER;
params->address = address->getHandle();
ucp_ep_params_t params = {.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS |
UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
UCP_EP_PARAM_FIELD_ERR_HANDLER,
.address = address->getHandle()};

return std::shared_ptr<Endpoint>(new Endpoint(worker, std::move(params), endpointErrorHandling));
return std::shared_ptr<Endpoint>(new Endpoint(worker, &params, endpointErrorHandling));
}

Endpoint::~Endpoint()
Expand Down
8 changes: 3 additions & 5 deletions cpp/src/listener.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@ Listener::Listener(std::shared_ptr<Worker> worker,
params.conn_handler.cb = callback;
params.conn_handler.arg = callbackArgs;

if (ucxx::utils::sockaddr_set(&params.sockaddr, NULL, port))
// throw std::bad_alloc("Failed allocation of sockaddr")
throw std::bad_alloc();
std::unique_ptr<ucs_sock_addr_t, void (*)(ucs_sock_addr_t*)> sockaddr(&params.sockaddr,
ucxx::utils::sockaddr_free);
auto info = ucxx::utils::get_addrinfo(NULL, port);
params.sockaddr.addr = info->ai_addr;
params.sockaddr.addrlen = info->ai_addrlen;

ucp_listener_h handle = nullptr;
utils::ucsErrorThrow(ucp_listener_create(worker->getHandle(), &params, &handle));
Expand Down
54 changes: 30 additions & 24 deletions cpp/src/utils/sockaddr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,58 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
#include <arpa/inet.h>
#include <netinet/in.h>
#include <memory>
#include <netdb.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>

#include <ucxx/exception.h>
#include <ucxx/utils/sockaddr.h>

namespace ucxx {

namespace utils {

int sockaddr_set(ucs_sock_addr_t* sockaddr, const char* ip_address, uint16_t port)
std::unique_ptr<struct addrinfo, void (*)(struct addrinfo*)> get_addrinfo(const char* ip_address,
uint16_t port)
{
struct sockaddr_in* addr = reinterpret_cast<sockaddr_in*>(malloc(sizeof(struct sockaddr_in)));
if (addr == NULL) { return 1; }
memset(addr, 0, sizeof(struct sockaddr_in));
addr->sin_family = AF_INET;
addr->sin_addr.s_addr = ip_address == NULL ? INADDR_ANY : inet_addr(ip_address);
addr->sin_port = htons(port);
sockaddr->addr = (const struct sockaddr*)addr;
sockaddr->addrlen = sizeof(struct sockaddr_in);
return 0;
}

void sockaddr_free(ucs_sock_addr_t* sockaddr)
{
::free(const_cast<void*>(reinterpret_cast<const void*>(sockaddr->addr)));
std::unique_ptr<struct addrinfo, void (*)(struct addrinfo*)> info(nullptr, ::freeaddrinfo);
{
char ports[6];
struct addrinfo* result = nullptr;
struct addrinfo hints;
// Don't restrict lookups
::memset(&hints, 0, sizeof(hints));
// Except, port is numeric, address may be NULL meaning the
// returned address is the wildcard.
hints.ai_flags = AI_NUMERICSERV | AI_PASSIVE;
if (::snprintf(ports, sizeof(ports), "%u", port) > sizeof(ports))
throw ucxx::Error(std::string("Invalid port"));
if (::getaddrinfo(ip_address, ports, &hints, &result))
throw ucxx::Error(std::string("Invalid IP address or hostname"));
info.reset(result);
}
return info;
}

void sockaddr_get_ip_port_str(const struct sockaddr_storage* sockaddr,
char* ip_str,
char* port_str,
size_t max_str_size)
{
struct sockaddr_in addr_in;
struct sockaddr_in6 addr_in6;
const struct sockaddr_in* addr_in = nullptr;
const struct sockaddr_in6* addr_in6 = nullptr;

switch (sockaddr->ss_family) {
case AF_INET:
memcpy(&addr_in, sockaddr, sizeof(struct sockaddr_in));
inet_ntop(AF_INET, &addr_in.sin_addr, ip_str, max_str_size);
snprintf(port_str, max_str_size, "%d", ntohs(addr_in.sin_port));
addr_in = reinterpret_cast<decltype(addr_in)>(sockaddr);
inet_ntop(AF_INET, &addr_in->sin_addr, ip_str, max_str_size);
snprintf(port_str, max_str_size, "%u", ntohs(addr_in->sin_port));
case AF_INET6:
memcpy(&addr_in6, sockaddr, sizeof(struct sockaddr_in6));
inet_ntop(AF_INET6, &addr_in6.sin6_addr, ip_str, max_str_size);
snprintf(port_str, max_str_size, "%d", ntohs(addr_in6.sin6_port));
addr_in6 = reinterpret_cast<decltype(addr_in6)>(sockaddr);
inet_ntop(AF_INET6, &addr_in6->sin6_addr, ip_str, max_str_size);
snprintf(port_str, max_str_size, "%u", ntohs(addr_in6->sin6_port));
default:
ip_str = const_cast<char*>(reinterpret_cast<const char*>("Invalid address family"));
port_str = const_cast<char*>(reinterpret_cast<const char*>("Invalid address family"));
Expand Down