diff --git a/contrib/ruby/test/client_test.rb b/contrib/ruby/test/client_test.rb index 33ef1fc5..99cb784e 100644 --- a/contrib/ruby/test/client_test.rb +++ b/contrib/ruby/test/client_test.rb @@ -998,6 +998,18 @@ def test_close_terminate_parent_connection end end + def test_discard_closes_connection + client = new_tcp_client + + assert_equal [1], client.query("SELECT 1").to_a.first + + client.discard! + + assert_raises Trilogy::ConnectionClosed do + client.query("SELECT 1") + end + end + def test_discard_doesnt_terminate_parent_connection skip("Fork isn't supported on this platform") unless Process.respond_to?(:fork) diff --git a/inc/trilogy/socket.h b/inc/trilogy/socket.h index 55bb1283..967b5615 100644 --- a/inc/trilogy/socket.h +++ b/inc/trilogy/socket.h @@ -110,6 +110,5 @@ static inline int trilogy_sock_fd(trilogy_sock_t *sock) { return sock->fd_cb(soc trilogy_sock_t *trilogy_sock_new(const trilogy_sockopt_t *opts); int trilogy_sock_resolve(trilogy_sock_t *raw); int trilogy_sock_upgrade_ssl(trilogy_sock_t *raw); -int trilogy_sock_discard(trilogy_sock_t *sock); #endif diff --git a/src/client.c b/src/client.c index e78abd6e..0b43f12c 100644 --- a/src/client.c +++ b/src/client.c @@ -770,7 +770,7 @@ void trilogy_free(trilogy_conn_t *conn) int trilogy_discard(trilogy_conn_t *conn) { - int rc = trilogy_sock_discard(conn->socket); + int rc = trilogy_sock_shutdown(conn->socket); if (rc == TRILOGY_OK) { trilogy_free(conn); } diff --git a/src/socket.c b/src/socket.c index bf9bf920..324191b9 100644 --- a/src/socket.c +++ b/src/socket.c @@ -110,7 +110,53 @@ static int _cb_raw_close(trilogy_sock_t *_sock) return rc; } -static int _cb_raw_shutdown(trilogy_sock_t *_sock) { return shutdown(trilogy_sock_fd(_sock), SHUT_RDWR); } +static int _cb_shutdown_connect(trilogy_sock_t *_sock) { + (void)_sock; + return TRILOGY_CLOSED_CONNECTION; +} +static ssize_t _cb_shutdown_write(trilogy_sock_t *_sock, const void *buf, size_t nwrite) { + (void)_sock; + (void)buf; + (void)nwrite; + return TRILOGY_CLOSED_CONNECTION; +} +static ssize_t _cb_shutdown_read(trilogy_sock_t *_sock, void *buf, size_t nread) { + (void)_sock; + (void)buf; + (void)nread; + return TRILOGY_CLOSED_CONNECTION; +} +static int _cb_shutdown_wait(trilogy_sock_t *_sock, trilogy_wait_t wait) { + (void)_sock; + (void)wait; + return TRILOGY_OK; +} +static int _cb_shutdown_shutdown(trilogy_sock_t *_sock) { + (void)_sock; + return TRILOGY_OK; +} + +// Shutdown will close the underlying socket fd and replace all I/O operations with stubs which perform no action. +static int _cb_raw_shutdown(trilogy_sock_t *_sock) { + struct trilogy_sock *sock = (struct trilogy_sock *)_sock; + + // Replace all operations with stubs which return immediately + sock->base.connect_cb = _cb_shutdown_connect; + sock->base.read_cb = _cb_shutdown_read; + sock->base.write_cb = _cb_shutdown_write; + sock->base.wait_cb = _cb_shutdown_wait; + sock->base.shutdown_cb = _cb_shutdown_shutdown; + + // These "raw" callbacks won't attempt further operations on the socket and work correctly with fd set to -1 + sock->base.close_cb = _cb_raw_close; + sock->base.fd_cb = _cb_raw_fd; + + if (sock->fd != -1) + close(sock->fd); + sock->fd = -1; + + return TRILOGY_OK; +} static int set_nonblocking_fd(int sock) { @@ -341,15 +387,9 @@ static int _cb_ssl_shutdown(trilogy_sock_t *_sock) // we need to close it. The OpenSSL explicitly states // not to call SSL_shutdown on a broken SSL socket. SSL_free(sock->ssl); - // Reset the handlers since we tore down SSL, so we - // fall back to the regular methods for detecting - // we have a closed connection and for the cleanup. - sock->base.read_cb = _cb_raw_read; - sock->base.write_cb = _cb_raw_write; - sock->base.shutdown_cb = _cb_raw_shutdown; - sock->base.close_cb = _cb_raw_close; sock->ssl = NULL; + // This will rewrite the handlers return _cb_raw_shutdown(_sock); } @@ -624,28 +664,3 @@ int trilogy_sock_upgrade_ssl(trilogy_sock_t *_sock) sock->ssl = NULL; return TRILOGY_OPENSSL_ERR; } - -int trilogy_sock_discard(trilogy_sock_t *_sock) -{ - struct trilogy_sock *sock = (struct trilogy_sock *)_sock; - - if (sock->fd < 0) { - return TRILOGY_OK; - } - - int null_fd = open("/dev/null", O_RDWR | O_CLOEXEC); - if (null_fd < 0) { - return TRILOGY_SYSERR; - } - - if (dup2(null_fd, sock->fd) < 0) { - close(null_fd); - return TRILOGY_SYSERR; - } - - if (close(null_fd) < 0) { - return TRILOGY_SYSERR; - } - - return TRILOGY_OK; -}