Skip to content

Commit

Permalink
feat: add sort commands (#357)
Browse files Browse the repository at this point in the history
* remove unused variables and move parser func to Doinitial
  • Loading branch information
haiyang426 authored Jul 28, 2024
1 parent 5c5b95f commit 110fa8c
Show file tree
Hide file tree
Showing 7 changed files with 338 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pikiwidb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
run: |
cd ../tests
go mod tidy
go test
go test -timeout 15m
build_on_ubuntu:
runs-on: ubuntu-latest
Expand All @@ -67,4 +67,4 @@ jobs:
run: |
cd ../tests
go mod tidy
go test
go test -timeout 15m
1 change: 1 addition & 0 deletions src/base_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ const std::string kSubCmdNameDebugHelp = "help";
const std::string kSubCmdNameDebugOOM = "oom";
const std::string kSubCmdNameDebugSegfault = "segfault";
const std::string kCmdNameInfo = "info";
const std::string kCmdNameSort = "sort";

// hash cmd
const std::string kCmdNameHSet = "hset";
Expand Down
208 changes: 208 additions & 0 deletions src/cmd_admin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
*/

#include "cmd_admin.h"
#include <cstddef>
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
#include "db.h"

#include "braft/raft.h"
#include "pstd_string.h"
#include "rocksdb/version.h"

#include "pikiwidb.h"
Expand Down Expand Up @@ -257,4 +263,206 @@ void CmdDebugSegfault::DoCmd(PClient* client) {
*ptr = 0;
}

SortCmd::SortCmd(const std::string& name, int16_t arity)
: BaseCmd(name, arity, kCmdFlagsAdmin | kCmdFlagsWrite, kAclCategoryAdmin) {}

bool SortCmd::DoInitial(PClient* client) {
InitialArgument();
client->SetKey(client->argv_[1]);
size_t argc = client->argv_.size();
for (int i = 2; i < argc; ++i) {
int leftargs = argc - i - 1;
if (strcasecmp(client->argv_[i].data(), "asc") == 0) {
desc_ = 0;
} else if (strcasecmp(client->argv_[i].data(), "desc") == 0) {
desc_ = 1;
} else if (strcasecmp(client->argv_[i].data(), "alpha") == 0) {
alpha_ = 1;
} else if (strcasecmp(client->argv_[i].data(), "limit") == 0 && leftargs >= 2) {
if (pstd::String2int(client->argv_[i + 1], &offset_) == 0 ||
pstd::String2int(client->argv_[i + 2], &count_) == 0) {
client->SetRes(CmdRes::kSyntaxErr);
return false;
}
i += 2;
} else if (strcasecmp(client->argv_[i].data(), "store") == 0 && leftargs >= 1) {
store_key_ = client->argv_[i + 1];
i++;
} else if (strcasecmp(client->argv_[i].data(), "by") == 0 && leftargs >= 1) {
sortby_ = client->argv_[i + 1];
if (sortby_.find('*') == std::string::npos) {
dontsort_ = 1;
}
i++;
} else if (strcasecmp(client->argv_[i].data(), "get") == 0 && leftargs >= 1) {
get_patterns_.push_back(client->argv_[i + 1]);
i++;
} else {
client->SetRes(CmdRes::kSyntaxErr);
return false;
}
}

Status s;
s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->LRange(client->Key(), 0, -1, &ret_);
if (s.ok()) {
return true;
} else if (!s.IsNotFound()) {
client->SetRes(CmdRes::kErrOther, s.ToString());
return false;
}

s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->SMembers(client->Key(), &ret_);
if (s.ok()) {
return true;
} else if (!s.IsNotFound()) {
client->SetRes(CmdRes::kErrOther, s.ToString());
return false;
}

std::vector<storage::ScoreMember> score_members;
s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->ZRange(client->Key(), 0, -1, &score_members);
if (s.ok()) {
for (auto& c : score_members) {
ret_.emplace_back(c.member);
}
return true;
} else if (!s.IsNotFound()) {
client->SetRes(CmdRes::kErrOther, s.ToString());
return false;
}
client->SetRes(CmdRes::kErrOther, "Unknown Type");
return false;
}

void SortCmd::DoCmd(PClient* client) {
std::vector<RedisSortObject> sort_ret(ret_.size());
for (size_t i = 0; i < ret_.size(); ++i) {
sort_ret[i].obj = ret_[i];
}

if (!dontsort_) {
for (size_t i = 0; i < ret_.size(); ++i) {
std::string byval;
if (!sortby_.empty()) {
auto lookup = lookupKeyByPattern(client, sortby_, ret_[i]);
if (!lookup.has_value()) {
byval = ret_[i];
} else {
byval = std::move(lookup.value());
}
} else {
byval = ret_[i];
}

if (alpha_) {
sort_ret[i].u = byval;
} else {
double double_byval;
if (pstd::String2d(byval, &double_byval)) {
sort_ret[i].u = double_byval;
} else {
client->SetRes(CmdRes::kErrOther, "One or more scores can't be converted into double");
return;
}
}
}

std::sort(sort_ret.begin(), sort_ret.end(), [this](const RedisSortObject& a, const RedisSortObject& b) {
if (this->alpha_) {
std::string score_a = std::get<std::string>(a.u);
std::string score_b = std::get<std::string>(b.u);
return !this->desc_ ? score_a < score_b : score_a > score_b;
} else {
double score_a = std::get<double>(a.u);
double score_b = std::get<double>(b.u);
return !this->desc_ ? score_a < score_b : score_a > score_b;
}
});

size_t sort_size = sort_ret.size();

count_ = count_ >= 0 ? count_ : sort_size;
offset_ = (offset_ >= 0 && offset_ < sort_size) ? offset_ : sort_size;
count_ = (offset_ + count_ < sort_size) ? count_ : sort_size - offset_;

size_t m_start = offset_;
size_t m_end = offset_ + count_;

ret_.clear();
if (get_patterns_.empty()) {
get_patterns_.emplace_back("#");
}

for (; m_start < m_end; m_start++) {
for (const std::string& pattern : get_patterns_) {
std::optional<std::string> val = lookupKeyByPattern(client, pattern, sort_ret[m_start].obj);
if (val.has_value()) {
ret_.push_back(val.value());
} else {
ret_.emplace_back("");
}
}
}
}

if (store_key_.empty()) {
client->AppendStringVector(ret_);
} else {
uint64_t reply_num = 0;
storage::Status s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->RPush(store_key_, ret_, &reply_num);
if (s.ok()) {
client->AppendInteger(reply_num);
} else {
client->SetRes(CmdRes::kErrOther, s.ToString());
}
}
}

std::optional<std::string> SortCmd::lookupKeyByPattern(PClient* client, const std::string& pattern,
const std::string& subst) {
if (pattern == "#") {
return subst;
}

auto match_pos = pattern.find('*');
if (match_pos == std::string::npos) {
return std::nullopt;
}

std::string field;
auto arrow_pos = pattern.find("->", match_pos + 1);
if (arrow_pos != std::string::npos && arrow_pos + 2 < pattern.size()) {
field = pattern.substr(arrow_pos + 2);
}

std::string key = pattern.substr(0, match_pos + 1);
key.replace(match_pos, 1, subst);

std::string value;
storage::Status s;
if (!field.empty()) {
s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->HGet(key, field, &value);
} else {
s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->Get(key, &value);
}

if (!s.ok()) {
return std::nullopt;
}

return value;
}

void SortCmd::InitialArgument() {
desc_ = 0;
alpha_ = 0;
offset_ = 0;
count_ = -1;
dontsort_ = 0;
store_key_.clear();
sortby_.clear();
get_patterns_.clear();
ret_.clear();
}
} // namespace pikiwidb
31 changes: 31 additions & 0 deletions src/cmd_admin.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#pragma once

