Skip to content

Commit

Permalink
new function: aws_socket_get_bound_address() (#491)
Browse files Browse the repository at this point in the history
Enable users to bind on port 0, which has the OS assign a port, and then query which port it ended up with.

The socket stores this address during `aws_socket_bind()` call, which allows `aws_socket_get_bound_address()` to be const and avoid any tricky threading issues where the socked closes on another thread.

Also fix a few subtle bugs in Windows socket code.

Co-authored-by: Vitaly Khalmansky <[email protected]>
  • Loading branch information
graebm and bgklika authored May 25, 2022
1 parent 66a38bc commit df07e42
Show file tree
Hide file tree
Showing 5 changed files with 413 additions and 300 deletions.
6 changes: 6 additions & 0 deletions include/aws/io/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ AWS_IO_API int aws_socket_connect(
*/
AWS_IO_API int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint *local_endpoint);

/**
* Get the local address which the socket is bound to.
* Raises an error if no address is bound.
*/
AWS_IO_API int aws_socket_get_bound_address(const struct aws_socket *socket, struct aws_socket_endpoint *out_address);

/**
* TCP, LOCAL and VSOCK only. Sets up the socket to listen on the address bound to in `aws_socket_bind()`.
*/
Expand Down
216 changes: 135 additions & 81 deletions source/posix/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,98 @@ void aws_socket_clean_up(struct aws_socket *socket) {
socket->io_handle.data.fd = -1;
}

/* Update socket->local_endpoint based on the results of getsockname() */
static int s_update_local_endpoint(struct aws_socket *socket) {
struct aws_socket_endpoint tmp_endpoint;
AWS_ZERO_STRUCT(tmp_endpoint);

struct sockaddr_storage address;
AWS_ZERO_STRUCT(address);
socklen_t address_size = sizeof(address);

if (getsockname(socket->io_handle.data.fd, (struct sockaddr *)&address, &address_size) != 0) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: getsockname() failed with error %d",
(void *)socket,
socket->io_handle.data.fd,
errno);
int aws_error = s_determine_socket_error(errno);
return aws_raise_error(aws_error);
}

if (address.ss_family == AF_INET) {
struct sockaddr_in *s = (struct sockaddr_in *)&address;
tmp_endpoint.port = ntohs(s->sin_port);
if (inet_ntop(AF_INET, &s->sin_addr, tmp_endpoint.address, sizeof(tmp_endpoint.address)) == NULL) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: inet_ntop() failed with error %d",
(void *)socket,
socket->io_handle.data.fd,
errno);
int aws_error = s_determine_socket_error(errno);
return aws_raise_error(aws_error);
}
} else if (address.ss_family == AF_INET6) {
struct sockaddr_in6 *s = (struct sockaddr_in6 *)&address;
tmp_endpoint.port = ntohs(s->sin6_port);
if (inet_ntop(AF_INET6, &s->sin6_addr, tmp_endpoint.address, sizeof(tmp_endpoint.address)) == NULL) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: inet_ntop() failed with error %d",
(void *)socket,
socket->io_handle.data.fd,
errno);
int aws_error = s_determine_socket_error(errno);
return aws_raise_error(aws_error);
}
} else if (address.ss_family == AF_UNIX) {
struct sockaddr_un *s = (struct sockaddr_un *)&address;

/* Ensure there's a null-terminator.
* On some platforms it may be missing when the path gets very long. See:
* https://man7.org/linux/man-pages/man7/unix.7.html#BUGS
* But let's keep it simple, and not deal with that madness until someone demands it. */
size_t sun_len;
if (aws_secure_strlen(s->sun_path, sizeof(tmp_endpoint.address), &sun_len)) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: UNIX domain socket name is too long",
(void *)socket,
socket->io_handle.data.fd);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
memcpy(tmp_endpoint.address, s->sun_path, sun_len);
#if USE_VSOCK
} else if (address.ss_family == AF_VSOCK) {
struct sockaddr_vm *s = (struct sockaddr_vm *)&address;

/* VSOCK port is 32bit, but aws_socket_endpoint.port is only 16bit.
* Hopefully this isn't an issue, since users can only pass in 16bit values.
* But if it becomes an issue, we'll need to make aws_socket_endpoint more flexible */
if (s->svm_port > UINT16_MAX) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: aws_socket_endpoint can't deal with VSOCK port > UINT16_MAX",
(void *)socket,
socket->io_handle.data.fd);
return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS);
}
tmp_endpoint.port = (uint16_t)s->svm_port;

snprintf(tmp_endpoint.address, sizeof(tmp_endpoint.address), "%" PRIu32, s->svm_cid);
return AWS_OP_SUCCESS;
#endif /* USE_VSOCK */
} else {
AWS_ASSERT(0);
return aws_raise_error(AWS_IO_SOCKET_UNSUPPORTED_ADDRESS_FAMILY);
}

socket->local_endpoint = tmp_endpoint;
return AWS_OP_SUCCESS;
}

static void s_on_connection_error(struct aws_socket *socket, int error);

