diff --git a/trpc/runtime/iomodel/reactor/default/tcp_acceptor.cc b/trpc/runtime/iomodel/reactor/default/tcp_acceptor.cc index 5f754465..6d693a8f 100644 --- a/trpc/runtime/iomodel/reactor/default/tcp_acceptor.cc +++ b/trpc/runtime/iomodel/reactor/default/tcp_acceptor.cc @@ -111,7 +111,8 @@ int TcpAcceptor::HandleReadEvent() { info.conn_info.remote_addr = std::move(peer_addr); if (!accept_handler(info)) { - close(conn_fd); + // Do not close socket when accept_handler fail(accept_handler will close), or may double-close socket + TRPC_LOG_ERROR("tcp accept handler return false"); } } else { TRPC_ASSERT("tcp accept handler empty"); diff --git a/trpc/runtime/iomodel/reactor/default/tcp_acceptor_test.cc b/trpc/runtime/iomodel/reactor/default/tcp_acceptor_test.cc index 2fc5d01d..698071ff 100644 --- a/trpc/runtime/iomodel/reactor/default/tcp_acceptor_test.cc +++ b/trpc/runtime/iomodel/reactor/default/tcp_acceptor_test.cc @@ -115,4 +115,55 @@ TEST_F(TcpAcceptorTest, All) { ASSERT_EQ(ret, 1); } +TEST_F(TcpAcceptorTest, TcpAcceptorFail) { + Latch l(1); + std::thread t1([this, &l]() { + l.count_down(); + this->reactor_->Run(); + }); + + l.wait(); + + int ret = 0; + + Latch accept_latch(1); + auto&& accept_handler = [&ret, &accept_latch](const AcceptConnectionInfo& connection_info) { + ret = 1; + accept_latch.count_down(); + return false; + }; + + // EnableListen fail + // Invalid IP "1.1.1.1" may bind socket fail + NetworkAddress bad_addr = NetworkAddress("1.1.1.1", 10000, NetworkAddress::IpType::kIpV4); + RefPtr fail_acceptor = MakeRefCounted(reactor_.get(), bad_addr); + bool listen_fail = fail_acceptor->EnableListen(); + EXPECT_EQ(listen_fail, false); + + NetworkAddress addr = NetworkAddress(trpc::util::GenRandomAvailablePort(), false, NetworkAddress::IpType::kIpV4); + + RefPtr acceptor = MakeRefCounted(reactor_.get(), addr); + acceptor->SetAcceptHandleFunction(std::move(accept_handler)); + acceptor->SetAcceptSetSocketOptFunction([](Socket& socket) { return; }); + acceptor->EnableListen(); + + std::unique_ptr socket = std::make_unique(Socket::CreateTcpSocket(false)); + + int succ = socket->Connect(addr); + + EXPECT_EQ(succ, 0); + + accept_latch.wait(); + + socket->Close(); + + acceptor->DisableListen(); + + reactor_->Stop(); + + t1.join(); + + EXPECT_EQ(ret, 1); +} + } // namespace trpc::testing diff --git a/trpc/runtime/iomodel/reactor/default/uds_acceptor.cc b/trpc/runtime/iomodel/reactor/default/uds_acceptor.cc index 30739684..9eff9380 100644 --- a/trpc/runtime/iomodel/reactor/default/uds_acceptor.cc +++ b/trpc/runtime/iomodel/reactor/default/uds_acceptor.cc @@ -105,7 +105,8 @@ int UdsAcceptor::HandleReadEvent() { info.conn_info.is_net = false; if (!accept_handler(info)) { - close(conn_fd); + // Do not close socket when accept_handler fail(accept_handler will close), or may double-close socket + TRPC_LOG_ERROR("unix accept handler return false"); } } else { TRPC_ASSERT("unix accept handler empty"); diff --git a/trpc/runtime/iomodel/reactor/default/uds_acceptor_test.cc b/trpc/runtime/iomodel/reactor/default/uds_acceptor_test.cc index 25db34a6..2d5e0261 100644 --- a/trpc/runtime/iomodel/reactor/default/uds_acceptor_test.cc +++ b/trpc/runtime/iomodel/reactor/default/uds_acceptor_test.cc @@ -97,6 +97,57 @@ TEST_F(UdsAcceptorTest, UdsAcceptorOk) { ASSERT_EQ(ret, 1); } +TEST_F(UdsAcceptorTest, UdsAcceptorFail) { + Latch l(1); + std::thread t1([this, &l]() { + l.count_down(); + this->reactor_->Run(); + }); + + l.wait(); + + int ret = 0; + char path[] = "uds_accpetor_test.socket"; + UnixAddress addr = UnixAddress(path); + + Latch accept_latch(1); + auto&& accept_handler = [&ret, &accept_latch](const AcceptConnectionInfo& connection_info) { + ret = 1; + accept_latch.count_down(); + return false; + }; + + sockaddr_un bad_socket; + UnixAddress bad_addr = UnixAddress(&bad_socket); + RefPtr fail_acceptor = MakeRefCounted(reactor_.get(), bad_addr); + bool listen_fail = fail_acceptor->EnableListen(); + EXPECT_EQ(listen_fail, false); + + RefPtr acceptor = MakeRefCounted(reactor_.get(), addr); + acceptor->SetAcceptHandleFunction(std::move(accept_handler)); + auto func = [](Socket& s) {}; + acceptor->SetAcceptSetSocketOptFunction(func); + acceptor->EnableListen(); + + std::unique_ptr socket = std::make_unique(Socket::CreateUnixSocket()); + + int succ = socket->Connect(addr); + + EXPECT_EQ(succ, 0); + + accept_latch.wait(); + + socket->Close(); + + acceptor->DisableListen(); + + reactor_->Stop(); + + t1.join(); + + EXPECT_EQ(ret, 1); +} + } // namespace testing } // namespace trpc diff --git a/trpc/runtime/iomodel/reactor/fiber/fiber_acceptor.cc b/trpc/runtime/iomodel/reactor/fiber/fiber_acceptor.cc index e31516f8..c2203539 100644 --- a/trpc/runtime/iomodel/reactor/fiber/fiber_acceptor.cc +++ b/trpc/runtime/iomodel/reactor/fiber/fiber_acceptor.cc @@ -122,7 +122,8 @@ FiberConnection::EventAction FiberAcceptor::OnTcpReadable() { info.conn_info.remote_addr = std::move(peer_addr); if (!accept_handler_(info)) { - ::close(conn_fd); + // Do not close socket when accept_handler fail(accept_handler will close), or may double-close socket + TRPC_LOG_ERROR("FiberAcceptor::OnTcpReadable accept handler return false"); } } else { TRPC_LOG_ERROR("FiberAcceptor::OnTcpReadable accept handler empty."); @@ -162,7 +163,8 @@ FiberConnection::EventAction FiberAcceptor::OnUdsReadable() { info.conn_info.is_net = false; if (!accept_handler_(info)) { - ::close(conn_fd); + // Do not close socket when accept_handler fail(accept_handler will close), or may double-close socket + TRPC_LOG_ERROR("FiberAcceptor::OnUdsReadable accept handler return false"); } } else { TRPC_LOG_ERROR("FiberAcceptor::OnUdsReadable accept handler empty."); diff --git a/trpc/runtime/iomodel/reactor/fiber/fiber_tcp_connection_test.cc b/trpc/runtime/iomodel/reactor/fiber/fiber_tcp_connection_test.cc index 5ad92a16..2465b46d 100644 --- a/trpc/runtime/iomodel/reactor/fiber/fiber_tcp_connection_test.cc +++ b/trpc/runtime/iomodel/reactor/fiber/fiber_tcp_connection_test.cc @@ -113,6 +113,25 @@ class FiberTcpConnectionTestImpl { acceptor_->SetAcceptHandleFunction(std::move(accept_handler)); std::cout << "fiber acceptor listen" << std::endl; acceptor_->Listen(); + + bad_tcp_accept_addr_ = NetworkAddress("0.0.0.0", 50000, NetworkAddress::IpType::kIpV4); + tcp_fail_acceptor_ = MakeRefCounted(reactor_, bad_tcp_accept_addr_); + trpc::AcceptConnectionFunction tcp_fail_accept_handler = [this](AcceptConnectionInfo& connection_info) { + return false; + }; + tcp_fail_acceptor_->SetAcceptHandleFunction(std::move(tcp_fail_accept_handler)); + std::cout << "tcp_fail_acceptor_ listen" << std::endl; + tcp_fail_acceptor_->Listen(); + + char path[] = "fiber_uds_accpet_fail_test.socket"; + UnixAddress fiber_uds_accpet_addr = UnixAddress(path); + uds_fail_acceptor_ = MakeRefCounted(reactor_, fiber_uds_accpet_addr); + trpc::AcceptConnectionFunction uds_fail_accept_handler = [this](AcceptConnectionInfo& connection_info) { + return false; + }; + uds_fail_acceptor_->SetAcceptHandleFunction(std::move(uds_fail_accept_handler)); + std::cout << "uds_fail_acceptor_ listen" << std::endl; + uds_fail_acceptor_->Listen(); } void TearDown() { @@ -125,6 +144,12 @@ class FiberTcpConnectionTestImpl { acceptor_->Stop(); acceptor_->Join(); + + tcp_fail_acceptor_->Stop(); + tcp_fail_acceptor_->Join(); + + uds_fail_acceptor_->Stop(); + uds_fail_acceptor_->Join(); } template @@ -168,6 +193,9 @@ class FiberTcpConnectionTestImpl { RefPtr server_conn_{nullptr}; std::atomic server_received_{0}; std::atomic client_received_{0}; + NetworkAddress bad_tcp_accept_addr_; + RefPtr tcp_fail_acceptor_{nullptr}; + RefPtr uds_fail_acceptor_{nullptr}; }; class FiberTcpConnectionTest : public ::testing::Test { @@ -227,6 +255,24 @@ TEST_F(FiberTcpConnectionTest, WriteEmpty) { EXPECT_EQ(0, client_conn->Send(std::move(msg))); } +TEST_F(FiberTcpConnectionTest, TcpAcceptFail) { + std::unique_ptr socket = std::make_unique(Socket::CreateTcpSocket(false)); + NetworkAddress bad_tcp_accept_addr = NetworkAddress("0.0.0.0", 50000, NetworkAddress::IpType::kIpV4); + int succ = socket->Connect(bad_tcp_accept_addr); + EXPECT_EQ(succ, 0); +} + +TEST_F(FiberTcpConnectionTest, UdsAcceptFail) { + char path[] = "fiber_uds_accpet_fail_test.socket"; + UnixAddress fiber_uds_accpet_addr = UnixAddress(path); + + std::unique_ptr socket = std::make_unique(Socket::CreateUnixSocket()); + + int succ = socket->Connect(fiber_uds_accpet_addr); + + EXPECT_EQ(succ, 0); +} + class WriteErrorIoHandler : public DefaultIoHandler { public: explicit WriteErrorIoHandler(Connection* conn) : DefaultIoHandler(conn) {}