Skip to content

Commit

Permalink
add client map to store client message
Browse files Browse the repository at this point in the history
  • Loading branch information
gukj-spel committed Jul 5, 2024
1 parent f5838bb commit 3e3f01d
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 95 deletions.
91 changes: 91 additions & 0 deletions src/client_map.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "client_map.h"

namespace pikiwidb {

uint32_t ClientMap::GetAllClientInfos(std::vector<ClientInfo>& results) {
// client info string type: ip, port, fd.
std::shared_lock<std::shared_mutex> client_map_lock(client_map_mutex_);
auto it = clients_.begin();
while (it != clients_.end()) {
auto client = it->second.lock();
if (client) {
results.emplace_back(client->GetClientInfo());
}
it++;
}
return results.size();
}

bool ClientMap::AddClient(int id, std::weak_ptr<PClient> client) {
std::unique_lock client_map_lock(client_map_mutex_);
if (clients_.find(id) == clients_.end()) {
clients_.insert({id, client});
return true;
}
return false;
}

ClientInfo ClientMap::GetClientsInfoById(int id) {
std::shared_lock client_map_lock(client_map_mutex_);
if (auto it = clients_.find(id); it != clients_.end()) {
if (auto client = it->second.lock(); client) {
return client->GetClientInfo();
}
}
return ClientInfo::invalidClientInfo;
}

bool ClientMap::RemoveClientById(int id) {
std::unique_lock client_map_lock(client_map_mutex_);
if (auto it = clients_.find(id); it != clients_.end()) {
clients_.erase(it);
return true;
}
return false;
}

bool ClientMap::KillAllClients() {
std::shared_lock<std::shared_mutex> client_map_lock(client_map_mutex_);
auto it = clients_.begin();
while (it != clients_.end()) {
auto client = it->second.lock();
if (client) {
client_map_lock.unlock();
client->Close();
client_map_lock.lock();
}
it++;
}
return true;
}

bool ClientMap::KillClientByAddrPort(const std::string& addr_port) {
std::shared_lock<std::shared_mutex> client_map_lock(client_map_mutex_);
for (auto& [id, client_weak] : clients_) {
auto client = client_weak.lock();
if (client) {
std::string client_ip_port = client->PeerIP() + ":" + std::to_string(client->PeerPort());
if (client_ip_port == addr_port) {
client_map_lock.unlock();
client->Close();
return true;
}
}
}
return false;
}

bool ClientMap::KillClientById(int client_id) {
std::shared_lock<std::shared_mutex> client_map_lock(client_map_mutex_);
if (auto it = clients_.find(client_id); it != clients_.end()) {
auto client = it->second.lock();
if (client) {
client_map_lock.unlock();
client->Close();
return true;
}
}
return false;
}

} // namespace pikiwidb
42 changes: 42 additions & 0 deletions src/client_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include <map>
#include <memory>
#include <shared_mutex>
#include <string>
#include "client.h"

namespace pikiwidb {

class ClientMap {
private:
ClientMap() = default;
// 禁用复制构造函数和赋值运算符

private:
std::map<int, std::weak_ptr<PClient>> clients_;
std::shared_mutex client_map_mutex_;

public:
static ClientMap& getInstance() {
static ClientMap instance;
return instance;
}

ClientMap(const ClientMap&) = delete;
ClientMap& operator=(const ClientMap&) = delete;

// client info function
pikiwidb::ClientInfo GetClientsInfoById(int id);
uint32_t GetAllClientInfos(std::vector<ClientInfo>& results);

bool AddClient(int id, std::weak_ptr<PClient>);

bool RemoveClientById(int id);

bool KillAllClients();
bool KillClientById(int client_id);
bool KillClientByAddrPort(const std::string& addr_port);
};

} // namespace pikiwidb
13 changes: 8 additions & 5 deletions src/cmd_admin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "praft/praft.h"
#include "pstd/env.h"

#include "client_map.h"
#include "store.h"

