diff --git a/src/client_map.cc b/src/client_map.cc new file mode 100644 index 000000000..f558f98e3 --- /dev/null +++ b/src/client_map.cc @@ -0,0 +1,91 @@ +#include "client_map.h" + +namespace pikiwidb { + +uint32_t ClientMap::GetAllClientInfos(std::vector& results) { + // client info string type: ip, port, fd. + std::shared_lock client_map_lock(client_map_mutex_); + auto it = clients_.begin(); + while (it != clients_.end()) { + auto client = it->second.lock(); + if (client) { + results.emplace_back(client->GetClientInfo()); + } + it++; + } + return results.size(); +} + +bool ClientMap::AddClient(int id, std::weak_ptr client) { + std::unique_lock client_map_lock(client_map_mutex_); + if (clients_.find(id) == clients_.end()) { + clients_.insert({id, client}); + return true; + } + return false; +} + +ClientInfo ClientMap::GetClientsInfoById(int id) { + std::shared_lock client_map_lock(client_map_mutex_); + if (auto it = clients_.find(id); it != clients_.end()) { + if (auto client = it->second.lock(); client) { + return client->GetClientInfo(); + } + } + return ClientInfo::invalidClientInfo; +} + +bool ClientMap::RemoveClientById(int id) { + std::unique_lock client_map_lock(client_map_mutex_); + if (auto it = clients_.find(id); it != clients_.end()) { + clients_.erase(it); + return true; + } + return false; +} + +bool ClientMap::KillAllClients() { + std::shared_lock client_map_lock(client_map_mutex_); + auto it = clients_.begin(); + while (it != clients_.end()) { + auto client = it->second.lock(); + if (client) { + client_map_lock.unlock(); + client->Close(); + client_map_lock.lock(); + } + it++; + } + return true; +} + +bool ClientMap::KillClientByAddrPort(const std::string& addr_port) { + std::shared_lock client_map_lock(client_map_mutex_); + for (auto& [id, client_weak] : clients_) { + auto client = client_weak.lock(); + if (client) { + std::string client_ip_port = client->PeerIP() + ":" + std::to_string(client->PeerPort()); + if (client_ip_port == addr_port) { + client_map_lock.unlock(); + client->Close(); + return true; + } + } + } + return false; +} + +bool ClientMap::KillClientById(int client_id) { + std::shared_lock client_map_lock(client_map_mutex_); + if (auto it = clients_.find(client_id); it != clients_.end()) { + auto client = it->second.lock(); + if (client) { + client_map_lock.unlock(); + client->Close(); + return true; + } + } + return false; +} + +} // namespace pikiwidb \ No newline at end of file diff --git a/src/client_map.h b/src/client_map.h new file mode 100644 index 000000000..a1aa674f6 --- /dev/null +++ b/src/client_map.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include +#include +#include "client.h" + +namespace pikiwidb { + +class ClientMap { + private: + ClientMap() = default; + // 禁用复制构造函数和赋值运算符 + + private: + std::map> clients_; + std::shared_mutex client_map_mutex_; + + public: + static ClientMap& getInstance() { + static ClientMap instance; + return instance; + } + + ClientMap(const ClientMap&) = delete; + ClientMap& operator=(const ClientMap&) = delete; + + // client info function + pikiwidb::ClientInfo GetClientsInfoById(int id); + uint32_t GetAllClientInfos(std::vector& results); + + bool AddClient(int id, std::weak_ptr); + + bool RemoveClientById(int id); + + bool KillAllClients(); + bool KillClientById(int client_id); + bool KillClientByAddrPort(const std::string& addr_port); +}; + +} // namespace pikiwidb \ No newline at end of file diff --git a/src/cmd_admin.cc b/src/cmd_admin.cc index 8e747409b..9b56cc453 100644 --- a/src/cmd_admin.cc +++ b/src/cmd_admin.cc @@ -21,6 +21,7 @@ #include "praft/praft.h" #include "pstd/env.h" +#include "client_map.h" #include "store.h" namespace pikiwidb { @@ -515,19 +516,20 @@ bool CmdClientKill::DoInitial(PClient* client) { void CmdClientKill::DoCmd(PClient* client) { bool ret; + auto& client_map = pikiwidb::ClientMap::getInstance(); switch (kill_type_) { case Type::ALL: { - ret = g_pikiwidb->KillAllClients(); + ret = client_map.KillAllClients(); break; } case Type::ADDR: { - ret = g_pikiwidb->KillClientByAddrPort(client->argv_[3]); + ret = client_map.KillClientByAddrPort(client->argv_[3]); break; } case Type::ID: { try { int client_id = stoi(client->argv_[3]); - ret = g_pikiwidb->KillClientById(client_id); + ret = client_map.KillClientById(client_id); } catch (const std::exception& e) { client->SetRes(CmdRes::kErrOther, "Invalid client id"); return; @@ -556,10 +558,11 @@ bool CmdClientList::DoInitial(PClient* client) { } void CmdClientList::DoCmd(PClient* client) { + auto& client_map = ClientMap::getInstance(); switch (list_type_) { case Type::DEFAULT: { std::vector client_infos; - g_pikiwidb->GetAllClientInfos(client_infos); + client_map.GetAllClientInfos(client_infos); client->AppendArrayLen(client_infos.size()); if (client_infos.size() == 0) { return; @@ -579,7 +582,7 @@ void CmdClientList::DoCmd(PClient* client) { for (size_t i = 3; i < client->argv_.size(); i++) { try { int client_id = std::stoi(client->argv_[i]); - auto client_info = g_pikiwidb->GetClientsInfoById(client_id); + auto client_info = client_map.GetClientsInfoById(client_id); if (client_info == ClientInfo::invalidClientInfo) { client->SetRes(CmdRes::kErrOther, "Invalid client id"); return; diff --git a/src/pikiwidb.cc b/src/pikiwidb.cc index 42b58eeb1..d67f1930e 100644 --- a/src/pikiwidb.cc +++ b/src/pikiwidb.cc @@ -22,6 +22,7 @@ #include "pstd/pstd_util.h" #include "client.h" +#include "client_map.h" #include "config.h" #include "helper.h" #include "pikiwidb_logo.h" @@ -118,90 +119,14 @@ void PikiwiDB::OnNewConnection(pikiwidb::TcpConnection* obj) { obj->SetOnDisconnect([](pikiwidb::TcpConnection* obj) { INFO("disconnect from {}", obj->GetPeerIP()); obj->GetContext()->SetState(pikiwidb::ClientState::kClosed); - g_pikiwidb->RemoveClientMetaById(obj->GetUniqueId()); + ClientMap::getInstance().RemoveClientById(obj->GetUniqueId()); }); obj->SetNodelay(true); obj->SetEventLoopSelector([this]() { return worker_threads_.ChooseNextWorkerEventLoop(); }); obj->SetSlaveEventLoopSelector([this]() { return slave_threads_.ChooseNextWorkerEventLoop(); }); // add new PClient to clients - clients.insert({client->GetUniqueId(), client}); -} - -uint32_t PikiwiDB::GetAllClientInfos(std::vector& results) { - // client info string type: ip, port, fd. - std::shared_lock client_map_lock(client_map_mutex); - auto it = clients.begin(); - while (it != clients.end()) { - auto client = it->second.lock(); - if (client) { - results.emplace_back(client->GetClientInfo()); - } - it++; - } - return results.size(); -} -ClientInfo PikiwiDB::GetClientsInfoById(int id) { - std::shared_lock client_map_lock(client_map_mutex); - if (auto it = clients.find(id); it != clients.end()) { - if (auto client = it->second.lock(); client) { - return client->GetClientInfo(); - } - } - return ClientInfo::invalidClientInfo; -} - -bool PikiwiDB::RemoveClientMetaById(int id) { - std::unique_lock client_map_lock(client_map_mutex); - if (auto it = clients.find(id); it != clients.end()) { - clients.erase(it); - return true; - } - return false; -} - -bool PikiwiDB::KillAllClients() { - std::shared_lock client_map_lock(client_map_mutex); - auto it = clients.begin(); - while (it != clients.end()) { - auto client = it->second.lock(); - if (client) { - client_map_lock.unlock(); - client->Close(); - client_map_lock.lock(); - } - it++; - } - return true; -} - -bool PikiwiDB::KillClientByAddrPort(const std::string& addr_port) { - std::shared_lock client_map_lock(client_map_mutex); - for (auto& [id, client_weak] : clients) { - auto client = client_weak.lock(); - if (client) { - std::string client_ip_port = client->PeerIP() + ":" + std::to_string(client->PeerPort()); - if (client_ip_port == addr_port) { - client_map_lock.unlock(); - client->Close(); - return true; - } - } - } - return false; -} - -bool PikiwiDB::KillClientById(int client_id) { - std::shared_lock client_map_lock(client_map_mutex); - if (auto it = clients.find(client_id); it != clients.end()) { - auto client = it->second.lock(); - if (client) { - client_map_lock.unlock(); - client->Close(); - return true; - } - } - return false; + ClientMap::getInstance().AddClient(client->GetUniqueId(), client); } bool PikiwiDB::Init() { diff --git a/src/pikiwidb.h b/src/pikiwidb.h index c4bfc1f23..61f423079 100644 --- a/src/pikiwidb.h +++ b/src/pikiwidb.h @@ -5,6 +5,7 @@ * of patent rights can be found in the PATENTS file in the same directory. */ +#include "client_map.h" #include "cmd_table_manager.h" #include "cmd_thread_pool.h" #include "common.h" @@ -45,16 +46,6 @@ class PikiwiDB final { void PushWriteTask(const std::shared_ptr& client) { worker_threads_.PushWriteTask(client); } - // client message function - uint32_t GetAllClientInfos(std::vector& results); - pikiwidb::ClientInfo GetClientsInfoById(int id); - - bool RemoveClientMetaById(int id); - - bool KillAllClients(); - bool KillClientByAddrPort(const std::string& addr_port); - bool KillClientById(int client_id); - public: PString cfg_file_; uint16_t port_{0}; @@ -71,9 +62,7 @@ class PikiwiDB final { pikiwidb::CmdThreadPool cmd_threads_; // pikiwidb::CmdTableManager cmd_table_manager_; // use std::list to store client pointer as a double linked list - std::shared_mutex client_map_mutex; std::mutex killer_mutex; - std::map> clients; uint32_t cmd_id_ = 0; std::atomic client_id_ = 0; };