Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug descriptors pmgdqh #47

Merged
merged 3 commits into from
Nov 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions src/DescriptorsCommand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ DescriptorsCommand::DescriptorsCommand(const std::string& cmd_name) :
// but not when there is an input error (such as wrong set_name).
// In case of wrong input, we need to inform to the user what
// went wrong.
std::string DescriptorsCommand::get_set_path(const std::string& set_name,
std::string DescriptorsCommand::get_set_path(PMGDQuery &query_tx,
const std::string& set_name,
int& dim)
{
// Will issue a read-only transaction to check
// if the Set exists
PMGDQuery query(*_pmgd_qh);
PMGDQuery query(query_tx.get_pmgd_qh());

Json::Value constraints, link;
Json::Value name_arr;
Expand Down Expand Up @@ -271,7 +272,7 @@ int AddDescriptor::construct_protobuf(
}

int dimensions;
std::string set_path = get_set_path(set_name, dimensions);
std::string set_path = get_set_path(query, set_name, dimensions);

if (set_path.empty()) {
error["info"] = "Set " + set_name + " not found";
Expand Down Expand Up @@ -356,7 +357,7 @@ int ClassifyDescriptor::construct_protobuf(
const std::string set_name = cmd["set"].asString();

int dimensions;
const std::string set_path = get_set_path(set_name, dimensions);
const std::string set_path = get_set_path(query, set_name, dimensions);

if (set_path.empty()) {
error["status"] = RSCommand::Error;
Expand Down Expand Up @@ -482,7 +483,7 @@ int FindDescriptor::construct_protobuf(
const std::string set_name = cmd["set"].asString();

int dimensions;
const std::string set_path = get_set_path(set_name, dimensions);
const std::string set_path = get_set_path(query, set_name, dimensions);

if (set_path.empty()) {
cp_result["status"] = RSCommand::Error;
Expand Down Expand Up @@ -619,11 +620,11 @@ int FindDescriptor::construct_protobuf(
auto cache_obj_id = VCL::get_uint64();
cp_result["cache_obj_id"] = Json::Int64(cache_obj_id);

_cache_map[cache_obj_id] = IDDistancePair();
_cache_map[cache_obj_id] = new IDDistancePair();

IDDistancePair& pair = _cache_map[cache_obj_id];
std::vector<long>& ids = pair.first;
std::vector<float>& distances = pair.second;
IDDistancePair* pair = _cache_map[cache_obj_id];
std::vector<long>& ids = pair->first;
std::vector<float>& distances = pair->second;

set->search((float*)blob.data(), 1, k_neighbors, ids, distances);

Expand Down Expand Up @@ -801,9 +802,9 @@ Json::Value FindDescriptor::construct_responses(
long cache_obj_id = cache["cache_obj_id"].asInt64();

// Get from Cache
IDDistancePair& pair = _cache_map[cache_obj_id];
ids = &pair.first;
distances = &pair.second;
IDDistancePair* pair = _cache_map[cache_obj_id];
ids = &(pair->first);
distances = &(pair->second);

findDesc = json_responses[1];

Expand Down Expand Up @@ -875,8 +876,11 @@ Json::Value FindDescriptor::construct_responses(
}

if (cache.isMember("cache_obj_id")) {
// TODO CHECK THIS UNSAFE ERASE
_cache_map.unsafe_erase(cache["cache_obj_id"].asInt64());
// We remove the vectors associated with that entry to
// free memory, without removing the entry from _cache_map
// because tbb does not have a lock free way to do this.
IDDistancePair* pair = _cache_map[cache["cache_obj_id"].asInt64()];
delete pair;
}
}

Expand Down
14 changes: 6 additions & 8 deletions src/DescriptorsCommand.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,22 @@ namespace VDMS{
{
protected:
DescriptorsManager* _dm;
PMGDQueryHandler* _pmgd_qh; // This needs to make read-transcations.

tbb::concurrent_unordered_map<long, IDDistancePair> _cache_map;
// IDDistancePair is a pointer so that we can free its content
// without having to use erase methods, which are not lock free
// for this data structure in tbb
tbb::concurrent_unordered_map<long, IDDistancePair*> _cache_map;

// Will return the path to the set and the dimensions
std::string get_set_path(const std::string& set, int& dim);
std::string get_set_path(PMGDQuery& query_tx,
const std::string& set, int& dim);

bool check_blob_size(const std::string& blob, const int dimensions,
const long n_desc);

public:
DescriptorsCommand(const std::string& cmd_name);

void set_pmgd_qh(PMGDQueryHandler* pmgd_qh)
{
_pmgd_qh = pmgd_qh;
}

virtual bool need_blob(const Json::Value& cmd) { return false; }

virtual int construct_protobuf(PMGDQuery& tx,
Expand Down
2 changes: 2 additions & 0 deletions src/PMGDQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ namespace VDMS {
//This is a reference to avoid copies
Json::Value& get_json_responses() {return _json_responses;}

PMGDQueryHandler& get_pmgd_qh() {return _pmgd_qh;}

// If constraints is not null, this becomes a conditional AddNode
void AddNode(int ref,
const std::string& tag,
Expand Down
8 changes: 0 additions & 8 deletions src/QueryHandler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,6 @@ QueryHandler::QueryHandler()
,ch_tx_send("ch_tx_send")
#endif
{
((DescriptorsCommand*)_rs_cmds["AddDescriptorSet"])
->set_pmgd_qh(&_pmgd_qh);
((DescriptorsCommand*)_rs_cmds["AddDescriptor"])
->set_pmgd_qh(&_pmgd_qh);
((DescriptorsCommand*)_rs_cmds["ClassifyDescriptor"])
->set_pmgd_qh(&_pmgd_qh);
((DescriptorsCommand*)_rs_cmds["FindDescriptor"])
->set_pmgd_qh(&_pmgd_qh);
}

void QueryHandler::process_connection(comm::Connection *c)
Expand Down
75 changes: 58 additions & 17 deletions tests/python/TestEntities.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

class TestEntities(unittest.TestCase):

def addEntity(self, thID=0):
def addEntity(self, thID, results):

db = vdms.vdms()
db.connect(hostname, port)
Expand All @@ -63,9 +63,14 @@ def addEntity(self, thID=0):
response, res_arr = db.query(all_queries)
# print (db.get_last_response_str())

self.assertEqual(response[0]["AddEntity"]["status"], 0)
try:
self.assertEqual(response[0]["AddEntity"]["status"], 0)
except:
results[thID] = -1

results[thID] = 0

def findEntity(self, thID):
def findEntity(self, thID, results):

db = vdms.vdms()
db.connect(hostname, port)
Expand All @@ -89,33 +94,69 @@ def findEntity(self, thID):

response, res_arr = db.query(all_queries)

self.assertEqual(response[0]["FindEntity"]["status"], 0)
self.assertEqual(response[0]["FindEntity"]["entities"][0]
["lastname"], "Ferro")
self.assertEqual(response[0]["FindEntity"]["entities"][0]
["threadid"], thID)
try:

self.assertEqual(response[0]["FindEntity"]["status"], 0)
self.assertEqual(response[0]["FindEntity"]["entities"][0]
["lastname"], "Ferro")
self.assertEqual(response[0]["FindEntity"]["entities"][0]
["threadid"], thID)
except:
results[thID] = -1

results[thID] = 0

def test_runMultipleAdds(self):

# Test concurrent AddEntities
concurrency = 32
thread_arr = []
results = [None] * concurrency
for i in range(0,concurrency):
thread_add = Thread(target=self.addEntity,args=(i, results) )
thread_add.start()
thread_arr.append(thread_add)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these set of changes for? I am lazy to figure it out myself :) seems like a different commit from a quick scan.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stuff for checking that concurrent read/writes worked. you only need to worry about it if the test fails :)


idx = 0
error_counter = 0
for th in thread_arr:
th.join()
if (results[idx] == -1):
error_counter += 1
idx += 1

def ztest_runMultipleAdds(self):
self.assertEqual(error_counter, 0)

simultaneous = 1000;
thread_arr = []
for i in range(1,simultaneous):
thread_add = Thread(target=self.addEntity,args=(i,) )

# Tests concurrent AddEntities and FindEntities (that should exists)
results = [None] * concurrency * 2
for i in range(0,concurrency):
addidx = concurrency + i
thread_add = Thread(target=self.addEntity,args=(addidx, results) )
thread_add.start()
thread_arr.append(thread_add)
time.sleep(0.002)

for i in range(1,simultaneous):
thread_find = Thread(target=self.findEntity,args=(i,) )
thread_find = Thread(
target=self.findEntity,args=(i, results) )
thread_find.start()
thread_arr.append(thread_find)

idx = 0
error_counter = 0
for th in thread_arr:
th.join();
if (results[idx] == -1):
error_counter += 1

idx += 1

self.assertEqual(error_counter, 0)

def test_addFindEntity(self):
self.addEntity(9000);
self.findEntity(9000);
results = [None] * 1
self.addEntity(0, results);
self.findEntity(0, results);

def test_addEntityWithLink(self):
db = vdms.vdms()
Expand Down
Loading