From 110fa8c7a16bb70b2a6fcebead8b3827356477dd Mon Sep 17 00:00:00 2001 From: haiyang426 <51285701+haiyang426@users.noreply.github.com> Date: Sun, 28 Jul 2024 10:19:22 +0800 Subject: [PATCH] feat: add sort commands (#357) * remove unused variables and move parser func to Doinitial --- .github/workflows/pikiwidb.yml | 4 +- src/base_cmd.h | 1 + src/cmd_admin.cc | 208 +++++++++++++++++++++++++++++++++ src/cmd_admin.h | 31 +++++ src/cmd_set.cc | 6 +- src/cmd_table_manager.cc | 1 + tests/admin_test.go | 92 +++++++++++++++ 7 files changed, 338 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pikiwidb.yml b/.github/workflows/pikiwidb.yml index c6ea3d576..36d6e15d0 100644 --- a/.github/workflows/pikiwidb.yml +++ b/.github/workflows/pikiwidb.yml @@ -43,7 +43,7 @@ jobs: run: | cd ../tests go mod tidy - go test + go test -timeout 15m build_on_ubuntu: runs-on: ubuntu-latest @@ -67,4 +67,4 @@ jobs: run: | cd ../tests go mod tidy - go test \ No newline at end of file + go test -timeout 15m \ No newline at end of file diff --git a/src/base_cmd.h b/src/base_cmd.h index 7c5507333..470052ccc 100644 --- a/src/base_cmd.h +++ b/src/base_cmd.h @@ -89,6 +89,7 @@ const std::string kSubCmdNameDebugHelp = "help"; const std::string kSubCmdNameDebugOOM = "oom"; const std::string kSubCmdNameDebugSegfault = "segfault"; const std::string kCmdNameInfo = "info"; +const std::string kCmdNameSort = "sort"; // hash cmd const std::string kCmdNameHSet = "hset"; diff --git a/src/cmd_admin.cc b/src/cmd_admin.cc index 11314f140..db48b8cae 100644 --- a/src/cmd_admin.cc +++ b/src/cmd_admin.cc @@ -6,9 +6,15 @@ */ #include "cmd_admin.h" +#include +#include +#include +#include +#include #include "db.h" #include "braft/raft.h" +#include "pstd_string.h" #include "rocksdb/version.h" #include "pikiwidb.h" @@ -257,4 +263,206 @@ void CmdDebugSegfault::DoCmd(PClient* client) { *ptr = 0; } +SortCmd::SortCmd(const std::string& name, int16_t arity) + : BaseCmd(name, arity, kCmdFlagsAdmin | kCmdFlagsWrite, kAclCategoryAdmin) {} + +bool SortCmd::DoInitial(PClient* client) { + InitialArgument(); + client->SetKey(client->argv_[1]); + size_t argc = client->argv_.size(); + for (int i = 2; i < argc; ++i) { + int leftargs = argc - i - 1; + if (strcasecmp(client->argv_[i].data(), "asc") == 0) { + desc_ = 0; + } else if (strcasecmp(client->argv_[i].data(), "desc") == 0) { + desc_ = 1; + } else if (strcasecmp(client->argv_[i].data(), "alpha") == 0) { + alpha_ = 1; + } else if (strcasecmp(client->argv_[i].data(), "limit") == 0 && leftargs >= 2) { + if (pstd::String2int(client->argv_[i + 1], &offset_) == 0 || + pstd::String2int(client->argv_[i + 2], &count_) == 0) { + client->SetRes(CmdRes::kSyntaxErr); + return false; + } + i += 2; + } else if (strcasecmp(client->argv_[i].data(), "store") == 0 && leftargs >= 1) { + store_key_ = client->argv_[i + 1]; + i++; + } else if (strcasecmp(client->argv_[i].data(), "by") == 0 && leftargs >= 1) { + sortby_ = client->argv_[i + 1]; + if (sortby_.find('*') == std::string::npos) { + dontsort_ = 1; + } + i++; + } else if (strcasecmp(client->argv_[i].data(), "get") == 0 && leftargs >= 1) { + get_patterns_.push_back(client->argv_[i + 1]); + i++; + } else { + client->SetRes(CmdRes::kSyntaxErr); + return false; + } + } + + Status s; + s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->LRange(client->Key(), 0, -1, &ret_); + if (s.ok()) { + return true; + } else if (!s.IsNotFound()) { + client->SetRes(CmdRes::kErrOther, s.ToString()); + return false; + } + + s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->SMembers(client->Key(), &ret_); + if (s.ok()) { + return true; + } else if (!s.IsNotFound()) { + client->SetRes(CmdRes::kErrOther, s.ToString()); + return false; + } + + std::vector score_members; + s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->ZRange(client->Key(), 0, -1, &score_members); + if (s.ok()) { + for (auto& c : score_members) { + ret_.emplace_back(c.member); + } + return true; + } else if (!s.IsNotFound()) { + client->SetRes(CmdRes::kErrOther, s.ToString()); + return false; + } + client->SetRes(CmdRes::kErrOther, "Unknown Type"); + return false; +} + +void SortCmd::DoCmd(PClient* client) { + std::vector sort_ret(ret_.size()); + for (size_t i = 0; i < ret_.size(); ++i) { + sort_ret[i].obj = ret_[i]; + } + + if (!dontsort_) { + for (size_t i = 0; i < ret_.size(); ++i) { + std::string byval; + if (!sortby_.empty()) { + auto lookup = lookupKeyByPattern(client, sortby_, ret_[i]); + if (!lookup.has_value()) { + byval = ret_[i]; + } else { + byval = std::move(lookup.value()); + } + } else { + byval = ret_[i]; + } + + if (alpha_) { + sort_ret[i].u = byval; + } else { + double double_byval; + if (pstd::String2d(byval, &double_byval)) { + sort_ret[i].u = double_byval; + } else { + client->SetRes(CmdRes::kErrOther, "One or more scores can't be converted into double"); + return; + } + } + } + + std::sort(sort_ret.begin(), sort_ret.end(), [this](const RedisSortObject& a, const RedisSortObject& b) { + if (this->alpha_) { + std::string score_a = std::get(a.u); + std::string score_b = std::get(b.u); + return !this->desc_ ? score_a < score_b : score_a > score_b; + } else { + double score_a = std::get(a.u); + double score_b = std::get(b.u); + return !this->desc_ ? score_a < score_b : score_a > score_b; + } + }); + + size_t sort_size = sort_ret.size(); + + count_ = count_ >= 0 ? count_ : sort_size; + offset_ = (offset_ >= 0 && offset_ < sort_size) ? offset_ : sort_size; + count_ = (offset_ + count_ < sort_size) ? count_ : sort_size - offset_; + + size_t m_start = offset_; + size_t m_end = offset_ + count_; + + ret_.clear(); + if (get_patterns_.empty()) { + get_patterns_.emplace_back("#"); + } + + for (; m_start < m_end; m_start++) { + for (const std::string& pattern : get_patterns_) { + std::optional val = lookupKeyByPattern(client, pattern, sort_ret[m_start].obj); + if (val.has_value()) { + ret_.push_back(val.value()); + } else { + ret_.emplace_back(""); + } + } + } + } + + if (store_key_.empty()) { + client->AppendStringVector(ret_); + } else { + uint64_t reply_num = 0; + storage::Status s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->RPush(store_key_, ret_, &reply_num); + if (s.ok()) { + client->AppendInteger(reply_num); + } else { + client->SetRes(CmdRes::kErrOther, s.ToString()); + } + } +} + +std::optional SortCmd::lookupKeyByPattern(PClient* client, const std::string& pattern, + const std::string& subst) { + if (pattern == "#") { + return subst; + } + + auto match_pos = pattern.find('*'); + if (match_pos == std::string::npos) { + return std::nullopt; + } + + std::string field; + auto arrow_pos = pattern.find("->", match_pos + 1); + if (arrow_pos != std::string::npos && arrow_pos + 2 < pattern.size()) { + field = pattern.substr(arrow_pos + 2); + } + + std::string key = pattern.substr(0, match_pos + 1); + key.replace(match_pos, 1, subst); + + std::string value; + storage::Status s; + if (!field.empty()) { + s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->HGet(key, field, &value); + } else { + s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->Get(key, &value); + } + + if (!s.ok()) { + return std::nullopt; + } + + return value; +} + +void SortCmd::InitialArgument() { + desc_ = 0; + alpha_ = 0; + offset_ = 0; + count_ = -1; + dontsort_ = 0; + store_key_.clear(); + sortby_.clear(); + get_patterns_.clear(); + ret_.clear(); +} } // namespace pikiwidb diff --git a/src/cmd_admin.h b/src/cmd_admin.h index c78164093..7c6eaa610 100644 --- a/src/cmd_admin.h +++ b/src/cmd_admin.h @@ -7,6 +7,8 @@ #pragma once +#include +#include #include "base_cmd.h" #include "config.h" @@ -172,4 +174,33 @@ class CmdDebugSegfault : public BaseCmd { void DoCmd(PClient* client) override; }; +class SortCmd : public BaseCmd { + public: + SortCmd(const std::string& name, int16_t arity); + + protected: + bool DoInitial(PClient* client) override; + + private: + void DoCmd(PClient* client) override; + + void InitialArgument(); + std::optional lookupKeyByPattern(PClient* client, const std::string& pattern, const std::string& subst); + + struct RedisSortObject { + std::string obj; + std::variant u; + }; + + int desc_ = 0; + int alpha_ = 0; + size_t offset_ = 0; + size_t count_ = -1; + int dontsort_ = 0; + std::string store_key_; + std::string sortby_; + std::vector get_patterns_; + std::vector ret_; +}; + } // namespace pikiwidb diff --git a/src/cmd_set.cc b/src/cmd_set.cc index 8a750dbb7..c2a76cc97 100644 --- a/src/cmd_set.cc +++ b/src/cmd_set.cc @@ -158,11 +158,11 @@ bool SCardCmd::DoInitial(PClient* client) { void SCardCmd::DoCmd(PClient* client) { int32_t reply_Num = 0; storage::Status s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->SCard(client->Key(), &reply_Num); - if (!s.ok()) { - client->SetRes(CmdRes::kSyntaxErr, "scard cmd error"); + if (s.ok() || s.IsNotFound()) { + client->AppendInteger(reply_Num); return; } - client->AppendInteger(reply_Num); + client->SetRes(CmdRes::kSyntaxErr, "scard cmd error"); } SMoveCmd::SMoveCmd(const std::string& name, int16_t arity) diff --git a/src/cmd_table_manager.cc b/src/cmd_table_manager.cc index 5a335465b..c5c667404 100644 --- a/src/cmd_table_manager.cc +++ b/src/cmd_table_manager.cc @@ -57,6 +57,7 @@ void CmdTableManager::InitCmdTable() { ADD_SUBCOMMAND(Debug, Help, 2); ADD_SUBCOMMAND(Debug, OOM, 2); ADD_SUBCOMMAND(Debug, Segfault, 2); + ADD_COMMAND(Sort, -2); // server ADD_COMMAND(Flushdb, 1); diff --git a/tests/admin_test.go b/tests/admin_test.go index 8d8713c31..05e463b78 100644 --- a/tests/admin_test.go +++ b/tests/admin_test.go @@ -159,4 +159,96 @@ var _ = Describe("Admin", Ordered, func() { // Expect(res.Err()).NotTo(HaveOccurred()) // Expect(res.Val()).To(Equal(map[string]string{"timeout": "0"})) }) + + It("Cmd Sort", func() { + size, err := client.LPush(ctx, "list", "1").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(size).To(Equal(int64(1))) + + size, err = client.LPush(ctx, "list", "3").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(size).To(Equal(int64(2))) + + size, err = client.LPush(ctx, "list", "2").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(size).To(Equal(int64(3))) + + els, err := client.Sort(ctx, "list", &redis.Sort{ + Offset: 0, + Count: 2, + Order: "ASC", + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(els).To(Equal([]string{"1", "2"})) + + del := client.Del(ctx, "list") + Expect(del.Err()).NotTo(HaveOccurred()) + }) + + It("should Sort and Get", Label("NonRedisEnterprise"), func() { + size, err := client.LPush(ctx, "list", "1").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(size).To(Equal(int64(1))) + + size, err = client.LPush(ctx, "list", "3").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(size).To(Equal(int64(2))) + + size, err = client.LPush(ctx, "list", "2").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(size).To(Equal(int64(3))) + + err = client.Set(ctx, "object_2", "value2", 0).Err() + Expect(err).NotTo(HaveOccurred()) + + { + els, err := client.Sort(ctx, "list", &redis.Sort{ + Get: []string{"object_*"}, + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(els).To(Equal([]string{"", "value2", ""})) + } + + { + els, err := client.SortInterfaces(ctx, "list", &redis.Sort{ + Get: []string{"object_*"}, + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(els).To(Equal([]interface{}{nil, "value2", nil})) + } + del := client.Del(ctx, "list") + Expect(del.Err()).NotTo(HaveOccurred()) + }) + + It("should Sort and Store", Label("NonRedisEnterprise"), func() { + size, err := client.LPush(ctx, "list", "1").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(size).To(Equal(int64(1))) + + size, err = client.LPush(ctx, "list", "3").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(size).To(Equal(int64(2))) + + size, err = client.LPush(ctx, "list", "2").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(size).To(Equal(int64(3))) + + n, err := client.SortStore(ctx, "list", "list2", &redis.Sort{ + Offset: 0, + Count: 2, + Order: "ASC", + }).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(int64(2))) + + els, err := client.LRange(ctx, "list2", 0, -1).Result() + Expect(err).NotTo(HaveOccurred()) + Expect(els).To(Equal([]string{"1", "2"})) + + del := client.Del(ctx, "list") + Expect(del.Err()).NotTo(HaveOccurred()) + + del2 := client.Del(ctx, "list2") + Expect(del2.Err()).NotTo(HaveOccurred()) + }) })