Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: hash command hrandfield #98

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/base_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ const std::string kCmdNameHKeys = "hkeys";
const std::string kCmdNameHLen = "hlen";
const std::string kCmdNameHStrLen = "hstrlen";
const std::string kCmdNameHScan = "hscan";
const std::string kCmdNameHRandField = "hrandfield";
const std::string kCmdNameHVals = "hvals";

// set cmd
Expand Down Expand Up @@ -275,4 +276,4 @@ class BaseCmdGroup : public BaseCmd {
private:
std::map<std::string, std::unique_ptr<BaseCmd>> subCmds_;
};
} // namespace pikiwidb
} // namespace pikiwidb
54 changes: 54 additions & 0 deletions src/cmd_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,59 @@ void HScanCmd::DoCmd(PClient* client) {
}
}

HRandFieldCmd::HRandFieldCmd(const std::string& name, int16_t arity)
: BaseCmd(name, arity, kCmdFlagsReadonly, kAclCategoryRead | kAclCategoryHash) {}

bool HRandFieldCmd::DoInitial(PClient* client) {
client->SetKey(client->argv_[1]);
return true;
}

void HRandFieldCmd::DoCmd(PClient* client) {
// parse arguments
const auto& argv = client->argv_;
int64_t count{1};
bool with_values{false};
if (argv.size() > 2) {
// redis checks the integer argument first and then the number of parameters
if (pstd::String2int(argv[2], &count) == 0) {
client->SetRes(CmdRes::kInvalidInt);
return;
}
if (argv.size() > 4) {
client->SetRes(CmdRes::kSyntaxErr);
return;
}
if (argv.size() > 3) {
if (kWithValueString != pstd::StringToLower(argv[3])) {
client->SetRes(CmdRes::kSyntaxErr);
return;
}
with_values = true;
}
}

// execute command
std::vector<std::string> res;
auto s = PSTORE.GetBackend()->HRandField(client->Key(), count, with_values, &res);
if (s.IsNotFound()) {
client->AppendString("");
return;
}
if (!s.ok()) {
client->SetRes(CmdRes::kErrOther, s.ToString());
return;
}

// reply to client
if (argv.size() > 2) {
client->AppendArrayLenUint64(res.size());
}
for (const auto& item : res) {
client->AppendString(item);
}
}

HValsCmd::HValsCmd(const std::string& name, int16_t arity)
: BaseCmd(name, arity, kCmdFlagsReadonly, kAclCategoryRead | kAclCategoryHash) {}

Expand All @@ -300,4 +353,5 @@ void HValsCmd::DoCmd(PClient* client) {
client->SetRes(CmdRes::kErrOther, "hvals cmd error");
}
}

} // namespace pikiwidb
13 changes: 13 additions & 0 deletions src/cmd_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ class HScanCmd : public BaseCmd {
static constexpr const char *kCountSymbol = "count";
};

class HRandFieldCmd : public BaseCmd {
public:
HRandFieldCmd(const std::string &name, int16_t arity);

protected:
bool DoInitial(PClient *client) override;

private:
void DoCmd(PClient *client) override;

static constexpr const char *kWithValueString = "withvalues";
};

