diff --git a/src/base_cmd.h b/src/base_cmd.h index 470052ccc..e8ff8c36c 100644 --- a/src/base_cmd.h +++ b/src/base_cmd.h @@ -76,6 +76,13 @@ const std::string kCmdNameUnwatch = "unwatch"; const std::string kCmdNameDiscard = "discard"; // admin +const std::string kCmdNameClient = "client"; +const std::string kSubCmdNameClientGetname = "getname"; +const std::string kSubCmdNameClientSetname = "setname"; +const std::string kSubCmdNameClientId = "id"; +const std::string kSubCmdNameClientList = "list"; +const std::string kSubCmdNameClientKill = "kill"; + const std::string kCmdNameConfig = "config"; const std::string kSubCmdNameConfigGet = "get"; const std::string kSubCmdNameConfigSet = "set"; diff --git a/src/client.cc b/src/client.cc index 248a1cd2f..2c12043ef 100644 --- a/src/client.cc +++ b/src/client.cc @@ -21,6 +21,8 @@ namespace pikiwidb { +const ClientInfo ClientInfo::invalidClientInfo = {0, "", -1}; + void CmdRes::RedisAppendLen(std::string& str, int64_t ori, const std::string& prefix) { str.append(prefix); str.append(pstd::Int2string(ori)); @@ -451,7 +453,7 @@ void PClient::OnConnect() { std::string PClient::PeerIP() const { if (!addr_.IsValid()) { - ERROR("Invalid address detected for client {}", uniqueID()); + ERROR("Invalid address detected for client {}", GetUniqueID()); return ""; } return addr_.GetIP(); @@ -459,7 +461,7 @@ std::string PClient::PeerIP() const { int PClient::PeerPort() const { if (!addr_.IsValid()) { - ERROR("Invalid address detected for client {}", uniqueID()); + ERROR("Invalid address detected for client {}", GetUniqueID()); return 0; } return addr_.GetPort(); @@ -506,7 +508,9 @@ bool PClient::isClusterCmdTarget() const { return PRAFT.GetClusterCmdCtx().GetPeerIp() == PeerIP() && PRAFT.GetClusterCmdCtx().GetPort() == PeerPort(); } -uint64_t PClient::uniqueID() const { return GetConnId(); } +uint64_t PClient::GetUniqueID() const { return GetConnId(); } + +ClientInfo PClient::GetClientInfo() const { return {GetUniqueID(), PeerIP().c_str(), PeerPort()}; } bool PClient::Watch(int dbno, const std::string& key) { DEBUG("Client {} watch {}, db {}", name_, key, dbno); @@ -515,12 +519,12 @@ bool PClient::Watch(int dbno, const std::string& key) { bool PClient::NotifyDirty(int dbno, const std::string& key) { if (IsFlagOn(kClientFlagDirty)) { - INFO("client is already dirty {}", uniqueID()); + INFO("client is already dirty {}", GetUniqueID()); return true; } if (watch_keys_[dbno].contains(key)) { - INFO("{} client become dirty because key {} in db {}", uniqueID(), key, dbno); + INFO("{} client become dirty because key {} in db {}", GetUniqueID(), key, dbno); SetFlag(kClientFlagDirty); return true; } else { diff --git a/src/client.h b/src/client.h index 1365a1db3..6f3f2c674 100644 --- a/src/client.h +++ b/src/client.h @@ -20,7 +20,6 @@ #include "storage/storage.h" namespace pikiwidb { - class CmdRes { public: enum CmdRet { @@ -118,6 +117,14 @@ enum class ClientState { class DB; struct PSlaveInfo; +struct ClientInfo { + uint64_t client_id; + std::string ip; + int port; + static const ClientInfo invalidClientInfo; + bool operator==(const ClientInfo& ci) const { return client_id == ci.client_id; } +}; + class PClient : public std::enable_shared_from_this, public CmdRes { public: // PClient() = delete; @@ -129,6 +136,8 @@ class PClient : public std::enable_shared_from_this, public CmdRes { std::string PeerIP() const; int PeerPort() const; + const int GetFd() const; + ClientInfo GetClientInfo() const; // bool SendPacket(const std::string& buf); // bool SendPacket(const void* data, size_t size); @@ -217,6 +226,7 @@ class PClient : public std::enable_shared_from_this, public CmdRes { void SetAuth() { auth_ = true; } bool GetAuth() const { return auth_; } + uint64_t GetUniqueID() const; void RewriteCmd(std::vector& params) { parser_.SetParams(params); } void Reexecutecommand() { this->executeCommand(); } @@ -244,7 +254,6 @@ class PClient : public std::enable_shared_from_this, public CmdRes { int processInlineCmd(const char*, size_t, std::vector&); void reset(); bool isPeerMaster() const; - uint64_t uniqueID() const; bool isClusterCmdTarget() const; diff --git a/src/client_map.cc b/src/client_map.cc new file mode 100644 index 000000000..12319917c --- /dev/null +++ b/src/client_map.cc @@ -0,0 +1,103 @@ +#include "client_map.h" +#include "log.h" + +namespace pikiwidb { + +uint32_t ClientMap::GetAllClientInfos(std::vector& results) { + // client info string type: ip, port, fd. + std::shared_lock client_map_lock(client_map_mutex_); + for (auto& [id, client_weak] : clients_) { + if (auto client = client_weak.lock()) { + results.emplace_back(client->GetClientInfo()); + } + } + return results.size(); +} + +bool ClientMap::AddClient(int id, std::weak_ptr client) { + std::unique_lock client_map_lock(client_map_mutex_); + if (clients_.find(id) == clients_.end()) { + clients_.insert({id, client}); + return true; + } + return false; +} + +ClientInfo ClientMap::GetClientsInfoById(int id) { + std::shared_lock client_map_lock(client_map_mutex_); + if (auto it = clients_.find(id); it != clients_.end()) { + if (auto client = it->second.lock(); client) { + return client->GetClientInfo(); + } + } + ERROR("Client with ID {} not found in GetClientsInfoById", id); + return ClientInfo::invalidClientInfo; +} + +bool ClientMap::RemoveClientById(int id) { + std::unique_lock client_map_lock(client_map_mutex_); + if (auto it = clients_.find(id); it != clients_.end()) { + clients_.erase(it); + INFO("Removed client with ID {}", id); + return true; + } + return false; +} + +bool ClientMap::KillAllClients() { + std::vector> clients_to_close; + { + std::shared_lock client_map_lock(client_map_mutex_); + for (auto& [id, client_weak] : clients_) { + if (auto client = client_weak.lock()) { + clients_to_close.push_back(client); + } + } + } + for (auto& client : clients_to_close) { + client->Close(); + } + return true; +} + +bool ClientMap::KillClientByAddrPort(const std::string& addr_port) { + std::shared_ptr client_to_close; + { + std::shared_lock client_map_lock(client_map_mutex_); + for (auto& [id, client_weak] : clients_) { + if (auto client = client_weak.lock()) { + std::string client_ip_port = client->PeerIP() + ":" + std::to_string(client->PeerPort()); + if (client_ip_port == addr_port) { + client_to_close = client; + break; + } + } + } + } + if (client_to_close) { + client_to_close->Close(); + return true; + } + return false; +} + +bool ClientMap::KillClientById(int client_id) { + std::shared_ptr client_to_close; + { + std::shared_lock client_map_lock(client_map_mutex_); + if (auto it = clients_.find(client_id); it != clients_.end()) { + if (auto client = it->second.lock()) { + client_to_close = client; + } + } + } + if (client_to_close) { + INFO("Closing client with ID {}", client_id); + client_to_close->Close(); + INFO("Client with ID {} closed", client_id); + return true; + } + return false; +} + +} // namespace pikiwidb \ No newline at end of file diff --git a/src/client_map.h b/src/client_map.h new file mode 100644 index 000000000..aba1e4ee6 --- /dev/null +++ b/src/client_map.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include +#include +#include "client.h" + +namespace pikiwidb { +class ClientMap { + private: + ClientMap() = default; + // 禁用复制构造函数和赋值运算符 + + private: + std::map> clients_; + std::shared_mutex client_map_mutex_; + + public: + static ClientMap& getInstance() { + static ClientMap instance; + return instance; + } + + ClientMap(const ClientMap&) = delete; + ClientMap& operator=(const ClientMap&) = delete; + + // client info function + pikiwidb::ClientInfo GetClientsInfoById(int id); + uint32_t GetAllClientInfos(std::vector& results); + + bool AddClient(int id, std::weak_ptr); + + bool RemoveClientById(int id); + + bool KillAllClients(); + bool KillClientById(int client_id); + bool KillClientByAddrPort(const std::string& addr_port); +}; + +} // namespace pikiwidb \ No newline at end of file diff --git a/src/cmd_admin.cc b/src/cmd_admin.cc index 8f601deb8..79b252750 100644 --- a/src/cmd_admin.cc +++ b/src/cmd_admin.cc @@ -21,6 +21,7 @@ #include "praft/praft.h" #include "pstd/env.h" +#include "client_map.h" #include "store.h" namespace pikiwidb { @@ -468,4 +469,140 @@ void SortCmd::InitialArgument() { get_patterns_.clear(); ret_.clear(); } +CmdClient::CmdClient(const std::string& name, int arity) + : BaseCmdGroup(name, kCmdFlagsReadonly | kCmdFlagsAdmin, kAclCategoryAdmin) {} + +bool CmdClient::HasSubCommand() const { return true; } + +CmdClientGetname::CmdClientGetname(const std::string& name, int16_t arity) + : BaseCmd(name, arity, kCmdFlagsAdmin | kCmdFlagsReadonly, kAclCategoryAdmin) {} + +bool CmdClientGetname::DoInitial(PClient* client) { return true; } + +void CmdClientGetname::DoCmd(PClient* client) { client->AppendString(client->GetName()); } + +CmdClientSetname::CmdClientSetname(const std::string& name, int16_t arity) + : BaseCmd(name, arity, kCmdFlagsAdmin | kCmdFlagsWrite, kAclCategoryAdmin) {} + +bool CmdClientSetname::DoInitial(PClient* client) { return true; } + +void pikiwidb::CmdClientSetname::DoCmd(PClient* client) { + client->SetName(client->argv_[2]); + client->SetRes(CmdRes::kOK); +} + +CmdClientId::CmdClientId(const std::string& name, int16_t arity) + : BaseCmd(name, arity, kCmdFlagsAdmin | kCmdFlagsReadonly, kAclCategoryAdmin) {} + +bool CmdClientId::DoInitial(PClient* client) { return true; } + +void CmdClientId::DoCmd(PClient* client) { client->AppendInteger(client->GetUniqueID()); } + +CmdClientKill::CmdClientKill(const std::string& name, int16_t arity) + : BaseCmd(name, arity, kCmdFlagsAdmin, kAclCategoryAdmin) {} + +bool CmdClientKill::DoInitial(PClient* client) { + if (client->argv_.size() == 3 && strcasecmp(client->argv_[2].data(), "all") == 0) { + kill_type_ = Type::ALL; + return true; + } else if (client->argv_.size() == 4 && strcasecmp(client->argv_[2].data(), "addr") == 0) { + kill_type_ = Type::ADDR; + return true; + } else if (client->argv_.size() == 4 && strcasecmp(client->argv_[2].data(), "id") == 0) { + kill_type_ = Type::ID; + return true; + } else { + client->SetRes(CmdRes::kWrongNum, client->CmdName()); + return false; + } +} + +void CmdClientKill::DoCmd(PClient* client) { + bool ret; + auto& client_map = pikiwidb::ClientMap::getInstance(); + switch (kill_type_) { + case Type::ALL: { + ret = client_map.KillAllClients(); + break; + } + case Type::ADDR: { + ret = client_map.KillClientByAddrPort(client->argv_[3]); + break; + } + case Type::ID: { + try { + int client_id = stoi(client->argv_[3]); + ret = client_map.KillClientById(client_id); + } catch (const std::exception& e) { + client->SetRes(CmdRes::kErrOther, "Invalid client id"); + return; + } + } + default: + break; + } + ret == true ? client->SetRes(CmdRes::kOK) : client->SetRes(CmdRes::kErrOther, "No such client"); +} + +CmdClientList::CmdClientList(const std::string& name, int16_t arity) + : BaseCmd(name, arity, kCmdFlagsAdmin | kCmdFlagsReadonly, kAclCategoryAdmin) {} + +bool CmdClientList::DoInitial(PClient* client) { + if (client->argv_.size() == 2) { + list_type_ = Type::DEFAULT; + return true; + } + if (client->argv_.size() > 3 && strcasecmp(client->argv_[2].data(), "id") == 0) { + list_type_ = Type::ID; + return true; + } + client->SetRes(CmdRes::kErrOther, "Syntax error, try CLIENT (LIST [ID client_id_1, client_id_2...])"); + return false; +} + +void CmdClientList::DoCmd(PClient* client) { + auto& client_map = ClientMap::getInstance(); + switch (list_type_) { + case Type::DEFAULT: { + std::vector client_infos; + client_map.GetAllClientInfos(client_infos); + client->AppendArrayLen(client_infos.size()); + if (client_infos.size() == 0) { + return; + } + char buf[128]; + for (auto& client_info : client_infos) { + // client-> + snprintf(buf, sizeof(buf), "ID=%ld IP=%s PORT=%d\n", client_info.client_id, client_info.ip.c_str(), + client_info.port); + client->AppendString(std::string(buf)); + } + break; + } + case Type::ID: { + client->AppendArrayLen(client->argv_.size() - 3); + + for (size_t i = 3; i < client->argv_.size(); i++) { + try { + int client_id = std::stoi(client->argv_[i]); + auto client_info = client_map.GetClientsInfoById(client_id); + if (client_info == ClientInfo::invalidClientInfo) { + client->SetRes(CmdRes::kErrOther, "Invalid client id"); + return; + } + char buf[128]; + snprintf(buf, sizeof(buf), "ID=%ld IP=%s PORT=%d\n", client_info.client_id, client_info.ip.c_str(), + client_info.port); + client->AppendString(std::string(buf)); + } catch (const std::exception& e) { + client->SetRes(CmdRes::kErrOther, "Invalid client id"); + return; + } + } + break; + } + default: + break; + } +} } // namespace pikiwidb diff --git a/src/cmd_admin.h b/src/cmd_admin.h index 7c6eaa610..3259af81f 100644 --- a/src/cmd_admin.h +++ b/src/cmd_admin.h @@ -81,6 +81,84 @@ class FlushallCmd : public BaseCmd { void DoCmd(PClient* client) override; }; +class CmdClient : public BaseCmdGroup { + public: + CmdClient(const std::string& name, int arity); + bool HasSubCommand() const override; + + protected: + std::string operation_, info_; + bool DoInitial(PClient* client) override { return true; } + + private: + const static std::string CLIENT_LIST_S; + const static std::string CLIENT_KILL_S; + + void DoCmd(PClient* client) override {} +}; + +class CmdClientGetname : public BaseCmd { + public: + CmdClientGetname(const std::string& name, int16_t arity); + + protected: + bool DoInitial(PClient* client) override; + + private: + void DoCmd(PClient* client) override; +}; + +class CmdClientSetname : public BaseCmd { + public: + CmdClientSetname(const std::string& name, int16_t arity); + + protected: + bool DoInitial(PClient* client) override; + + private: + void DoCmd(PClient* client) override; +}; + +class CmdClientId : public BaseCmd { + public: + CmdClientId(const std::string& name, int16_t arity); + + protected: + bool DoInitial(PClient* client) override; + + private: + void DoCmd(PClient* client) override; +}; + +class CmdClientList : public BaseCmd { + private: + enum class Type { DEFAULT, IDLE, ADDR, ID } list_type_; + std::string info_; + + public: + CmdClientList(const std::string& name, int16_t arity); + + protected: + bool DoInitial(PClient* client) override; + + private: + void DoCmd(PClient* client) override; +}; + +class CmdClientKill : public BaseCmd { + private: + enum class Type { ALL, ADDR, ID } kill_type_; + + public: + CmdClientKill(const std::string& name, int16_t arity); + + protected: + bool DoInitial(PClient* client) override; + + private: + void DoCmd(PClient* client) override; +}; + class SelectCmd : public BaseCmd { public: SelectCmd(const std::string& name, int16_t arity); diff --git a/src/cmd_table_manager.cc b/src/cmd_table_manager.cc index 0b10a40c5..cbc4f315a 100644 --- a/src/cmd_table_manager.cc +++ b/src/cmd_table_manager.cc @@ -60,6 +60,13 @@ void CmdTableManager::InitCmdTable() { ADD_SUBCOMMAND(Debug, Segfault, 2); ADD_COMMAND(Sort, -2); + ADD_COMMAND_GROUP(Client, -2); + ADD_SUBCOMMAND(Client, Getname, 2); + ADD_SUBCOMMAND(Client, Setname, 3); + ADD_SUBCOMMAND(Client, Id, 2); + ADD_SUBCOMMAND(Client, List, -2); + ADD_SUBCOMMAND(Client, Kill, -3); + // server ADD_COMMAND(Flushdb, 1); ADD_COMMAND(Flushall, 1); diff --git a/src/pikiwidb.cc b/src/pikiwidb.cc index a5ba33804..6246a252f 100644 --- a/src/pikiwidb.cc +++ b/src/pikiwidb.cc @@ -22,6 +22,7 @@ #include "pstd/pstd_util.h" #include "client.h" +#include "client_map.h" #include "config.h" #include "helper.h" #include "pikiwidb_logo.h" @@ -109,6 +110,8 @@ void PikiwiDB::OnNewConnection(uint64_t connId, std::shared_ptrSetSocketAddr(addr); client->OnConnect(); + // add new PClient to clients + ClientMap::getInstance().AddClient(client->GetUniqueID(), client); } bool PikiwiDB::Init() { @@ -156,6 +159,7 @@ bool PikiwiDB::Init() { event_server_->SetOnCreate([](uint64_t connID, std::shared_ptr& client, const net::SocketAddr& addr) { client->SetSocketAddr(addr); client->OnConnect(); + ClientMap::getInstance().AddClient(client->GetUniqueID(), client); INFO("New connection from fd:{} IP:{} port:{}", connID, addr.GetIP(), addr.GetPort()); }); @@ -166,6 +170,7 @@ bool PikiwiDB::Init() { event_server_->SetOnClose([](std::shared_ptr& client, std::string&& msg) { INFO("Close connection id:{} msg:{}", client->GetConnId(), msg); client->OnClose(); + ClientMap::getInstance().RemoveClientById(client->GetUniqueID()); }); event_server_->InitTimer(10); diff --git a/src/pikiwidb.h b/src/pikiwidb.h index 3f354ee01..d5158c128 100644 --- a/src/pikiwidb.h +++ b/src/pikiwidb.h @@ -78,6 +78,7 @@ class PikiwiDB final { std::unique_ptr>> event_server_; uint32_t cmd_id_ = 0; + std::atomic client_id_ = 0; }; extern std::unique_ptr g_pikiwidb; diff --git a/tests/admin_test.go b/tests/admin_test.go index 05e463b78..47ea47cb7 100644 --- a/tests/admin_test.go +++ b/tests/admin_test.go @@ -251,4 +251,22 @@ var _ = Describe("Admin", Ordered, func() { del2 := client.Del(ctx, "list2") Expect(del2.Err()).NotTo(HaveOccurred()) }) + It("Cmd Client", func() { + get := client.ClientGetName(ctx) + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("clientxxx")) + + resId := client.ClientID(ctx).Err() + Expect(resId).NotTo(HaveOccurred()) + Expect(client.ClientID(ctx).Val()).To(BeNumerically(">=", 0)) + + resKillFilter := client.ClientKillByFilter(ctx, "ADDR", "1.1.1.1:1111") + Expect(resKillFilter.Err()).To(MatchError("ERR No such client")) + Expect(resKillFilter.Val()).To(Equal(int64(0))) + + resKillFilter = client.ClientKillByFilter(ctx, "ID", "1") + Expect(resKillFilter.Err()).To(MatchError("ERR No such client")) + Expect(resKillFilter.Val()).To(Equal(int64(0))) + }) + })