Skip to content

Commit

Permalink
windows tcp: Lookup extended TCP function pointers at startup
Browse files Browse the repository at this point in the history
This avoids the need for a lock during listener or dialer initialization,
and it avoids the need to carry these pointers on those objects.
It also eliminates a potential failure case "post startup".
  • Loading branch information
gdamore committed Dec 29, 2024
1 parent 7b5515c commit 6a41461
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 105 deletions.
71 changes: 70 additions & 1 deletion src/platform/windows/win_tcp.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// Copyright 2018 Staysail Systems, Inc. <[email protected]>
// Copyright 2024 Staysail Systems, Inc. <[email protected]>
// Copyright 2018 Capitar IT Group BV <[email protected]>
//
// This software is supplied under the terms of the MIT License, a
Expand All @@ -15,18 +15,87 @@
#include <malloc.h>
#include <stdio.h>

static LPFN_ACCEPTEX acceptex;
static LPFN_GETACCEPTEXSOCKADDRS getacceptexsockaddrs;
static LPFN_CONNECTEX connectex;

int
nni_win_tcp_sysinit(void)
{
int rv;

WSADATA data;

if (WSAStartup(MAKEWORD(2, 2), &data) != 0) {
NNI_ASSERT(LOBYTE(data.wVersion) == 2);
NNI_ASSERT(HIBYTE(data.wVersion) == 2);
return (nni_win_error(GetLastError()));
}

DWORD nbytes;
GUID guid1 = WSAID_ACCEPTEX;
GUID guid2 = WSAID_GETACCEPTEXSOCKADDRS;
GUID guid3 = WSAID_CONNECTEX;

SOCKET s = socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP);
if (s == INVALID_SOCKET) {
rv = nni_win_error(GetLastError());
WSACleanup();
return (rv);
}
if ((WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid1,
sizeof(guid1), &acceptex, sizeof(acceptex), &nbytes, NULL,
NULL) == SOCKET_ERROR) ||
(WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid2,
sizeof(guid2), &getacceptexsockaddrs,
sizeof(getacceptexsockaddrs), &nbytes, NULL,
NULL) == SOCKET_ERROR) ||
(WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid3,
sizeof(guid3), &connectex, sizeof(connectex), &nbytes, NULL,
NULL) == SOCKET_ERROR)) {
rv = nni_win_error(GetLastError());
closesocket(s);
WSACleanup();
return (rv);
}

closesocket(s);
return (0);
}

int
nni_win_acceptex(SOCKET listen, SOCKET child, void *buf, LPOVERLAPPED olpd)
{
DWORD cnt = 0;
return (acceptex(listen, child, buf, 0, 256, 256, &cnt, olpd));
}

// This is called after a socket is accepted for the connection, and the buffer
// contains the peers socket addresses. It is is kind of weird, windows
// specific, and must be called only after acceptex. The caller should call
// setsockopt(s, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT) after calling this.
void
nni_win_get_acceptex_sockaddrs(
void *buf, SOCKADDR_STORAGE *self, SOCKADDR_STORAGE *peer)
{
SOCKADDR *self_p;
SOCKADDR *peer_p;
int self_len;
int peer_len;

getacceptexsockaddrs(
buf, 0, 256, 256, &self_p, &self_len, &peer_p, &peer_len);

(void) memcpy(self, self_p, self_len);
(void) memcpy(peer, peer_p, peer_len);
}

int
nni_win_connectex(SOCKET s, SOCKADDR *peer, int peer_len, LPOVERLAPPED olpd)
{
return (connectex(s, peer, peer_len, NULL, 0, NULL, olpd));
}

