Skip to content

Commit

Permalink
BugFix: double close when accept fail (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
yujun411522 authored Nov 3, 2023
1 parent 983ad47 commit ef36f11
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 4 deletions.
3 changes: 2 additions & 1 deletion trpc/runtime/iomodel/reactor/default/tcp_acceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ int TcpAcceptor::HandleReadEvent() {
info.conn_info.remote_addr = std::move(peer_addr);

if (!accept_handler(info)) {
close(conn_fd);
// Do not close socket when accept_handler fail(accept_handler will close), or may double-close socket
TRPC_LOG_ERROR("tcp accept handler return false");
}
} else {
TRPC_ASSERT("tcp accept handler empty");
Expand Down
51 changes: 51 additions & 0 deletions trpc/runtime/iomodel/reactor/default/tcp_acceptor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,55 @@ TEST_F(TcpAcceptorTest, All) {
ASSERT_EQ(ret, 1);
}

TEST_F(TcpAcceptorTest, TcpAcceptorFail) {
Latch l(1);
std::thread t1([this, &l]() {
l.count_down();
this->reactor_->Run();
});

l.wait();

int ret = 0;

Latch accept_latch(1);
auto&& accept_handler = [&ret, &accept_latch](const AcceptConnectionInfo& connection_info) {
ret = 1;
accept_latch.count_down();
return false;
};

// EnableListen fail
// Invalid IP "1.1.1.1" may bind socket fail
NetworkAddress bad_addr = NetworkAddress("1.1.1.1", 10000, NetworkAddress::IpType::kIpV4);
RefPtr<TcpAcceptor> fail_acceptor = MakeRefCounted<TcpAcceptor>(reactor_.get(), bad_addr);
bool listen_fail = fail_acceptor->EnableListen();
EXPECT_EQ(listen_fail, false);

NetworkAddress addr = NetworkAddress(trpc::util::GenRandomAvailablePort(), false, NetworkAddress::IpType::kIpV4);

RefPtr<TcpAcceptor> acceptor = MakeRefCounted<TcpAcceptor>(reactor_.get(), addr);
acceptor->SetAcceptHandleFunction(std::move(accept_handler));
acceptor->SetAcceptSetSocketOptFunction([](Socket& socket) { return; });
acceptor->EnableListen();

std::unique_ptr<Socket> socket = std::make_unique<Socket>(Socket::CreateTcpSocket(false));

int succ = socket->Connect(addr);

EXPECT_EQ(succ, 0);

accept_latch.wait();

socket->Close();

acceptor->DisableListen();

reactor_->Stop();

t1.join();

EXPECT_EQ(ret, 1);
}

} // namespace trpc::testing
3 changes: 2 additions & 1 deletion trpc/runtime/iomodel/reactor/default/uds_acceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ int UdsAcceptor::HandleReadEvent() {
info.conn_info.is_net = false;

if (!accept_handler(info)) {
close(conn_fd);
// Do not close socket when accept_handler fail(accept_handler will close), or may double-close socket
TRPC_LOG_ERROR("unix accept handler return false");
}
} else {
TRPC_ASSERT("unix accept handler empty");
Expand Down
51 changes: 51 additions & 0 deletions trpc/runtime/iomodel/reactor/default/uds_acceptor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,57 @@ TEST_F(UdsAcceptorTest, UdsAcceptorOk) {
ASSERT_EQ(ret, 1);
}

TEST_F(UdsAcceptorTest, UdsAcceptorFail) {
Latch l(1);
std::thread t1([this, &l]() {
l.count_down();
this->reactor_->Run();
});

l.wait();

int ret = 0;
char path[] = "uds_accpetor_test.socket";
UnixAddress addr = UnixAddress(path);

Latch accept_latch(1);
auto&& accept_handler = [&ret, &accept_latch](const AcceptConnectionInfo& connection_info) {
ret = 1;
accept_latch.count_down();
return false;
};

sockaddr_un bad_socket;
UnixAddress bad_addr = UnixAddress(&bad_socket);
RefPtr<UdsAcceptor> fail_acceptor = MakeRefCounted<UdsAcceptor>(reactor_.get(), bad_addr);
bool listen_fail = fail_acceptor->EnableListen();
EXPECT_EQ(listen_fail, false);

RefPtr<UdsAcceptor> acceptor = MakeRefCounted<UdsAcceptor>(reactor_.get(), addr);
acceptor->SetAcceptHandleFunction(std::move(accept_handler));
auto func = [](Socket& s) {};
acceptor->SetAcceptSetSocketOptFunction(func);
acceptor->EnableListen();

std::unique_ptr<Socket> socket = std::make_unique<Socket>(Socket::CreateUnixSocket());

int succ = socket->Connect(addr);

EXPECT_EQ(succ, 0);

accept_latch.wait();

socket->Close();

acceptor->DisableListen();

reactor_->Stop();

t1.join();

EXPECT_EQ(ret, 1);
}

} // namespace testing

} // namespace trpc
6 changes: 4 additions & 2 deletions trpc/runtime/iomodel/reactor/fiber/fiber_acceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ FiberConnection::EventAction FiberAcceptor::OnTcpReadable() {
info.conn_info.remote_addr = std::move(peer_addr);

if (!accept_handler_(info)) {
::close(conn_fd);
// Do not close socket when accept_handler fail(accept_handler will close), or may double-close socket
TRPC_LOG_ERROR("FiberAcceptor::OnTcpReadable accept handler return false");
}
} else {
TRPC_LOG_ERROR("FiberAcceptor::OnTcpReadable accept handler empty.");
Expand Down Expand Up @@ -162,7 +163,8 @@ FiberConnection::EventAction FiberAcceptor::OnUdsReadable() {
info.conn_info.is_net = false;

if (!accept_handler_(info)) {
::close(conn_fd);
// Do not close socket when accept_handler fail(accept_handler will close), or may double-close socket
TRPC_LOG_ERROR("FiberAcceptor::OnUdsReadable accept handler return false");
}
} else {
TRPC_LOG_ERROR("FiberAcceptor::OnUdsReadable accept handler empty.");
Expand Down
46 changes: 46 additions & 0 deletions trpc/runtime/iomodel/reactor/fiber/fiber_tcp_connection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,25 @@ class FiberTcpConnectionTestImpl {
acceptor_->SetAcceptHandleFunction(std::move(accept_handler));
std::cout << "fiber acceptor listen" << std::endl;
acceptor_->Listen();

bad_tcp_accept_addr_ = NetworkAddress("0.0.0.0", 50000, NetworkAddress::IpType::kIpV4);
tcp_fail_acceptor_ = MakeRefCounted<FiberAcceptor>(reactor_, bad_tcp_accept_addr_);
trpc::AcceptConnectionFunction tcp_fail_accept_handler = [this](AcceptConnectionInfo& connection_info) {
return false;
};
tcp_fail_acceptor_->SetAcceptHandleFunction(std::move(tcp_fail_accept_handler));
std::cout << "tcp_fail_acceptor_ listen" << std::endl;
tcp_fail_acceptor_->Listen();

char path[] = "fiber_uds_accpet_fail_test.socket";
UnixAddress fiber_uds_accpet_addr = UnixAddress(path);
uds_fail_acceptor_ = MakeRefCounted<FiberAcceptor>(reactor_, fiber_uds_accpet_addr);
trpc::AcceptConnectionFunction uds_fail_accept_handler = [this](AcceptConnectionInfo& connection_info) {
return false;
};
uds_fail_acceptor_->SetAcceptHandleFunction(std::move(uds_fail_accept_handler));
std::cout << "uds_fail_acceptor_ listen" << std::endl;
uds_fail_acceptor_->Listen();
}

void TearDown() {
Expand All @@ -125,6 +144,12 @@ class FiberTcpConnectionTestImpl {

acceptor_->Stop();
acceptor_->Join();

tcp_fail_acceptor_->Stop();
tcp_fail_acceptor_->Join();

uds_fail_acceptor_->Stop();
uds_fail_acceptor_->Join();
}

template <class IoHandlerType>
Expand Down Expand Up @@ -168,6 +193,9 @@ class FiberTcpConnectionTestImpl {
RefPtr<FiberTcpConnection> server_conn_{nullptr};
std::atomic<std::size_t> server_received_{0};
std::atomic<std::size_t> client_received_{0};
NetworkAddress bad_tcp_accept_addr_;
RefPtr<FiberAcceptor> tcp_fail_acceptor_{nullptr};
RefPtr<FiberAcceptor> uds_fail_acceptor_{nullptr};
};

class FiberTcpConnectionTest : public ::testing::Test {
Expand Down Expand Up @@ -227,6 +255,24 @@ TEST_F(FiberTcpConnectionTest, WriteEmpty) {
EXPECT_EQ(0, client_conn->Send(std::move(msg)));
}

TEST_F(FiberTcpConnectionTest, TcpAcceptFail) {
std::unique_ptr<Socket> socket = std::make_unique<Socket>(Socket::CreateTcpSocket(false));
NetworkAddress bad_tcp_accept_addr = NetworkAddress("0.0.0.0", 50000, NetworkAddress::IpType::kIpV4);
int succ = socket->Connect(bad_tcp_accept_addr);
EXPECT_EQ(succ, 0);
}

TEST_F(FiberTcpConnectionTest, UdsAcceptFail) {
char path[] = "fiber_uds_accpet_fail_test.socket";
UnixAddress fiber_uds_accpet_addr = UnixAddress(path);

std::unique_ptr<Socket> socket = std::make_unique<Socket>(Socket::CreateUnixSocket());

int succ = socket->Connect(fiber_uds_accpet_addr);

EXPECT_EQ(succ, 0);
}

class WriteErrorIoHandler : public DefaultIoHandler {
public:
explicit WriteErrorIoHandler(Connection* conn) : DefaultIoHandler(conn) {}
Expand Down

1 comment on commit ef36f11

@guzitajiu123
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good job

Please sign in to comment.