diff --git a/src/rime/commit_history.h b/src/rime/commit_history.h index 8a56c03b57..dfe8ea0fe3 100644 --- a/src/rime/commit_history.h +++ b/src/rime/commit_history.h @@ -29,6 +29,9 @@ class CommitHistory : public list { void Push(const KeyEvent& key_event); void Push(const Composition& composition, const string& input); string repr() const; + string latest_text() const { + return empty() ? string() : back().text; + } }; } // Namespace rime diff --git a/src/rime/composition.cc b/src/rime/composition.cc index 500d7b8dd6..d234365020 100644 --- a/src/rime/composition.cc +++ b/src/rime/composition.cc @@ -5,6 +5,7 @@ // 2011-06-19 GONG Chen // #include +#include #include #include #include @@ -167,4 +168,16 @@ string Composition::GetDebugText() const { return result; } +string Composition::GetTextBefore(size_t pos) const { + if (empty()) return string(); + for (const auto& seg : boost::adaptors::reverse(*this)) { + if (seg.end <= pos) { + if (auto cand = seg.GetSelectedCandidate()) { + return cand->text(); + } + } + } + return string(); +} + } // namespace rime diff --git a/src/rime/composition.h b/src/rime/composition.h index 5502947fe3..7bbb2bdfcf 100644 --- a/src/rime/composition.h +++ b/src/rime/composition.h @@ -29,6 +29,8 @@ class Composition : public Segmentation { string GetCommitText() const; string GetScriptText() const; string GetDebugText() const; + // Returns text of the last segment before the given position. + string GetTextBefore(size_t pos) const; }; } // namespace rime diff --git a/src/rime/gear/contextual_translation.cc b/src/rime/gear/contextual_translation.cc new file mode 100644 index 0000000000..6266195ec9 --- /dev/null +++ b/src/rime/gear/contextual_translation.cc @@ -0,0 +1,60 @@ +#include +#include +#include +#include + +namespace rime { + +const int kContextualSearchLimit = 32; + +bool ContextualTranslation::Replenish() { + vector> queue; + size_t end_pos = 0; + while (!translation_->exhausted() && + cache_.size() + queue.size() < kContextualSearchLimit) { + auto cand = translation_->Peek(); + DLOG(INFO) << cand->text() << " cache/queue: " + << cache_.size() << "/" << queue.size(); + if (cand->type() == "phrase" || cand->type() == "table") { + if (end_pos != cand->end()) { + end_pos = cand->end(); + AppendToCache(queue); + } + queue.push_back(Evaluate(As(cand))); + } else { + AppendToCache(queue); + cache_.push_back(cand); + } + if (!translation_->Next()) { + break; + } + } + AppendToCache(queue); + return !cache_.empty(); +} + +an ContextualTranslation::Evaluate(an phrase) { + auto sentence = New(phrase->language()); + sentence->Offset(phrase->start()); + bool is_rear = phrase->end() == input_.length(); + sentence->Extend(phrase->entry(), phrase->end(), is_rear, preceding_text_, + grammar_); + phrase->set_weight(sentence->weight()); + DLOG(INFO) << "contextual suggestion: " << phrase->text() + << " weight: " << phrase->weight(); + return phrase; +} + +static bool compare_by_weight_desc(const an& a, const an& b) { + return a->weight() > b->weight(); +} + +void ContextualTranslation::AppendToCache(vector>& queue) { + if (queue.empty()) return; + DLOG(INFO) << "appending to cache " << queue.size() << " candidates."; + std::sort(queue.begin(), queue.end(), compare_by_weight_desc); + std::copy(queue.begin(), queue.end(), std::back_inserter(cache_)); + queue.clear(); +} + +} // namespace rime diff --git a/src/rime/gear/contextual_translation.h b/src/rime/gear/contextual_translation.h new file mode 100644 index 0000000000..e817ea4c23 --- /dev/null +++ b/src/rime/gear/contextual_translation.h @@ -0,0 +1,38 @@ +// +// Copyright RIME Developers +// Distributed under the BSD License +// + +#include +#include + +namespace rime { + +class Candidate; +class Grammar; +class Phrase; + +class ContextualTranslation : public PrefetchTranslation { + public: + ContextualTranslation(an translation, + string input, + string preceding_text, + Grammar* grammar) + : PrefetchTranslation(translation), + input_(input), + preceding_text_(preceding_text), + grammar_(grammar) {} + + protected: + bool Replenish() override; + + private: + an Evaluate(an phrase); + void AppendToCache(vector>& queue); + + string input_; + string preceding_text_; + Grammar* grammar_; +}; + +} // namespace rime diff --git a/src/rime/gear/poet.cc b/src/rime/gear/poet.cc index afd3310aa6..5a5f870387 100644 --- a/src/rime/gear/poet.cc +++ b/src/rime/gear/poet.cc @@ -6,6 +6,8 @@ // // 2011-10-06 GONG Chen // +#include +#include #include #include #include @@ -21,14 +23,27 @@ inline static Grammar* create_grammar(Config* config) { return nullptr; } -Poet::Poet(const Language* language, Config* config) +Poet::Poet(const Language* language, Config* config, Compare compare) : language_(language), - grammar_(create_grammar(config)) {} + grammar_(create_grammar(config)), + compare_(compare) {} Poet::~Poet() {} +bool Poet::LeftAssociateCompare(const Sentence& one, const Sentence& other) { + return one.weight() < other.weight() || ( // left associate if even + one.weight() == other.weight() && ( + one.size() > other.size() || ( // less components is more favorable + one.size() == other.size() && + std::lexicographical_compare(one.syllable_lengths().begin(), + one.syllable_lengths().end(), + other.syllable_lengths().begin(), + other.syllable_lengths().end())))); +} + an Poet::MakeSentence(const WordGraph& graph, - size_t total_length) { + size_t total_length, + const string& preceding_text) { // TODO: save more intermediate sentence candidates map> sentences; sentences[0] = New(language_); @@ -43,16 +58,17 @@ an Poet::MakeSentence(const WordGraph& graph, if (start_pos == 0 && end_pos == total_length) continue; // exclude single words from the result DLOG(INFO) << "end pos: " << end_pos; + bool is_rear = end_pos == total_length; const DictEntryList& entries(x.second); - for (size_t i = 0; i < entries.size(); ++i) { - const auto& entry(entries[i]); + for (const auto& entry : entries) { auto new_sentence = New(*sentences[start_pos]); - bool is_rear = end_pos == total_length; - new_sentence->Extend(*entry, end_pos, is_rear, grammar_.get()); + new_sentence->Extend( + *entry, end_pos, is_rear, preceding_text, grammar_.get()); if (sentences.find(end_pos) == sentences.end() || - sentences[end_pos]->weight() < new_sentence->weight()) { - DLOG(INFO) << "updated sentences " << end_pos << ") with '" - << new_sentence->text() << "', " << new_sentence->weight(); + compare_(*sentences[end_pos], *new_sentence)) { + DLOG(INFO) << "updated sentences " << end_pos << ") with " + << new_sentence->text() << " weight: " + << new_sentence->weight(); sentences[end_pos] = std::move(new_sentence); } } diff --git a/src/rime/gear/poet.h b/src/rime/gear/poet.h index dfe8bcb1ce..ff55ed55f9 100644 --- a/src/rime/gear/poet.h +++ b/src/rime/gear/poet.h @@ -11,8 +11,10 @@ #define RIME_POET_H_ #include +#include #include #include +#include namespace rime { @@ -23,14 +25,42 @@ class Language; class Poet { public: - Poet(const Language* language, Config* config); + // sentence "less", used to compare sentences of the same input range. + using Compare = function; + + static bool CompareWeight(const Sentence& one, const Sentence& other) { + return one.weight() < other.weight(); + } + static bool LeftAssociateCompare(const Sentence& one, const Sentence& other); + + Poet(const Language* language, Config* config, + Compare compare = CompareWeight); ~Poet(); - an MakeSentence(const WordGraph& graph, size_t total_length); + an MakeSentence(const WordGraph& graph, + size_t total_length, + const string& preceding_text); + + template + an ContextualWeighted(an translation, + const string& input, + size_t start, + TranslatorT* translator) { + if (!translator->contextual_suggestions() || !grammar_) { + return translation; + } + auto preceding_text = translator->GetPrecedingText(start); + if (preceding_text.empty()) { + return translation; + } + return New( + translation, input, preceding_text, grammar_.get()); + } - protected: + private: const Language* language_; the grammar_; + Compare compare_; }; } // namespace rime diff --git a/src/rime/gear/script_translator.cc b/src/rime/gear/script_translator.cc index ebf16fecf9..7ca3a2ac3b 100644 --- a/src/rime/gear/script_translator.cc +++ b/src/rime/gear/script_translator.cc @@ -194,7 +194,11 @@ an ScriptTranslator::Query(const string& input, enable_user_dict ? user_dict_.get() : NULL)) { return nullptr; } - return New(result); + auto deduped = New(result); + if (contextual_suggestions_) { + return poet_->ContextualWeighted(deduped, input, segment.start, this); + } + return deduped; } string ScriptTranslator::FormatPreedit(const string& preedit) { @@ -214,6 +218,12 @@ string ScriptTranslator::Spell(const Code& code) { return result; } +string ScriptTranslator::GetPrecedingText(size_t start) const { + return !contextual_suggestions_ ? string() : + start > 0 ? engine_->context()->composition().GetTextBefore(start) : + engine_->context()->commit_history().latest_text(); +} + bool ScriptTranslator::Memorize(const CommitEntry& commit_entry) { bool update_elements = false; // avoid updating single character entries within a phrase which is @@ -538,12 +548,15 @@ an ScriptTranslation::MakeSentence(Dictionary* dict, } } } - auto sentence = poet_->MakeSentence(graph, syllable_graph.interpreted_length); - if (sentence) { + if (auto sentence = + poet_->MakeSentence(graph, + syllable_graph.interpreted_length, + translator_->GetPrecedingText(start_))) { sentence->Offset(start_); sentence->set_syllabifier(syllabifier_); + return sentence; } - return sentence; + return nullptr; } } // namespace rime diff --git a/src/rime/gear/script_translator.h b/src/rime/gear/script_translator.h index ea799e65e3..66f60a2caa 100644 --- a/src/rime/gear/script_translator.h +++ b/src/rime/gear/script_translator.h @@ -37,6 +37,7 @@ class ScriptTranslator : public Translator, string FormatPreedit(const string& preedit); string Spell(const Code& code); + string GetPrecedingText(size_t start) const; // options int max_homophones() const { return max_homophones_; } diff --git a/src/rime/gear/table_translator.cc b/src/rime/gear/table_translator.cc index 3fbbf0b8b0..2dcc4524c1 100644 --- a/src/rime/gear/table_translator.cc +++ b/src/rime/gear/table_translator.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -17,7 +18,7 @@ #include #include #include -#include +#include #include #include #include @@ -219,10 +220,9 @@ TableTranslator::TableTranslator(const Ticket& ticket) &max_phrase_length_); config->GetInt(name_space_ + "/max_homographs", &max_homographs_); - if (enable_sentence_ || sentence_over_completion_) { - if (auto* grammar_component = Grammar::Require("grammar")) { - grammar_.reset(grammar_component->Create(config)); - } + if (enable_sentence_ || sentence_over_completion_ || + contextual_suggestions_) { + poet_.reset(new Poet(language(), config, Poet::LeftAssociateCompare)); } } if (enable_encoder_ && user_dict_) { @@ -307,11 +307,12 @@ an TableTranslator::Query(const string& input, translation = sentence + translation; } } - if (translation) { - translation = New(translation); - } if (translation && translation->exhausted()) { - translation.reset(); // discard futile translation + return nullptr; + } + translation = New(translation); + if (contextual_suggestions_) { + return poet_->ContextualWeighted(translation, input, segment.start, this); } return translation; } @@ -343,7 +344,9 @@ bool TableTranslator::Memorize(const CommitEntry& commit_entry) { } string phrase; for (; it != history.rend(); ++it) { - if (it->type != "table" && it->type != "sentence" && it->type != "uniquified") + if (it->type != "table" && + it->type != "sentence" && + it->type != "uniquified") break; if (phrase.empty()) { phrase = it->text; // last word @@ -363,6 +366,12 @@ bool TableTranslator::Memorize(const CommitEntry& commit_entry) { return true; } +string TableTranslator::GetPrecedingText(size_t start) const { + return !contextual_suggestions_ ? string() : + start > 0 ? engine_->context()->composition().GetTextBefore(start) : + engine_->context()->commit_history().latest_text(); +} + // SentenceSyllabifier class SentenceSyllabifier : public PhraseSyllabifier { @@ -558,20 +567,22 @@ TableTranslator::MakeSentence(const string& input, size_t start, const int max_entries = max_homographs_; DictEntryCollector collector; UserDictEntryCollector user_phrase_collector; - map> sentences; - sentences[0] = New(language()); + WordGraph graph; + hash_set vertices = {0}; for (size_t start_pos = 0; start_pos < input.length(); ++start_pos) { - if (sentences.find(start_pos) == sentences.end()) + // find next reachable vertex in word graph + if (vertices.find(start_pos) == vertices.end()) continue; string active_input = input.substr(start_pos); string active_key = active_input + ' '; - UserDictEntryCollector collected_entries; + UserDictEntryCollector& collected_entries(graph[start_pos]); // lookup dictionaries if (user_dict_ && user_dict_->loaded()) { for (size_t len = 1; len <= active_input.length(); ++len) { size_t consumed_length = consume_trailing_delimiters(len, active_input, delimiters_); - auto& dest(collected_entries[consumed_length]); + size_t end_pos = start_pos + consumed_length; + auto& dest(collected_entries[end_pos]); if (dest.size() >= max_entries) continue; DLOG(INFO) << "active input: " << active_input << "[0, " << len << ")"; @@ -583,13 +594,14 @@ TableTranslator::MakeSentence(const string& input, size_t start, uter.AddFilter(CharsetFilter::FilterDictEntry); } if (!uter.exhausted()) { + vertices.insert(end_pos); if (start_pos == 0 && max_entries > 1) { UserDictEntryIterator uter_copy(uter); collect_entries(dest, uter_copy, max_entries); } else { collect_entries(dest, uter, max_entries); } - if (start_pos == 0) { + if (include_prefix_phrases && start_pos == 0) { // also provide words for manual composition // uter must not be consumed uter.Release(&user_phrase_collector[consumed_length]); @@ -607,7 +619,8 @@ TableTranslator::MakeSentence(const string& input, size_t start, for (size_t len = 1; len <= active_input.length(); ++len) { size_t consumed_length = consume_trailing_delimiters(len, active_input, delimiters_); - auto& dest(collected_entries[consumed_length]); + size_t end_pos = start_pos + consumed_length; + auto& dest(collected_entries[end_pos]); if (!dest.empty()) continue; DLOG(INFO) << "active input: " << active_input << "[0, " << len << ")"; @@ -619,13 +632,14 @@ TableTranslator::MakeSentence(const string& input, size_t start, uter.AddFilter(CharsetFilter::FilterDictEntry); } if (!uter.exhausted()) { + vertices.insert(end_pos); if (start_pos == 0 && max_entries > 1) { UserDictEntryIterator uter_copy(uter); collect_entries(dest, uter_copy, max_entries); } else { collect_entries(dest, uter, max_entries); } - if (start_pos == 0) { + if (include_prefix_phrases && start_pos == 0) { // also provide words for manual composition // uter must not be consumed uter.Release(&user_phrase_collector[consumed_length]); @@ -648,7 +662,8 @@ TableTranslator::MakeSentence(const string& input, size_t start, continue; size_t consumed_length = consume_trailing_delimiters(m.length, active_input, delimiters_); - auto& dest(collected_entries[consumed_length]); + size_t end_pos = start_pos + consumed_length; + auto& dest(collected_entries[end_pos]); if (dest.size() >= max_entries) continue; DictEntryIterator iter; @@ -657,13 +672,14 @@ TableTranslator::MakeSentence(const string& input, size_t start, iter.AddFilter(CharsetFilter::FilterDictEntry); } if (!iter.exhausted()) { + vertices.insert(end_pos); if (start_pos == 0 && max_entries - dest.size() > 1) { DictEntryIterator iter_copy = iter; collect_entries(dest, iter_copy, max_entries); } else { collect_entries(dest, iter, max_entries); } - if (start_pos == 0) { + if (include_prefix_phrases && start_pos == 0) { // also provide words for manual composition // iter must not be consumed collector[consumed_length] = std::move(iter); @@ -673,38 +689,23 @@ TableTranslator::MakeSentence(const string& input, size_t start, } } } - for (size_t len = 1; len <= active_input.length(); ++len) { - const auto& entries(collected_entries[len]); - if (entries.empty()) - continue; - size_t end_pos = start_pos + len; - bool is_rear = end_pos == input.length(); - for (const auto& entry : entries) { - // create a new sentence - auto new_sentence = New(*sentences[start_pos]); - new_sentence->Extend(*entry, end_pos, is_rear, grammar_.get()); - // compare and update sentences - if (sentences.find(end_pos) == sentences.end() || - sentences[end_pos]->weight() <= new_sentence->weight()) { - sentences[end_pos] = std::move(new_sentence); - } - } - } } - an result; - if (sentences.find(input.length()) != sentences.end()) { - result = Cached( + if (auto sentence = poet_->MakeSentence(graph, + input.length(), + GetPrecedingText(start))) { + auto result = Cached( this, - std::move(sentences[input.length()]), - include_prefix_phrases ? std::move(collector) : DictEntryCollector(), - include_prefix_phrases ? std::move(user_phrase_collector) : UserDictEntryCollector(), + std::move(sentence), + std::move(collector), + std::move(user_phrase_collector), input, start); if (result && filter_by_charset) { - result = New(result); + return New(result); } + return result; } - return result; + return nullptr; } } // namespace rime diff --git a/src/rime/gear/table_translator.h b/src/rime/gear/table_translator.h index 3948fcb730..ae1d464247 100644 --- a/src/rime/gear/table_translator.h +++ b/src/rime/gear/table_translator.h @@ -19,7 +19,7 @@ namespace rime { -class Grammar; +class Poet; class UnityTableEncoder; class TableTranslator : public Translator, @@ -29,13 +29,13 @@ class TableTranslator : public Translator, TableTranslator(const Ticket& ticket); virtual an Query(const string& input, - const Segment& segment); + const Segment& segment); virtual bool Memorize(const CommitEntry& commit_entry); an MakeSentence(const string& input, - size_t start, - bool include_prefix_phrases = false); - + size_t start, + bool include_prefix_phrases = false); + string GetPrecedingText(size_t start) const; UnityTableEncoder* encoder() const { return encoder_.get(); } protected: @@ -46,8 +46,8 @@ class TableTranslator : public Translator, bool encode_commit_history_ = true; int max_phrase_length_ = 5; int max_homographs_ = 1; + the poet_; the encoder_; - the grammar_; }; class TableTranslation : public Translation { diff --git a/src/rime/gear/translator_commons.cc b/src/rime/gear/translator_commons.cc index 63ee160484..8e305fcd92 100644 --- a/src/rime/gear/translator_commons.cc +++ b/src/rime/gear/translator_commons.cc @@ -91,8 +91,10 @@ bool Spans::HasVertex(size_t vertex) const { void Sentence::Extend(const DictEntry& entry, size_t end_pos, bool is_rear, + const string& preceding_text, Grammar* grammar) { - entry_->weight += Grammar::Evaluate(entry_->text, entry, is_rear, grammar); + const string& context = empty() ? preceding_text : text(); + entry_->weight += Grammar::Evaluate(context, entry, is_rear, grammar); entry_->text.append(entry.text); entry_->code.insert(entry_->code.end(), entry.code.begin(), @@ -101,7 +103,7 @@ void Sentence::Extend(const DictEntry& entry, syllable_lengths_.push_back(end_pos - end()); set_end(end_pos); DLOG(INFO) << "extend sentence " << end_pos << ") " - << entry_->text << " : " << entry_->weight; + << text() << " weight: " << weight(); } void Sentence::Offset(size_t offset) { @@ -118,6 +120,8 @@ TranslatorOptions::TranslatorOptions(const Ticket& ticket) { config->GetString(ticket.name_space + "/delimiter", &delimiters_) || config->GetString("speller/delimiter", &delimiters_); config->GetString(ticket.name_space + "/tag", &tag_); + config->GetBool(ticket.name_space + "/contextual_suggestions", + &contextual_suggestions_); config->GetBool(ticket.name_space + "/enable_completion", &enable_completion_); config->GetBool(ticket.name_space + "/strict_spelling", diff --git a/src/rime/gear/translator_commons.h b/src/rime/gear/translator_commons.h index 88e05aa97f..343fd74859 100644 --- a/src/rime/gear/translator_commons.h +++ b/src/rime/gear/translator_commons.h @@ -71,7 +71,9 @@ class Language; class Phrase : public Candidate { public: Phrase(const Language* language, - const string& type, size_t start, size_t end, + const string& type, + size_t start, + size_t end, const an& entry) : Candidate(type, start, end), language_(language), @@ -89,8 +91,8 @@ class Phrase : public Candidate { void set_syllabifier(an syllabifier) { syllabifier_ = syllabifier; } - double weight() const { return entry_->weight; } + void set_weight(double weight) { entry_->weight = weight; } Code& code() const { return entry_->code; } const DictEntry& entry() const { return *entry_; } const Language* language() const { return language_; } @@ -112,9 +114,7 @@ class Grammar; class Sentence : public Phrase { public: Sentence(const Language* language) - : Phrase(language, "sentence", 0, 0, New()) { - entry_->weight = 0.0; - } + : Phrase(language, "sentence", 0, 0, New()) {} Sentence(const Sentence& other) : Phrase(other), components_(other.components_), @@ -124,9 +124,18 @@ class Sentence : public Phrase { void Extend(const DictEntry& entry, size_t end_pos, bool is_rear, + const string& preceding_text, Grammar* grammar); void Offset(size_t offset); + bool empty() const { + return components_.empty(); + } + + size_t size() const { + return components_.size(); + } + const vector& components() const { return components_; } @@ -151,6 +160,10 @@ class TranslatorOptions { const string& delimiters() const { return delimiters_; } const string& tag() const { return tag_; } void set_tag(const string& tag) { tag_ = tag; } + bool contextual_suggestions() const { return contextual_suggestions_; } + void set_contextual_suggestions(bool enabled) { + contextual_suggestions_ = enabled; + } bool enable_completion() const { return enable_completion_; } void set_enable_completion(bool enabled) { enable_completion_ = enabled; } bool strict_spelling() const { return strict_spelling_; } @@ -163,6 +176,7 @@ class TranslatorOptions { protected: string delimiters_; string tag_ = "abc"; + bool contextual_suggestions_ = false; bool enable_completion_ = true; bool strict_spelling_ = false; double initial_quality_ = 0.;