class HValsCmd : public BaseCmd {
public:
HValsCmd(const std::string &name, int16_t arity);
Expand Down
1 change: 1 addition & 0 deletions src/cmd_table_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ void CmdTableManager::InitCmdTable() {
ADD_COMMAND(HLen, 2);
ADD_COMMAND(HStrLen, 3);
ADD_COMMAND(HScan, -3);
ADD_COMMAND(HRandField, -2);
ADD_COMMAND(HVals, 2);

// set
Expand Down
3 changes: 3 additions & 0 deletions src/storage/include/storage/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ class Storage {
Status HScanx(const Slice& key, const std::string& start_field, const std::string& pattern, int64_t count,
std::vector<FieldValue>* field_values, std::string* next_field);

// Return random field(s) and value(s) from the hash value stored at key.
Status HRandField(const Slice& key, int64_t count, bool with_values, std::vector<std::string>* res);

// Iterate over a Hash table of fields by specified range
// return next_field that the user need to use as the start_field argument
// in the next call
Expand Down
73 changes: 73 additions & 0 deletions src/storage/src/redis_hashes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "src/redis_hashes.h"

#include <memory>
#include <numeric>
#include <random>

#include <fmt/core.h>
#include <glog/logging.h>
Expand Down Expand Up @@ -923,6 +925,77 @@ Status RedisHashes::HScanx(const Slice& key, const std::string& start_field, con
return Status::OK();
}

Status RedisHashes::HRandField(const Slice& key, int64_t count, bool with_values, std::vector<std::string>* res) {
std::string meta_value;
Status s = db_->Get(default_read_options_, handles_[0], key, &meta_value);
if (!s.ok()) {
return s;
}
ParsedHashesMetaValue parsed_hashes_meta_value(&meta_value);
auto hlen = parsed_hashes_meta_value.count();
if (parsed_hashes_meta_value.IsStale() || hlen == 0) {
return Status::NotFound();
}

if (count >= hlen) {
// case 1: count > 0 and >= hlen, return all fv
if (!with_values) {
return HKeys(key, res);
}
std::vector<FieldValue> fvs;
s = HGetall(key, &fvs);
for (const auto& [field, value] : fvs) {
res->push_back(field);
res->push_back(value);
}
return s;
}

std::vector<uint32_t> idxs;
if (count == 1) {
// special case of case 3
idxs.push_back(rand() % hlen);
} else if (count < 0) {
// case 2: count < 0, allow duplication
while (idxs.size() < -count) {
idxs.push_back(rand() % hlen);
}
std::sort(idxs.begin(), idxs.end());
} else {
// case 3: count > 0 and < hlen, no duplication
std::vector<uint32_t> range(hlen);
std::iota(range.begin(), range.end(), 0);
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(range.begin(), range.end(), g);
idxs.insert(idxs.cend(), range.begin(), range.begin() + count);
std::sort(idxs.begin(), idxs.end());
}

HashesDataKey hashes_data_key(key, parsed_hashes_meta_value.version(), "");
Slice prefix = hashes_data_key.Encode();
auto tmp_iter = db_->NewIterator(default_read_options_, handles_[1]);
std::unique_ptr<rocksdb::Iterator> iter{tmp_iter};
iter->Seek(prefix);
uint32_t save_idx{};
for (auto idx : idxs) {
while (save_idx < idx) {
iter->Next();
save_idx++;
}
if (!iter->Valid()) {
res->clear();
return Status::IOError(fmt::format("Should search for the data starting with {}", prefix.ToString()));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里理应找到对应的kv,iterator应该是valid的。如果不是,我这里设为了IOError,如果有其他更好的选择,请提出建议

}
ParsedHashesDataKey datakey(iter->key());
res->push_back(datakey.field().ToString());
if (with_values) {
res->push_back(iter->value().ToString());
}
}
return Status::OK();
}

Status RedisHashes::PKHScanRange(const Slice& key, const Slice& field_start, const std::string& field_end,
const Slice& pattern, int32_t limit, std::vector<FieldValue>* field_values,
std::string* next_field) {
Expand Down
1 change: 1 addition & 0 deletions src/storage/src/redis_hashes.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class RedisHashes : public Redis {
std::vector<FieldValue>* field_values, int64_t* next_cursor);
Status HScanx(const Slice& key, const std::string& start_field, const std::string& pattern, int64_t count,
std::vector<FieldValue>* field_values, std::string* next_field);
Status HRandField(const Slice& key, int64_t count, bool with_values, std::vector<std::string>* res);
Status PKHScanRange(const Slice& key, const Slice& field_start, const std::string& field_end, const Slice& pattern,
int32_t limit, std::vector<FieldValue>* field_values, std::string* next_field);
Status PKHRScanRange(const Slice& key, const Slice& field_start, const std::string& field_end, const Slice& pattern,
Expand Down
4 changes: 4 additions & 0 deletions src/storage/src/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ Status Storage::HScanx(const Slice& key, const std::string& start_field, const s
return hashes_db_->HScanx(key, start_field, pattern, count, field_values, next_field);
}

Status Storage::HRandField(const Slice& key, int64_t count, bool with_values, std::vector<std::string>* res) {
return hashes_db_->HRandField(key, count, with_values, res);
}

Status Storage::PKHScanRange(const Slice& key, const Slice& field_start, const std::string& field_end,
const Slice& pattern, int32_t limit, std::vector<FieldValue>* field_values,
std::string* next_field) {
Expand Down
84 changes: 84 additions & 0 deletions tests/hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,88 @@ var _ = Describe("Hash", Ordered, func() {
log.Println("Cmd HSET Begin")
Expect(client.HSet(ctx, "myhash", "one").Val()).NotTo(Equal("FooBar"))
})

It("Cmd HRandField Test", func() {
// set test data
key := "hrandfield"
kvs := map[string]string{
"field0": "value0",
"field1": "value1",
"field2": "value2",
}
for f, v := range kvs {
client.HSet(ctx, key, f, v)
}

num_test := 10
for i := 0; i < num_test; i++ {
// count < hlen
{
// without values
fields := client.HRandField(ctx, key, 2).Val()
Expect(len(fields)).To(Equal(2))
Expect(fields[0]).ToNot(Equal(fields[1]))
for _, field := range fields {
_, ok := kvs[field]
Expect(ok).To(BeTrue())
}

// with values
fvs := client.HRandFieldWithValues(ctx, key, 2).Val()
Expect(len(fvs)).To(Equal(2))
Expect(fvs[0].Key).ToNot(Equal(fvs[1].Key))
for _, fv := range fvs {
val, ok := kvs[fv.Key]
Expect(ok).To(BeTrue())
Expect(val).To(Equal(fv.Value))
}
}

// count > hlen
{
fields := client.HRandField(ctx, key, 10).Val()
Expect(len(fields)).To(Equal(3))
Expect(fields[0]).ToNot(Equal(fields[1]))
Expect(fields[2]).ToNot(Equal(fields[1]))
Expect(fields[0]).ToNot(Equal(fields[2]))
for _, field := range fields {
_, ok := kvs[field]
Expect(ok).To(BeTrue())
}

fvs := client.HRandFieldWithValues(ctx, key, 10).Val()
Expect(len(fvs)).To(Equal(3))
Expect(fvs[0].Key).ToNot(Equal(fvs[1].Key))
Expect(fvs[2].Key).ToNot(Equal(fvs[1].Key))
Expect(fvs[0].Key).ToNot(Equal(fvs[2].Key))
for _, fv := range fvs {
val, ok := kvs[fv.Key]
Expect(ok).To(BeTrue())
Expect(val).To(Equal(fv.Value))
}
}

// count < 0
{
fields := client.HRandField(ctx, key, -10).Val()
Expect(len(fields)).To(Equal(10))
for _, field := range fields {
_, ok := kvs[field]
Expect(ok).To(BeTrue())
}

fvs := client.HRandFieldWithValues(ctx, key, -10).Val()
Expect(len(fvs)).To(Equal(10))
for _, fv := range fvs {
val, ok := kvs[fv.Key]
Expect(ok).To(BeTrue())
Expect(val).To(Equal(fv.Value))
}
}
}

// the key not exist
res1 := client.HRandField(ctx, "not_exist_key", 1).Val()
Expect(len(res1)).To(Equal(0))
})
})