namespace pikiwidb {
Expand Down Expand Up @@ -308,19 +309,20 @@ bool CmdClientKill::DoInitial(PClient* client) {

void CmdClientKill::DoCmd(PClient* client) {
bool ret;
auto& client_map = pikiwidb::ClientMap::getInstance();
switch (kill_type_) {
case Type::ALL: {
ret = g_pikiwidb->KillAllClients();
ret = client_map.KillAllClients();
break;
}
case Type::ADDR: {
ret = g_pikiwidb->KillClientByAddrPort(client->argv_[3]);
ret = client_map.KillClientByAddrPort(client->argv_[3]);
break;
}
case Type::ID: {
try {
int client_id = stoi(client->argv_[3]);
ret = g_pikiwidb->KillClientById(client_id);
ret = client_map.KillClientById(client_id);
} catch (const std::exception& e) {
client->SetRes(CmdRes::kErrOther, "Invalid client id");
return;
Expand Down Expand Up @@ -349,10 +351,11 @@ bool CmdClientList::DoInitial(PClient* client) {
}

void CmdClientList::DoCmd(PClient* client) {
auto& client_map = ClientMap::getInstance();
switch (list_type_) {
case Type::DEFAULT: {
std::vector<pikiwidb::ClientInfo> client_infos;
g_pikiwidb->GetAllClientInfos(client_infos);
client_map.GetAllClientInfos(client_infos);
client->AppendArrayLen(client_infos.size());
if (client_infos.size() == 0) {
return;
Expand All @@ -372,7 +375,7 @@ void CmdClientList::DoCmd(PClient* client) {
for (size_t i = 3; i < client->argv_.size(); i++) {
try {
int client_id = std::stoi(client->argv_[i]);
auto client_info = g_pikiwidb->GetClientsInfoById(client_id);
auto client_info = client_map.GetClientsInfoById(client_id);
if (client_info == ClientInfo::invalidClientInfo) {
client->SetRes(CmdRes::kErrOther, "Invalid client id");
return;
Expand Down
81 changes: 3 additions & 78 deletions src/pikiwidb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "pstd/pstd_util.h"

#include "client.h"
#include "client_map.h"
#include "config.h"
#include "helper.h"
#include "pikiwidb_logo.h"
Expand Down Expand Up @@ -117,90 +118,14 @@ void PikiwiDB::OnNewConnection(pikiwidb::TcpConnection* obj) {
obj->SetOnDisconnect([](pikiwidb::TcpConnection* obj) {
INFO("disconnect from {}", obj->GetPeerIP());
obj->GetContext<pikiwidb::PClient>()->SetState(pikiwidb::ClientState::kClosed);
g_pikiwidb->RemoveClientMetaById(obj->GetUniqueId());
ClientMap::getInstance().RemoveClientById(obj->GetUniqueId());
});
obj->SetNodelay(true);
obj->SetEventLoopSelector([this]() { return worker_threads_.ChooseNextWorkerEventLoop(); });
obj->SetSlaveEventLoopSelector([this]() { return slave_threads_.ChooseNextWorkerEventLoop(); });

// add new PClient to clients
clients.insert({client->GetUniqueId(), client});
}

uint32_t PikiwiDB::GetAllClientInfos(std::vector<ClientInfo>& results) {
// client info string type: ip, port, fd.
std::shared_lock<std::shared_mutex> client_map_lock(client_map_mutex);
auto it = clients.begin();
while (it != clients.end()) {
auto client = it->second.lock();
if (client) {
results.emplace_back(client->GetClientInfo());
}
it++;
}
return results.size();
}
ClientInfo PikiwiDB::GetClientsInfoById(int id) {
std::shared_lock client_map_lock(client_map_mutex);
if (auto it = clients.find(id); it != clients.end()) {
if (auto client = it->second.lock(); client) {
return client->GetClientInfo();
}
}
return ClientInfo::invalidClientInfo;
}

bool PikiwiDB::RemoveClientMetaById(int id) {
std::unique_lock client_map_lock(client_map_mutex);
if (auto it = clients.find(id); it != clients.end()) {
clients.erase(it);
return true;
}
return false;
}

bool PikiwiDB::KillAllClients() {
std::shared_lock<std::shared_mutex> client_map_lock(client_map_mutex);
auto it = clients.begin();
while (it != clients.end()) {
auto client = it->second.lock();
if (client) {
client_map_lock.unlock();
client->Close();
client_map_lock.lock();
}
it++;
}
return true;
}

bool PikiwiDB::KillClientByAddrPort(const std::string& addr_port) {
std::shared_lock<std::shared_mutex> client_map_lock(client_map_mutex);
for (auto& [id, client_weak] : clients) {
auto client = client_weak.lock();
if (client) {
std::string client_ip_port = client->PeerIP() + ":" + std::to_string(client->PeerPort());
if (client_ip_port == addr_port) {
client_map_lock.unlock();
client->Close();
return true;
}
}
}
return false;
}

bool PikiwiDB::KillClientById(int client_id) {
std::shared_lock<std::shared_mutex> client_map_lock(client_map_mutex);
if (auto it = clients.find(client_id); it != clients.end()) {
auto client = it->second.lock();
if (client) {
client_map_lock.unlock();
client->Close();
return true;
}
}
return false;
ClientMap::getInstance().AddClient(client->GetUniqueId(), client);
}

bool PikiwiDB::Init() {
Expand Down
13 changes: 1 addition & 12 deletions src/pikiwidb.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* of patent rights can be found in the PATENTS file in the same directory.
*/

#include "client_map.h"
#include "cmd_table_manager.h"
#include "cmd_thread_pool.h"
#include "common.h"
Expand Down Expand Up @@ -45,16 +46,6 @@ class PikiwiDB final {

void PushWriteTask(const std::shared_ptr<pikiwidb::PClient>& client) { worker_threads_.PushWriteTask(client); }

// client message function
uint32_t GetAllClientInfos(std::vector<pikiwidb::ClientInfo>& results);
pikiwidb::ClientInfo GetClientsInfoById(int id);

bool RemoveClientMetaById(int id);

bool KillAllClients();
bool KillClientByAddrPort(const std::string& addr_port);
bool KillClientById(int client_id);

public:
PString cfg_file_;
uint16_t port_{0};
Expand All @@ -71,9 +62,7 @@ class PikiwiDB final {
pikiwidb::CmdThreadPool cmd_threads_;
// pikiwidb::CmdTableManager cmd_table_manager_;
// use std::list to store client pointer as a double linked list
std::shared_mutex client_map_mutex;
std::mutex killer_mutex;
std::map<int, std::weak_ptr<pikiwidb::PClient>> clients;
uint32_t cmd_id_ = 0;
std::atomic<int64_t> client_id_ = 0;
};
Expand Down

0 comments on commit 3e3f01d

Please sign in to comment.