void
nni_win_tcp_sysfini(void)
{
Expand Down
9 changes: 9 additions & 0 deletions src/platform/windows/win_tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,13 @@ struct nni_tcp_conn {

extern int nni_win_tcp_init(nni_tcp_conn **, SOCKET);

// Following functions are wrappers around Windows functions that have to be
// looked up by pointer/GUID.
extern int nni_win_acceptex(
SOCKET listen, SOCKET child, void *buf, LPOVERLAPPED olpd);
extern void nni_win_get_acceptex_sockaddrs(
void *buf, SOCKADDR_STORAGE *self, SOCKADDR_STORAGE *peer);
extern int nni_win_connectex(
SOCKET s, SOCKADDR *peer, int peer_len, LPOVERLAPPED olpd);

#endif // NNG_PLATFORM_WIN_WINTCP_H
27 changes: 2 additions & 25 deletions src/platform/windows/win_tcpdial.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
#include <stdio.h>

struct nni_tcp_dialer {
LPFN_CONNECTEX connectex; // looked up name via ioctl
nni_list aios; // in flight connections
nni_list aios; // in flight connections
bool closed;
bool nodelay; // initial value for child conns
bool keepalive; // initial value for child conns
Expand All @@ -40,31 +39,10 @@ nni_tcp_dialer_init(nni_tcp_dialer **dp)
if ((d = NNI_ALLOC_STRUCT(d)) == NULL) {
return (NNG_ENOMEM);
}
ZeroMemory(d, sizeof(*d));
nni_mtx_init(&d->mtx);
nni_aio_list_init(&d->aios);
d->nodelay = true;

// Create a scratch socket for use with ioctl.
s = socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP);
if (s == INVALID_SOCKET) {
rv = nni_win_error(GetLastError());
nni_tcp_dialer_fini(d);
return (rv);
}

// Look up the function pointer.
if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid,
sizeof(guid), &d->connectex, sizeof(d->connectex), &nbytes,
NULL, NULL) == SOCKET_ERROR) {
rv = nni_win_error(GetLastError());
closesocket(s);
nni_tcp_dialer_fini(d);
return (rv);
}

closesocket(s);

*dp = d;
return (0);
}
Expand Down Expand Up @@ -255,8 +233,7 @@ nni_tcp_dial(nni_tcp_dialer *d, const nni_sockaddr *sa, nni_aio *aio)
nni_aio_list_append(&d->aios, aio);

// dialing is concurrent.
if (!d->connectex(s, (struct sockaddr *) &c->peername, len, NULL, 0,
NULL, &c->conn_io.olpd)) {
if (!nni_win_connectex(s, &c->peername, len, &c->conn_io.olpd)) {
if ((rv = GetLastError()) != ERROR_IO_PENDING) {
nni_aio_list_remove(aio);
nni_mtx_unlock(&d->mtx);
Expand Down
96 changes: 17 additions & 79 deletions src/platform/windows/win_tcplisten.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,72 +18,23 @@
#include "win_tcp.h"

typedef struct tcp_listener {
nng_stream_listener ops;
nng_sockaddr sa;
SOCKET s;
nni_list aios;
bool closed;
bool started;
bool nodelay; // initial value for child conns
bool keepalive; // initial value for child conns
bool running;
LPFN_ACCEPTEX acceptex;
LPFN_GETACCEPTEXSOCKADDRS getacceptexsockaddrs;
SOCKADDR_STORAGE ss;
nni_mtx mtx;
nni_reap_node reap;
nni_win_io accept_io;
int accept_rv;
nni_tcp_conn *pend_conn;
nng_stream_listener ops;
nng_sockaddr sa;
SOCKET s;
nni_list aios;
bool closed;
bool started;
bool nodelay; // initial value for child conns
bool keepalive; // initial value for child conns
bool running;
SOCKADDR_STORAGE ss;
nni_mtx mtx;
nni_reap_node reap;
nni_win_io accept_io;
int accept_rv;
nni_tcp_conn *pend_conn;
} tcp_listener;

// tcp_listener_funcs looks up function pointers we need for advanced accept
// functionality on Windows. Windows is weird.
static int
tcp_listener_funcs(tcp_listener *l)
{
static SRWLOCK lock = SRWLOCK_INIT;
static LPFN_ACCEPTEX acceptex;
static LPFN_GETACCEPTEXSOCKADDRS getacceptexsockaddrs;

AcquireSRWLockExclusive(&lock);
if (acceptex == NULL) {
int rv;
DWORD nbytes;
GUID guid1 = WSAID_ACCEPTEX;
GUID guid2 = WSAID_GETACCEPTEXSOCKADDRS;
SOCKET s = socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP);

if (s == INVALID_SOCKET) {
rv = nni_win_error(GetLastError());
ReleaseSRWLockExclusive(&lock);
return (rv);
}

// Look up the function pointer.
if ((WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid1,
sizeof(guid1), &acceptex, sizeof(acceptex), &nbytes,
NULL, NULL) == SOCKET_ERROR) ||
(WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid2,
sizeof(guid2), &getacceptexsockaddrs,
sizeof(getacceptexsockaddrs), &nbytes, NULL,
NULL) == SOCKET_ERROR)) {
rv = nni_win_error(GetLastError());
acceptex = NULL;
getacceptexsockaddrs = NULL;
ReleaseSRWLockExclusive(&lock);
closesocket(s);
return (rv);
}
closesocket(s);
}
ReleaseSRWLockExclusive(&lock);

l->acceptex = acceptex;
l->getacceptexsockaddrs = getacceptexsockaddrs;
return (0);
}