#include <optional>
#include <variant>
#include "base_cmd.h"
#include "config.h"

Expand Down Expand Up @@ -172,4 +174,33 @@ class CmdDebugSegfault : public BaseCmd {
void DoCmd(PClient* client) override;
};

class SortCmd : public BaseCmd {
public:
SortCmd(const std::string& name, int16_t arity);

protected:
bool DoInitial(PClient* client) override;

private:
void DoCmd(PClient* client) override;

void InitialArgument();
std::optional<std::string> lookupKeyByPattern(PClient* client, const std::string& pattern, const std::string& subst);

struct RedisSortObject {
std::string obj;
std::variant<double, std::string> u;
};

int desc_ = 0;
int alpha_ = 0;
size_t offset_ = 0;
size_t count_ = -1;
int dontsort_ = 0;
std::string store_key_;
std::string sortby_;
std::vector<std::string> get_patterns_;
std::vector<std::string> ret_;
};

} // namespace pikiwidb
6 changes: 3 additions & 3 deletions src/cmd_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ bool SCardCmd::DoInitial(PClient* client) {
void SCardCmd::DoCmd(PClient* client) {
int32_t reply_Num = 0;
storage::Status s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->SCard(client->Key(), &reply_Num);
if (!s.ok()) {
client->SetRes(CmdRes::kSyntaxErr, "scard cmd error");
if (s.ok() || s.IsNotFound()) {
client->AppendInteger(reply_Num);
return;
}
client->AppendInteger(reply_Num);
client->SetRes(CmdRes::kSyntaxErr, "scard cmd error");
}

SMoveCmd::SMoveCmd(const std::string& name, int16_t arity)
Expand Down
1 change: 1 addition & 0 deletions src/cmd_table_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ void CmdTableManager::InitCmdTable() {
ADD_SUBCOMMAND(Debug, Help, 2);
ADD_SUBCOMMAND(Debug, OOM, 2);
ADD_SUBCOMMAND(Debug, Segfault, 2);
ADD_COMMAND(Sort, -2);

// server
ADD_COMMAND(Flushdb, 1);
Expand Down
Loading

0 comments on commit 110fa8c

Please sign in to comment.