static int s_on_connection_success(struct aws_socket *socket) {
Expand Down Expand Up @@ -308,67 +400,8 @@ static int s_on_connection_success(struct aws_socket *socket) {

AWS_LOGF_INFO(AWS_LS_IO_SOCKET, "id=%p fd=%d: connection success", (void *)socket, socket->io_handle.data.fd);

struct sockaddr_storage address;
AWS_ZERO_STRUCT(address);
socklen_t address_size = sizeof(address);
if (!getsockname(socket->io_handle.data.fd, (struct sockaddr *)&address, &address_size)) {
uint16_t port = 0;

if (address.ss_family == AF_INET) {
struct sockaddr_in *s = (struct sockaddr_in *)&address;
port = ntohs(s->sin_port);
/* this comes straight from the kernal. a.) they won't fail. b.) even if they do, it's not fatal
* once we add logging, we can log this if it fails. */
if (inet_ntop(
AF_INET, &s->sin_addr, socket->local_endpoint.address, sizeof(socket->local_endpoint.address))) {
AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: local endpoint %s:%d",
(void *)socket,
socket->io_handle.data.fd,
socket->local_endpoint.address,
port);
} else {
AWS_LOGF_WARN(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: determining local endpoint failed",
(void *)socket,
socket->io_handle.data.fd);
}
} else if (address.ss_family == AF_INET6) {
struct sockaddr_in6 *s = (struct sockaddr_in6 *)&address;
port = ntohs(s->sin6_port);
/* this comes straight from the kernal. a.) they won't fail. b.) even if they do, it's not fatal
* once we add logging, we can log this if it fails. */
if (inet_ntop(
AF_INET6, &s->sin6_addr, socket->local_endpoint.address, sizeof(socket->local_endpoint.address))) {
AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p fd %d: local endpoint %s:%d",
(void *)socket,
socket->io_handle.data.fd,
socket->local_endpoint.address,
port);
} else {
AWS_LOGF_WARN(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: determining local endpoint failed",
(void *)socket,
socket->io_handle.data.fd);
}
}

socket->local_endpoint.port = port;
} else {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: getsockname() failed with error %d",
(void *)socket,
socket->io_handle.data.fd,
errno);
int aws_error = s_determine_socket_error(errno);
aws_raise_error(aws_error);
s_on_connection_error(socket, aws_error);
if (s_update_local_endpoint(socket)) {
s_on_connection_error(socket, aws_last_error());
return AWS_OP_ERR;
}

Expand Down Expand Up @@ -761,9 +794,6 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
return AWS_OP_ERR;
}

int error_code = -1;

socket->local_endpoint = *local_endpoint;
AWS_LOGF_INFO(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: binding to %s:%d.",
Expand Down Expand Up @@ -813,31 +843,55 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint
return aws_raise_error(s_convert_pton_error(pton_err));
}

error_code = bind(socket->io_handle.data.fd, (struct sockaddr *)&address.sock_addr_types, sock_size);
if (bind(socket->io_handle.data.fd, (struct sockaddr *)&address.sock_addr_types, sock_size) != 0) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: bind failed with error code %d",
(void *)socket,
socket->io_handle.data.fd,
errno);

if (!error_code) {
if (socket->options.type == AWS_SOCKET_STREAM) {
socket->state = BOUND;
} else {
/* e.g. UDP is now readable */
socket->state = CONNECTED_READ;
}
AWS_LOGF_DEBUG(AWS_LS_IO_SOCKET, "id=%p fd=%d: successfully bound", (void *)socket, socket->io_handle.data.fd);
aws_raise_error(s_determine_socket_error(errno));
goto error;
}

return AWS_OP_SUCCESS;
if (s_update_local_endpoint(socket)) {
goto error;
}

socket->state = ERROR;
error_code = errno;
AWS_LOGF_ERROR(
if (socket->options.type == AWS_SOCKET_STREAM) {
socket->state = BOUND;
} else {
/* e.g. UDP is now readable */
socket->state = CONNECTED_READ;
}

AWS_LOGF_DEBUG(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: bind failed with error code %d",
"id=%p fd=%d: successfully bound to %s:%d",
(void *)socket,
socket->io_handle.data.fd,
error_code);
socket->local_endpoint.address,
socket->local_endpoint.port);

int aws_error = s_determine_socket_error(error_code);
return aws_raise_error(aws_error);
return AWS_OP_SUCCESS;

error:
socket->state = ERROR;
return AWS_OP_ERR;
}

int aws_socket_get_bound_address(const struct aws_socket *socket, struct aws_socket_endpoint *out_address) {
if (socket->local_endpoint.address[0] == 0) {
AWS_LOGF_ERROR(
AWS_LS_IO_SOCKET,
"id=%p fd=%d: Socket has no local address. Socket must be bound first.",
(void *)socket,
socket->io_handle.data.fd);
return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE);
}
*out_address = socket->local_endpoint;
return AWS_OP_SUCCESS;
}

int aws_socket_listen(struct aws_socket *socket, int backlog_size) {
Expand Down
Loading

0 comments on commit df07e42

Please sign in to comment.