static void tcp_listener_accepted(tcp_listener *l);
static void tcp_listener_doaccept(tcp_listener *l);
static void tcp_listener_free(void *arg);
Expand Down Expand Up @@ -269,10 +220,6 @@ tcp_accept_cancel(nni_aio *aio, void *arg, int rv)
static void
tcp_listener_accepted(tcp_listener *l)
{
int len1;
int len2;
SOCKADDR *sa1;
SOCKADDR *sa2;
BOOL nd;
BOOL ka;
nni_tcp_conn *c;
Expand All @@ -281,15 +228,11 @@ tcp_listener_accepted(tcp_listener *l)
aio = nni_list_first(&l->aios);
c = l->pend_conn;
l->pend_conn = NULL;
len1 = (int) sizeof(c->sockname);
len2 = (int) sizeof(c->peername);
ka = l->keepalive;
nd = l->nodelay;

nni_aio_list_remove(aio);
l->getacceptexsockaddrs(c->buf, 0, 256, 256, &sa1, &len1, &sa2, &len2);
memcpy(&c->sockname, sa1, len1);
memcpy(&c->peername, sa2, len2);
nni_win_get_acceptex_sockaddrs(c->buf, &c->sockname, &c->peername);

(void) setsockopt(c->s, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT,
(char *) &l->s, sizeof(l->s));
Expand Down Expand Up @@ -337,8 +280,7 @@ tcp_listener_doaccept(tcp_listener *l)
}
c->listener = l;
l->pend_conn = c;
if (l->acceptex(l->s, s, c->buf, 0, 256, 256, &cnt,
&l->accept_io.olpd)) {
if (nni_win_acceptex(l->s, s, c->buf, &l->accept_io.olpd)) {
// completed synchronously
tcp_listener_accepted(l);
continue;
Expand Down Expand Up @@ -546,10 +488,6 @@ tcp_listener_alloc_addr(nng_stream_listener **lp, const nng_sockaddr *sa)
nni_mtx_init(&l->mtx);
nni_aio_list_init(&l->aios);
nni_win_io_init(&l->accept_io, tcp_accept_cb, l);
if ((rv = tcp_listener_funcs(l)) != 0) {
NNI_FREE_STRUCT(l);
return (rv);
}

// We assume these defaults -- not everyone will agree, but anyone
// can change them.
Expand Down

0 comments on commit 6a41461

Please sign in to comment.