From dfc3e9691feaf918e822e344ee37b58529797921 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Wed, 18 Jul 2018 19:02:01 +0200 Subject: [PATCH 1/5] SquishedDawg: Use new serialization API Signed-off-by: Stefan Weil --- src/dict/dawg.cpp | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/dict/dawg.cpp b/src/dict/dawg.cpp index 27a28bf792..db626a4a4b 100644 --- a/src/dict/dawg.cpp +++ b/src/dict/dawg.cpp @@ -315,23 +315,20 @@ bool SquishedDawg::read_squished_dawg(TFile *file) { // Read the magic number and check that it matches kDawgMagicNumber, as // auto-endian fixing should make sure it is always correct. int16_t magic; - if (file->FReadEndian(&magic, sizeof(magic), 1) != 1) return false; + if (!file->DeSerialize(&magic)) return false; if (magic != kDawgMagicNumber) { tprintf("Bad magic number on dawg: %d vs %d\n", magic, kDawgMagicNumber); return false; } int32_t unicharset_size; - if (file->FReadEndian(&unicharset_size, sizeof(unicharset_size), 1) != 1) - return false; - if (file->FReadEndian(&num_edges_, sizeof(num_edges_), 1) != 1) return false; + if (!file->DeSerialize(&unicharset_size)) return false; + if (!file->DeSerialize(&num_edges_)) return false; ASSERT_HOST(num_edges_ > 0); // DAWG should not be empty Dawg::init(unicharset_size); edges_ = new EDGE_RECORD[num_edges_]; - if (file->FReadEndian(&edges_[0], sizeof(edges_[0]), num_edges_) != - num_edges_) - return false; + if (!file->DeSerialize(&edges_[0], num_edges_)) return false; if (debug_level_ > 2) { tprintf("type: %d lang: %s perm: %d unicharset_size: %d num_edges: %d\n", type_, lang_.string(), perm_, unicharset_size_, num_edges_); @@ -382,9 +379,8 @@ bool SquishedDawg::write_squished_dawg(TFile *file) { // Write the magic number to help detecting a change in endianness. int16_t magic = kDawgMagicNumber; - if (file->FWrite(&magic, sizeof(magic), 1) != 1) return false; - if (file->FWrite(&unicharset_size_, sizeof(unicharset_size_), 1) != 1) - return false; + if (!file->Serialize(&magic)) return false; + if (!file->Serialize(&unicharset_size_)) return false; // Count the number of edges in this Dawg. num_edges = 0; @@ -393,7 +389,7 @@ bool SquishedDawg::write_squished_dawg(TFile *file) { num_edges++; // Write edge count to file. - if (file->FWrite(&num_edges, sizeof(num_edges), 1) != 1) return false; + if (!file->Serialize(&num_edges)) return false; if (debug_level_) { tprintf("%d nodes in DAWG\n", node_count); @@ -406,8 +402,7 @@ bool SquishedDawg::write_squished_dawg(TFile *file) { old_index = next_node_from_edge_rec(edges_[edge]); set_next_node(edge, node_map[old_index]); temp_record = edges_[edge]; - if (file->FWrite(&temp_record, sizeof(temp_record), 1) != 1) - return false; + if (!file->Serialize(&temp_record)) return false; set_next_node(edge, old_index); } while (!last_edge(edge++)); From f4449ba41a689018f55eccc939bf6008278153f8 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Wed, 18 Jul 2018 19:07:49 +0200 Subject: [PATCH 2/5] Convolve: Use new serialization API Signed-off-by: Stefan Weil --- src/lstm/convolve.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/lstm/convolve.cpp b/src/lstm/convolve.cpp index 04d0afeb69..32518017f3 100644 --- a/src/lstm/convolve.cpp +++ b/src/lstm/convolve.cpp @@ -32,16 +32,15 @@ Convolve::Convolve(const STRING& name, int ni, int half_x, int half_y) // Writes to the given file. Returns false in case of error. bool Convolve::Serialize(TFile* fp) const { - if (!Network::Serialize(fp)) return false; - if (fp->FWrite(&half_x_, sizeof(half_x_), 1) != 1) return false; - if (fp->FWrite(&half_y_, sizeof(half_y_), 1) != 1) return false; - return true; + return Network::Serialize(fp) && + fp->Serialize(&half_x_) && + fp->Serialize(&half_y_); } // Reads from the given file. Returns false in case of error. bool Convolve::DeSerialize(TFile* fp) { - if (fp->FReadEndian(&half_x_, sizeof(half_x_), 1) != 1) return false; - if (fp->FReadEndian(&half_y_, sizeof(half_y_), 1) != 1) return false; + if (!fp->DeSerialize(&half_x_)) return false; + if (!fp->DeSerialize(&half_y_)) return false; no_ = ni_ * (2*half_x_ + 1) * (2*half_y_ + 1); return true; } From 45a7ccf2d2f429950c62d33fd838fed2711d1981 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Wed, 18 Jul 2018 19:09:06 +0200 Subject: [PATCH 3/5] LSTM: Use new serialization API Signed-off-by: Stefan Weil --- src/lstm/lstm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lstm/lstm.cpp b/src/lstm/lstm.cpp index 979c66b48b..904325f1f5 100644 --- a/src/lstm/lstm.cpp +++ b/src/lstm/lstm.cpp @@ -206,7 +206,7 @@ void LSTM::DebugWeights() { // Writes to the given file. Returns false in case of error. bool LSTM::Serialize(TFile* fp) const { if (!Network::Serialize(fp)) return false; - if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false; + if (!fp->Serialize(&na_)) return false; for (int w = 0; w < WT_COUNT; ++w) { if (w == GFS && !Is2D()) continue; if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false; @@ -218,7 +218,7 @@ bool LSTM::Serialize(TFile* fp) const { // Reads from the given file. Returns false in case of error. bool LSTM::DeSerialize(TFile* fp) { - if (fp->FReadEndian(&na_, sizeof(na_), 1) != 1) return false; + if (!fp->DeSerialize(&na_)) return false; if (type_ == NT_LSTM_SOFTMAX) { nf_ = no_; } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) { From 1dcda1aa8a7a38238a5280e471afe479cf35c464 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Wed, 18 Jul 2018 19:10:16 +0200 Subject: [PATCH 4/5] LSTMRecognizer: Use new serialization API Signed-off-by: Stefan Weil --- src/lstm/lstmrecognizer.cpp | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/src/lstm/lstmrecognizer.cpp b/src/lstm/lstmrecognizer.cpp index 523305ef5a..1b3ecee351 100644 --- a/src/lstm/lstmrecognizer.cpp +++ b/src/lstm/lstmrecognizer.cpp @@ -84,16 +84,13 @@ bool LSTMRecognizer::Serialize(const TessdataManager* mgr, TFile* fp) const { if (!network_->Serialize(fp)) return false; if (include_charsets && !GetUnicharset().save_to_file(fp)) return false; if (!network_str_.Serialize(fp)) return false; - if (fp->FWrite(&training_flags_, sizeof(training_flags_), 1) != 1) - return false; - if (fp->FWrite(&training_iteration_, sizeof(training_iteration_), 1) != 1) - return false; - if (fp->FWrite(&sample_iteration_, sizeof(sample_iteration_), 1) != 1) - return false; - if (fp->FWrite(&null_char_, sizeof(null_char_), 1) != 1) return false; - if (fp->FWrite(&adam_beta_, sizeof(adam_beta_), 1) != 1) return false; - if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false; - if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false; + if (!fp->Serialize(&training_flags_)) return false; + if (!fp->Serialize(&training_iteration_)) return false; + if (!fp->Serialize(&sample_iteration_)) return false; + if (!fp->Serialize(&null_char_)) return false; + if (!fp->Serialize(&adam_beta_)) return false; + if (!fp->Serialize(&learning_rate_)) return false; + if (!fp->Serialize(&momentum_)) return false; if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) return false; return true; } @@ -109,18 +106,13 @@ bool LSTMRecognizer::DeSerialize(const TessdataManager* mgr, TFile* fp) { if (include_charsets && !ccutil_.unicharset.load_from_file(fp, false)) return false; if (!network_str_.DeSerialize(fp)) return false; - if (fp->FReadEndian(&training_flags_, sizeof(training_flags_), 1) != 1) - return false; - if (fp->FReadEndian(&training_iteration_, sizeof(training_iteration_), 1) != - 1) - return false; - if (fp->FReadEndian(&sample_iteration_, sizeof(sample_iteration_), 1) != 1) - return false; - if (fp->FReadEndian(&null_char_, sizeof(null_char_), 1) != 1) return false; - if (fp->FReadEndian(&adam_beta_, sizeof(adam_beta_), 1) != 1) return false; - if (fp->FReadEndian(&learning_rate_, sizeof(learning_rate_), 1) != 1) - return false; - if (fp->FReadEndian(&momentum_, sizeof(momentum_), 1) != 1) return false; + if (!fp->DeSerialize(&training_flags_)) return false; + if (!fp->DeSerialize(&training_iteration_)) return false; + if (!fp->DeSerialize(&sample_iteration_)) return false; + if (!fp->DeSerialize(&null_char_)) return false; + if (!fp->DeSerialize(&adam_beta_)) return false; + if (!fp->DeSerialize(&learning_rate_)) return false; + if (!fp->DeSerialize(&momentum_)) return false; if (include_charsets && !LoadRecoder(fp)) return false; if (!include_charsets && !LoadCharsets(mgr)) return false; network_->SetRandomizer(&randomizer_); From b7b8dba5dbb8c27337cb6a4425bebf35fc4e4581 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Wed, 18 Jul 2018 19:23:13 +0200 Subject: [PATCH 5/5] LSTMTrainer: Use new serialization API Improve also portability by using int32_t instead of int for a serialized member variable. Signed-off-by: Stefan Weil --- src/lstm/lstmtrainer.cpp | 90 ++++++++++++++-------------------------- src/lstm/lstmtrainer.h | 4 +- 2 files changed, 32 insertions(+), 62 deletions(-) diff --git a/src/lstm/lstmtrainer.cpp b/src/lstm/lstmtrainer.cpp index dfabb3058a..0a7e47bde7 100644 --- a/src/lstm/lstmtrainer.cpp +++ b/src/lstm/lstmtrainer.cpp @@ -431,38 +431,25 @@ bool LSTMTrainer::TransitionTrainingStage(float error_threshold) { bool LSTMTrainer::Serialize(SerializeAmount serialize_amount, const TessdataManager* mgr, TFile* fp) const { if (!LSTMRecognizer::Serialize(mgr, fp)) return false; - if (fp->FWrite(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) - return false; - if (fp->FWrite(&prev_sample_iteration_, sizeof(prev_sample_iteration_), 1) != - 1) - return false; - if (fp->FWrite(&perfect_delay_, sizeof(perfect_delay_), 1) != 1) return false; - if (fp->FWrite(&last_perfect_training_iteration_, - sizeof(last_perfect_training_iteration_), 1) != 1) - return false; + if (!fp->Serialize(&learning_iteration_)) return false; + if (!fp->Serialize(&prev_sample_iteration_)) return false; + if (!fp->Serialize(&perfect_delay_)) return false; + if (!fp->Serialize(&last_perfect_training_iteration_)) return false; for (int i = 0; i < ET_COUNT; ++i) { if (!error_buffers_[i].Serialize(fp)) return false; } - if (fp->FWrite(&error_rates_, sizeof(error_rates_), 1) != 1) return false; - if (fp->FWrite(&training_stage_, sizeof(training_stage_), 1) != 1) - return false; + if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) return false; + if (!fp->Serialize(&training_stage_)) return false; uint8_t amount = serialize_amount; - if (fp->FWrite(&amount, sizeof(amount), 1) != 1) return false; + if (!fp->Serialize(&amount)) return false; if (serialize_amount == LIGHT) return true; // We are done. - if (fp->FWrite(&best_error_rate_, sizeof(best_error_rate_), 1) != 1) - return false; - if (fp->FWrite(&best_error_rates_, sizeof(best_error_rates_), 1) != 1) - return false; - if (fp->FWrite(&best_iteration_, sizeof(best_iteration_), 1) != 1) - return false; - if (fp->FWrite(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1) - return false; - if (fp->FWrite(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1) - return false; - if (fp->FWrite(&worst_iteration_, sizeof(worst_iteration_), 1) != 1) - return false; - if (fp->FWrite(&stall_iteration_, sizeof(stall_iteration_), 1) != 1) - return false; + if (!fp->Serialize(&best_error_rate_)) return false; + if (!fp->Serialize(&best_error_rates_[0], countof(best_error_rates_))) return false; + if (!fp->Serialize(&best_iteration_)) return false; + if (!fp->Serialize(&worst_error_rate_)) return false; + if (!fp->Serialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false; + if (!fp->Serialize(&worst_iteration_)) return false; + if (!fp->Serialize(&stall_iteration_)) return false; if (!best_model_data_.Serialize(fp)) return false; if (!worst_model_data_.Serialize(fp)) return false; if (serialize_amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp)) @@ -473,16 +460,14 @@ bool LSTMTrainer::Serialize(SerializeAmount serialize_amount, if (!sub_data.Serialize(fp)) return false; if (!best_error_history_.Serialize(fp)) return false; if (!best_error_iterations_.Serialize(fp)) return false; - if (fp->FWrite(&improvement_steps_, sizeof(improvement_steps_), 1) != 1) - return false; - return true; + return fp->Serialize(&improvement_steps_); } // Reads from the given file. Returns false in case of error. // NOTE: It is assumed that the trainer is never read cross-endian. bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) { if (!LSTMRecognizer::DeSerialize(mgr, fp)) return false; - if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) { + if (!fp->DeSerialize(&learning_iteration_)) { // Special case. If we successfully decoded the recognizer, but fail here // then it means we were just given a recognizer, so issue a warning and // allow it. @@ -491,37 +476,24 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) { network_->SetEnableTraining(TS_ENABLED); return true; } - if (fp->FReadEndian(&prev_sample_iteration_, sizeof(prev_sample_iteration_), - 1) != 1) - return false; - if (fp->FReadEndian(&perfect_delay_, sizeof(perfect_delay_), 1) != 1) - return false; - if (fp->FReadEndian(&last_perfect_training_iteration_, - sizeof(last_perfect_training_iteration_), 1) != 1) - return false; + if (!fp->DeSerialize(&prev_sample_iteration_)) return false; + if (!fp->DeSerialize(&perfect_delay_)) return false; + if (!fp->DeSerialize(&last_perfect_training_iteration_)) return false; for (int i = 0; i < ET_COUNT; ++i) { if (!error_buffers_[i].DeSerialize(fp)) return false; } - if (fp->FRead(&error_rates_, sizeof(error_rates_), 1) != 1) return false; - if (fp->FReadEndian(&training_stage_, sizeof(training_stage_), 1) != 1) - return false; + if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) return false; + if (!fp->DeSerialize(&training_stage_)) return false; uint8_t amount; - if (fp->FRead(&amount, sizeof(amount), 1) != 1) return false; + if (!fp->DeSerialize(&amount)) return false; if (amount == LIGHT) return true; // Don't read the rest. - if (fp->FReadEndian(&best_error_rate_, sizeof(best_error_rate_), 1) != 1) - return false; - if (fp->FReadEndian(&best_error_rates_, sizeof(best_error_rates_), 1) != 1) - return false; - if (fp->FReadEndian(&best_iteration_, sizeof(best_iteration_), 1) != 1) - return false; - if (fp->FReadEndian(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1) - return false; - if (fp->FReadEndian(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1) - return false; - if (fp->FReadEndian(&worst_iteration_, sizeof(worst_iteration_), 1) != 1) - return false; - if (fp->FReadEndian(&stall_iteration_, sizeof(stall_iteration_), 1) != 1) - return false; + if (!fp->DeSerialize(&best_error_rate_)) return false; + if (!fp->DeSerialize(&best_error_rates_[0], countof(best_error_rates_))) return false; + if (!fp->DeSerialize(&best_iteration_)) return false; + if (!fp->DeSerialize(&worst_error_rate_)) return false; + if (!fp->DeSerialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false; + if (!fp->DeSerialize(&worst_iteration_)) return false; + if (!fp->DeSerialize(&stall_iteration_)) return false; if (!best_model_data_.DeSerialize(fp)) return false; if (!worst_model_data_.DeSerialize(fp)) return false; if (amount != NO_BEST_TRAINER && !best_trainer_.DeSerialize(fp)) return false; @@ -536,9 +508,7 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) { } if (!best_error_history_.DeSerialize(fp)) return false; if (!best_error_iterations_.DeSerialize(fp)) return false; - if (fp->FReadEndian(&improvement_steps_, sizeof(improvement_steps_), 1) != 1) - return false; - return true; + return fp->DeSerialize(&improvement_steps_); } // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the diff --git a/src/lstm/lstmtrainer.h b/src/lstm/lstmtrainer.h index 0fb152f4c9..d488b6953a 100644 --- a/src/lstm/lstmtrainer.h +++ b/src/lstm/lstmtrainer.h @@ -147,7 +147,7 @@ class LSTMTrainer : public LSTMRecognizer { return best_iteration_; } int learning_iteration() const { return learning_iteration_; } - int improvement_steps() const { return improvement_steps_; } + int32_t improvement_steps() const { return improvement_steps_; } void set_perfect_delay(int delay) { perfect_delay_ = delay; } const GenericVector& best_trainer() const { return best_trainer_; } // Returns the error that was just calculated by PrepareForBackward. @@ -457,7 +457,7 @@ class LSTMTrainer : public LSTMRecognizer { GenericVector best_error_history_; GenericVector best_error_iterations_; // Number of iterations since the best_error_rate_ was 2% more than it is now. - int improvement_steps_; + int32_t improvement_steps_; // Number of iterations that yielded a non-zero delta error and thus provided // significant learning. learning_iteration_ <= training_iteration_. // learning_iteration_ is used to measure rate of learning progress.