Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: predict words longer than 4 syllables #832

Merged
merged 6 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 59 additions & 23 deletions src/rime/dict/dictionary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,22 @@ struct Chunk {
size_t size = 0;
size_t cursor = 0;
string remaining_code; // for predictive queries
size_t matching_code_size = 0;
double credibility = 0.0;

Chunk() = default;
Chunk(Table* t, const Code& c, const table::Entry* e, double cr = 0.0)
: table(t), code(c), entries(e), size(1), cursor(0), credibility(cr) {}
Chunk(Table* t,
const Code& c,
const table::Entry* e,
size_t m,
double cr = 0.0)
: table(t),
code(c),
entries(e),
size(1),
cursor(0),
matching_code_size(m),
credibility(cr) {}
Chunk(Table* t, const TableAccessor& a, double cr = 0.0)
: Chunk(t, a, string(), cr) {}
Chunk(Table* t, const TableAccessor& a, const string& r, double cr = 0.0)
Expand All @@ -39,7 +50,12 @@ struct Chunk {
size(a.remaining()),
cursor(0),
remaining_code(r),
matching_code_size(a.index_code().size()),
credibility(cr) {}

bool is_exact_match() const { return matching_code_size == code.size(); }

bool is_predictive_match() const { return matching_code_size < code.size(); }
};

struct QueryResult {
Expand All @@ -51,35 +67,49 @@ bool compare_chunk_by_head_element(const Chunk& a, const Chunk& b) {
return false;
if (!b.entries || b.cursor >= b.size)
return true;
if (a.is_exact_match() != b.is_exact_match())
return a.is_exact_match() > b.is_exact_match();
if (a.remaining_code.length() != b.remaining_code.length())
return a.remaining_code.length() < b.remaining_code.length();
return a.credibility + a.entries[a.cursor].weight >
b.credibility + b.entries[b.cursor].weight; // by weight desc
}

size_t match_extra_code(const table::Code* extra_code,
size_t depth,
const SyllableGraph& syll_graph,
size_t current_pos) {
struct CodeMatch {
bool success;
size_t depth;
size_t end_pos;
};

CodeMatch match_extra_code(const table::Code* extra_code,
size_t depth,
const SyllableGraph& syll_graph,
size_t current_pos,
bool predict_word) {
const CodeMatch kFailed{false, 0, 0};
if (!extra_code || depth >= extra_code->size)
return current_pos; // success
if (current_pos >= syll_graph.interpreted_length)
return 0; // failure (possibly success for completion in the future)
return {true, depth, current_pos};
if (current_pos >= syll_graph.interpreted_length) {
if (predict_word)
return {true, depth, syll_graph.interpreted_length};
else
return kFailed;
}
auto index = syll_graph.indices.find(current_pos);
if (index == syll_graph.indices.end())
return 0;
return kFailed;
SyllableId current_syll_id = extra_code->at[depth];
auto spellings = index->second.find(current_syll_id);
if (spellings == index->second.end())
return 0;
size_t best_match = 0;
return kFailed;
CodeMatch best_match = kFailed;
for (const SpellingProperties* props : spellings->second) {
size_t match_end_pos =
match_extra_code(extra_code, depth + 1, syll_graph, props->end_pos);
if (!match_end_pos)
CodeMatch match = match_extra_code(extra_code, depth + 1, syll_graph,
props->end_pos, predict_word);
if (!match.success)
continue;
if (match_end_pos > best_match)
best_match = match_end_pos;
if (match.end_pos > best_match.end_pos)
best_match = match;
}
return best_match;
}
Expand Down Expand Up @@ -127,6 +157,9 @@ an<DictEntry> DictEntryIterator::Peek() {
entry_->comment = "~" + chunk.remaining_code;
entry_->remaining_code_length = chunk.remaining_code.length();
}
if (chunk.is_predictive_match()) {
entry_->matching_code_size = chunk.matching_code_size;
}
}
return entry_;
}
Expand Down Expand Up @@ -199,6 +232,7 @@ static void lookup_table(Table* table,
DictEntryCollector* collector,
const SyllableGraph& syllable_graph,
size_t start_pos,
bool predict_word,
double initial_credibility) {
TableQueryResult result;
if (!table->Query(syllable_graph, start_pos, &result)) {
Expand All @@ -211,12 +245,13 @@ static void lookup_table(Table* table,
double cr = initial_credibility + a.credibility();
if (a.extra_code()) {
do {
size_t actual_end_pos = dictionary::match_extra_code(
a.extra_code(), 0, syllable_graph, end_pos);
if (actual_end_pos == 0)
dictionary::CodeMatch match = dictionary::match_extra_code(
a.extra_code(), 0, syllable_graph, end_pos, predict_word);
if (!match.success)
continue;
(*collector)[actual_end_pos].AddChunk(
{table, a.code(), a.entry(), cr});
size_t matching_code_size = a.index_code().size() + match.depth;
(*collector)[match.end_pos].AddChunk(
{table, a.code(), a.entry(), matching_code_size, cr});
} while (a.Next());
} else {
(*collector)[end_pos].AddChunk({table, a, cr});
Expand All @@ -227,6 +262,7 @@ static void lookup_table(Table* table,

an<DictEntryCollector> Dictionary::Lookup(const SyllableGraph& syllable_graph,
size_t start_pos,
bool predict_word,
double initial_credibility) {
if (!loaded())
return nullptr;
Expand All @@ -235,7 +271,7 @@ an<DictEntryCollector> Dictionary::Lookup(const SyllableGraph& syllable_graph,
if (!table->IsOpen())
continue;
lookup_table(table.get(), collector.get(), syllable_graph, start_pos,
initial_credibility);
predict_word, initial_credibility);
}
if (collector->empty())
return nullptr;
Expand Down
1 change: 1 addition & 0 deletions src/rime/dict/dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class Dictionary : public Class<Dictionary, const Ticket&> {

RIME_API an<DictEntryCollector> Lookup(const SyllableGraph& syllable_graph,
size_t start_pos,
bool predict_word = false,
double initial_credibility = 0.0);
// if predictive is true, do an expand search with limit,
// otherwise do an exact match.
Expand Down
68 changes: 60 additions & 8 deletions src/rime/dict/user_dictionary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
#include <rime/ticket.h>
#include <rime/algo/dynamics.h>
#include <rime/algo/syllabifier.h>
#include <rime/algo/strings.h>
#include <rime/dict/db.h>
#include <rime/dict/table.h>
#include <rime/dict/user_dictionary.h>
#include <rime/dict/vocabulary.h>

namespace rime {

struct DfsState {
size_t depth_limit;
size_t predict_word_from_depth;
TickCount present_tick;
Code code;
vector<double> credibility;
Expand All @@ -32,13 +35,15 @@ struct DfsState {
string key;
string value;

size_t depth() const { return code.size(); }

bool IsExactMatch(const string& prefix) {
return boost::starts_with(key, prefix + '\t');
}
bool IsPrefixMatch(const string& prefix) {
return boost::starts_with(key, prefix);
}
void RecruitEntry(size_t pos);
void RecruitEntry(size_t pos, map<string, SyllableId>* syllabary = nullptr);
bool NextEntry() {
if (!accessor->GetNextRecord(&key, &value)) {
key.clear();
Expand All @@ -63,11 +68,30 @@ struct DfsState {
}
};

void DfsState::RecruitEntry(size_t pos) {
void DfsState::RecruitEntry(size_t pos, map<string, SyllableId>* syllabary) {
string full_code;
auto e = UserDictionary::CreateDictEntry(key, value, present_tick,
credibility.back());
credibility.back(),
syllabary ? &full_code : nullptr);
if (e) {
e->code = code;
if (syllabary) {
vector<string> syllables =
strings::split(full_code, " ", strings::SplitBehavior::SkipToken);
Code numeric_code;
for (auto s = syllables.begin(); s != syllables.end(); ++s) {
auto found = syllabary->find(*s);
if (found == syllabary->end()) {
LOG(ERROR) << "failed to recruit dict entry '" << e->text
<< "', unrecognized syllable: " << *s;
return;
}
numeric_code.push_back(found->second);
}
e->code = numeric_code;
e->matching_code_size = code.size();
} else {
e->code = code;
}
DLOG(INFO) << "add entry at pos " << pos;
query_result[pos].push_back(e);
}
Expand Down Expand Up @@ -230,10 +254,36 @@ void UserDictionary::DfsLookup(const SyllableGraph& syll_graph,
if (!state->NextEntry()) // reached the end of db
break;
}
// the caller can limit the number of syllables to look up
if ((!state->depth_limit || state->code.size() < state->depth_limit) &&
state->IsPrefixMatch(prefix)) { // 'b |e ' vs. 'b e f \tBefore'
DfsLookup(syll_graph, end_pos, prefix, state);
auto next_index = syll_graph.indices.find(end_pos);
if (next_index == syll_graph.indices.end()) {
// reached the end of input, predict word if requested
if (state->predict_word_from_depth != 0 &&
state->depth() >= state->predict_word_from_depth) {
while (state->IsPrefixMatch(prefix)) {
DLOG(INFO) << "prefix match found for '" << prefix << "'.";
if (syllabary_.empty()) {
Syllabary syllabary;
if (!table_->GetSyllabary(&syllabary)) {
LOG(ERROR) << "failed to get syllabary for user dict: "
<< name();
break;
}
SyllableId syllable_id = 0;
for (auto s = syllabary.begin(); s != syllabary.end(); ++s) {
syllabary_[*s] = syllable_id++;
}
}
state->RecruitEntry(end_pos, &syllabary_);
if (!state->NextEntry()) // reached the end of db
break;
}
}
} else {
// the caller can limit the number of syllables to look up
if ((!state->depth_limit || state->depth() < state->depth_limit) &&
state->IsPrefixMatch(prefix)) { // 'b |e ' vs. 'b e f \tBefore'
DfsLookup(syll_graph, end_pos, prefix, state);
}
}
}
if (!state->IsPrefixMatch(current_prefix)) // 'b |' vs. 'g o \tGo'
Expand All @@ -254,12 +304,14 @@ an<UserDictEntryCollector> UserDictionary::Lookup(
const SyllableGraph& syll_graph,
size_t start_pos,
size_t depth_limit,
size_t predict_word_from_depth,
double initial_credibility) {
if (!table_ || !prism_ || !loaded() ||
start_pos >= syll_graph.interpreted_length)
return nullptr;
DfsState state;
state.depth_limit = depth_limit;
state.predict_word_from_depth = predict_word_from_depth;
FetchTickCount();
state.present_tick = tick_ + 1;
state.credibility.push_back(initial_credibility);
Expand Down
4 changes: 3 additions & 1 deletion src/rime/dict/user_dictionary.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class UserDictionary : public Class<UserDictionary, const Ticket&> {
an<UserDictEntryCollector> Lookup(const SyllableGraph& syllable_graph,
size_t start_pos,
size_t depth_limit = 0,
size_t predict_word_from_depth = 0,
double initial_credibility = 0.0);
size_t LookupWords(UserDictEntryIterator* result,
const string& input,
Expand All @@ -82,7 +83,7 @@ class UserDictionary : public Class<UserDictionary, const Ticket&> {
const string& value,
TickCount present_tick,
double credibility = 0.0,
string* full_code = NULL);
string* full_code = nullptr);

protected:
bool Initialize();
Expand All @@ -98,6 +99,7 @@ class UserDictionary : public Class<UserDictionary, const Ticket&> {
an<Db> db_;
an<Table> table_;
an<Prism> prism_;
map<string, SyllableId> syllabary_;
TickCount tick_ = 0;
time_t transaction_time_ = 0;
};
Expand Down
5 changes: 5 additions & 0 deletions src/rime/dict/vocabulary.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ using SyllableId = int32_t;

class Code : public vector<SyllableId> {
public:
Code() = default;
Code(const Code::const_iterator& begin, const Code::const_iterator& end)
: vector<SyllableId>(begin, end) {}

static const size_t kIndexCodeMaxLength = 3;

bool operator<(const Code& other) const;
Expand Down Expand Up @@ -48,6 +52,7 @@ struct DictEntry {
double weight = 0.0;
int commit_count = 0;
int remaining_code_length = 0;
int matching_code_size = 0;

DictEntry() = default;
ShortDictEntry ToShort() const;
Expand Down
Loading
Loading