Skip to content

Commit

Permalink
Merge pull request #1786 from stweil/serialize
Browse files Browse the repository at this point in the history
Use new serialization API
  • Loading branch information
egorpugin authored Jul 18, 2018
2 parents 790e115 + b7b8dba commit 3a7f5e4
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 105 deletions.
21 changes: 8 additions & 13 deletions src/dict/dawg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,23 +315,20 @@ bool SquishedDawg::read_squished_dawg(TFile *file) {
// Read the magic number and check that it matches kDawgMagicNumber, as
// auto-endian fixing should make sure it is always correct.
int16_t magic;
if (file->FReadEndian(&magic, sizeof(magic), 1) != 1) return false;
if (!file->DeSerialize(&magic)) return false;
if (magic != kDawgMagicNumber) {
tprintf("Bad magic number on dawg: %d vs %d\n", magic, kDawgMagicNumber);
return false;
}

int32_t unicharset_size;
if (file->FReadEndian(&unicharset_size, sizeof(unicharset_size), 1) != 1)
return false;
if (file->FReadEndian(&num_edges_, sizeof(num_edges_), 1) != 1) return false;
if (!file->DeSerialize(&unicharset_size)) return false;
if (!file->DeSerialize(&num_edges_)) return false;
ASSERT_HOST(num_edges_ > 0); // DAWG should not be empty
Dawg::init(unicharset_size);

edges_ = new EDGE_RECORD[num_edges_];
if (file->FReadEndian(&edges_[0], sizeof(edges_[0]), num_edges_) !=
num_edges_)
return false;
if (!file->DeSerialize(&edges_[0], num_edges_)) return false;
if (debug_level_ > 2) {
tprintf("type: %d lang: %s perm: %d unicharset_size: %d num_edges: %d\n",
type_, lang_.string(), perm_, unicharset_size_, num_edges_);
Expand Down Expand Up @@ -382,9 +379,8 @@ bool SquishedDawg::write_squished_dawg(TFile *file) {

// Write the magic number to help detecting a change in endianness.
int16_t magic = kDawgMagicNumber;
if (file->FWrite(&magic, sizeof(magic), 1) != 1) return false;
if (file->FWrite(&unicharset_size_, sizeof(unicharset_size_), 1) != 1)
return false;
if (!file->Serialize(&magic)) return false;
if (!file->Serialize(&unicharset_size_)) return false;

// Count the number of edges in this Dawg.
num_edges = 0;
Expand All @@ -393,7 +389,7 @@ bool SquishedDawg::write_squished_dawg(TFile *file) {
num_edges++;

// Write edge count to file.
if (file->FWrite(&num_edges, sizeof(num_edges), 1) != 1) return false;
if (!file->Serialize(&num_edges)) return false;

if (debug_level_) {
tprintf("%d nodes in DAWG\n", node_count);
Expand All @@ -406,8 +402,7 @@ bool SquishedDawg::write_squished_dawg(TFile *file) {
old_index = next_node_from_edge_rec(edges_[edge]);
set_next_node(edge, node_map[old_index]);
temp_record = edges_[edge];
if (file->FWrite(&temp_record, sizeof(temp_record), 1) != 1)
return false;
if (!file->Serialize(&temp_record)) return false;
set_next_node(edge, old_index);
} while (!last_edge(edge++));

Expand Down
11 changes: 5 additions & 6 deletions src/lstm/convolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,15 @@ Convolve::Convolve(const STRING& name, int ni, int half_x, int half_y)

// Writes to the given file. Returns false in case of error.
bool Convolve::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (fp->FWrite(&half_x_, sizeof(half_x_), 1) != 1) return false;
if (fp->FWrite(&half_y_, sizeof(half_y_), 1) != 1) return false;
return true;
return Network::Serialize(fp) &&
fp->Serialize(&half_x_) &&
fp->Serialize(&half_y_);
}

// Reads from the given file. Returns false in case of error.
bool Convolve::DeSerialize(TFile* fp) {
if (fp->FReadEndian(&half_x_, sizeof(half_x_), 1) != 1) return false;
if (fp->FReadEndian(&half_y_, sizeof(half_y_), 1) != 1) return false;
if (!fp->DeSerialize(&half_x_)) return false;
if (!fp->DeSerialize(&half_y_)) return false;
no_ = ni_ * (2*half_x_ + 1) * (2*half_y_ + 1);
return true;
}
Expand Down
4 changes: 2 additions & 2 deletions src/lstm/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ void LSTM::DebugWeights() {
// Writes to the given file. Returns false in case of error.
bool LSTM::Serialize(TFile* fp) const {
if (!Network::Serialize(fp)) return false;
if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false;
if (!fp->Serialize(&na_)) return false;
for (int w = 0; w < WT_COUNT; ++w) {
if (w == GFS && !Is2D()) continue;
if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
Expand All @@ -218,7 +218,7 @@ bool LSTM::Serialize(TFile* fp) const {
// Reads from the given file. Returns false in case of error.

bool LSTM::DeSerialize(TFile* fp) {
if (fp->FReadEndian(&na_, sizeof(na_), 1) != 1) return false;
if (!fp->DeSerialize(&na_)) return false;
if (type_ == NT_LSTM_SOFTMAX) {
nf_ = no_;
} else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
Expand Down
36 changes: 14 additions & 22 deletions src/lstm/lstmrecognizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,13 @@ bool LSTMRecognizer::Serialize(const TessdataManager* mgr, TFile* fp) const {
if (!network_->Serialize(fp)) return false;
if (include_charsets && !GetUnicharset().save_to_file(fp)) return false;
if (!network_str_.Serialize(fp)) return false;
if (fp->FWrite(&training_flags_, sizeof(training_flags_), 1) != 1)
return false;
if (fp->FWrite(&training_iteration_, sizeof(training_iteration_), 1) != 1)
return false;
if (fp->FWrite(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
return false;
if (fp->FWrite(&null_char_, sizeof(null_char_), 1) != 1) return false;
if (fp->FWrite(&adam_beta_, sizeof(adam_beta_), 1) != 1) return false;
if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false;
if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false;
if (!fp->Serialize(&training_flags_)) return false;
if (!fp->Serialize(&training_iteration_)) return false;
if (!fp->Serialize(&sample_iteration_)) return false;
if (!fp->Serialize(&null_char_)) return false;
if (!fp->Serialize(&adam_beta_)) return false;
if (!fp->Serialize(&learning_rate_)) return false;
if (!fp->Serialize(&momentum_)) return false;
if (include_charsets && IsRecoding() && !recoder_.Serialize(fp)) return false;
return true;
}
Expand All @@ -109,18 +106,13 @@ bool LSTMRecognizer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
if (include_charsets && !ccutil_.unicharset.load_from_file(fp, false))
return false;
if (!network_str_.DeSerialize(fp)) return false;
if (fp->FReadEndian(&training_flags_, sizeof(training_flags_), 1) != 1)
return false;
if (fp->FReadEndian(&training_iteration_, sizeof(training_iteration_), 1) !=
1)
return false;
if (fp->FReadEndian(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
return false;
if (fp->FReadEndian(&null_char_, sizeof(null_char_), 1) != 1) return false;
if (fp->FReadEndian(&adam_beta_, sizeof(adam_beta_), 1) != 1) return false;
if (fp->FReadEndian(&learning_rate_, sizeof(learning_rate_), 1) != 1)
return false;
if (fp->FReadEndian(&momentum_, sizeof(momentum_), 1) != 1) return false;
if (!fp->DeSerialize(&training_flags_)) return false;
if (!fp->DeSerialize(&training_iteration_)) return false;
if (!fp->DeSerialize(&sample_iteration_)) return false;
if (!fp->DeSerialize(&null_char_)) return false;
if (!fp->DeSerialize(&adam_beta_)) return false;
if (!fp->DeSerialize(&learning_rate_)) return false;
if (!fp->DeSerialize(&momentum_)) return false;
if (include_charsets && !LoadRecoder(fp)) return false;
if (!include_charsets && !LoadCharsets(mgr)) return false;
network_->SetRandomizer(&randomizer_);
Expand Down
90 changes: 30 additions & 60 deletions src/lstm/lstmtrainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,38 +431,25 @@ bool LSTMTrainer::TransitionTrainingStage(float error_threshold) {
bool LSTMTrainer::Serialize(SerializeAmount serialize_amount,
const TessdataManager* mgr, TFile* fp) const {
if (!LSTMRecognizer::Serialize(mgr, fp)) return false;
if (fp->FWrite(&learning_iteration_, sizeof(learning_iteration_), 1) != 1)
return false;
if (fp->FWrite(&prev_sample_iteration_, sizeof(prev_sample_iteration_), 1) !=
1)
return false;
if (fp->FWrite(&perfect_delay_, sizeof(perfect_delay_), 1) != 1) return false;
if (fp->FWrite(&last_perfect_training_iteration_,
sizeof(last_perfect_training_iteration_), 1) != 1)
return false;
if (!fp->Serialize(&learning_iteration_)) return false;
if (!fp->Serialize(&prev_sample_iteration_)) return false;
if (!fp->Serialize(&perfect_delay_)) return false;
if (!fp->Serialize(&last_perfect_training_iteration_)) return false;
for (int i = 0; i < ET_COUNT; ++i) {
if (!error_buffers_[i].Serialize(fp)) return false;
}
if (fp->FWrite(&error_rates_, sizeof(error_rates_), 1) != 1) return false;
if (fp->FWrite(&training_stage_, sizeof(training_stage_), 1) != 1)
return false;
if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) return false;
if (!fp->Serialize(&training_stage_)) return false;
uint8_t amount = serialize_amount;
if (fp->FWrite(&amount, sizeof(amount), 1) != 1) return false;
if (!fp->Serialize(&amount)) return false;
if (serialize_amount == LIGHT) return true; // We are done.
if (fp->FWrite(&best_error_rate_, sizeof(best_error_rate_), 1) != 1)
return false;
if (fp->FWrite(&best_error_rates_, sizeof(best_error_rates_), 1) != 1)
return false;
if (fp->FWrite(&best_iteration_, sizeof(best_iteration_), 1) != 1)
return false;
if (fp->FWrite(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1)
return false;
if (fp->FWrite(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1)
return false;
if (fp->FWrite(&worst_iteration_, sizeof(worst_iteration_), 1) != 1)
return false;
if (fp->FWrite(&stall_iteration_, sizeof(stall_iteration_), 1) != 1)
return false;
if (!fp->Serialize(&best_error_rate_)) return false;
if (!fp->Serialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
if (!fp->Serialize(&best_iteration_)) return false;
if (!fp->Serialize(&worst_error_rate_)) return false;
if (!fp->Serialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
if (!fp->Serialize(&worst_iteration_)) return false;
if (!fp->Serialize(&stall_iteration_)) return false;
if (!best_model_data_.Serialize(fp)) return false;
if (!worst_model_data_.Serialize(fp)) return false;
if (serialize_amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp))
Expand All @@ -473,16 +460,14 @@ bool LSTMTrainer::Serialize(SerializeAmount serialize_amount,
if (!sub_data.Serialize(fp)) return false;
if (!best_error_history_.Serialize(fp)) return false;
if (!best_error_iterations_.Serialize(fp)) return false;
if (fp->FWrite(&improvement_steps_, sizeof(improvement_steps_), 1) != 1)
return false;
return true;
return fp->Serialize(&improvement_steps_);
}

// Reads from the given file. Returns false in case of error.
// NOTE: It is assumed that the trainer is never read cross-endian.
bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
if (!LSTMRecognizer::DeSerialize(mgr, fp)) return false;
if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) {
if (!fp->DeSerialize(&learning_iteration_)) {
// Special case. If we successfully decoded the recognizer, but fail here
// then it means we were just given a recognizer, so issue a warning and
// allow it.
Expand All @@ -491,37 +476,24 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
network_->SetEnableTraining(TS_ENABLED);
return true;
}
if (fp->FReadEndian(&prev_sample_iteration_, sizeof(prev_sample_iteration_),
1) != 1)
return false;
if (fp->FReadEndian(&perfect_delay_, sizeof(perfect_delay_), 1) != 1)
return false;
if (fp->FReadEndian(&last_perfect_training_iteration_,
sizeof(last_perfect_training_iteration_), 1) != 1)
return false;
if (!fp->DeSerialize(&prev_sample_iteration_)) return false;
if (!fp->DeSerialize(&perfect_delay_)) return false;
if (!fp->DeSerialize(&last_perfect_training_iteration_)) return false;
for (int i = 0; i < ET_COUNT; ++i) {
if (!error_buffers_[i].DeSerialize(fp)) return false;
}
if (fp->FRead(&error_rates_, sizeof(error_rates_), 1) != 1) return false;
if (fp->FReadEndian(&training_stage_, sizeof(training_stage_), 1) != 1)
return false;
if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) return false;
if (!fp->DeSerialize(&training_stage_)) return false;
uint8_t amount;
if (fp->FRead(&amount, sizeof(amount), 1) != 1) return false;
if (!fp->DeSerialize(&amount)) return false;
if (amount == LIGHT) return true; // Don't read the rest.
if (fp->FReadEndian(&best_error_rate_, sizeof(best_error_rate_), 1) != 1)
return false;
if (fp->FReadEndian(&best_error_rates_, sizeof(best_error_rates_), 1) != 1)
return false;
if (fp->FReadEndian(&best_iteration_, sizeof(best_iteration_), 1) != 1)
return false;
if (fp->FReadEndian(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1)
return false;
if (fp->FReadEndian(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1)
return false;
if (fp->FReadEndian(&worst_iteration_, sizeof(worst_iteration_), 1) != 1)
return false;
if (fp->FReadEndian(&stall_iteration_, sizeof(stall_iteration_), 1) != 1)
return false;
if (!fp->DeSerialize(&best_error_rate_)) return false;
if (!fp->DeSerialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
if (!fp->DeSerialize(&best_iteration_)) return false;
if (!fp->DeSerialize(&worst_error_rate_)) return false;
if (!fp->DeSerialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
if (!fp->DeSerialize(&worst_iteration_)) return false;
if (!fp->DeSerialize(&stall_iteration_)) return false;
if (!best_model_data_.DeSerialize(fp)) return false;
if (!worst_model_data_.DeSerialize(fp)) return false;
if (amount != NO_BEST_TRAINER && !best_trainer_.DeSerialize(fp)) return false;
Expand All @@ -536,9 +508,7 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
}
if (!best_error_history_.DeSerialize(fp)) return false;
if (!best_error_iterations_.DeSerialize(fp)) return false;
if (fp->FReadEndian(&improvement_steps_, sizeof(improvement_steps_), 1) != 1)
return false;
return true;
return fp->DeSerialize(&improvement_steps_);
}

// De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
Expand Down
4 changes: 2 additions & 2 deletions src/lstm/lstmtrainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class LSTMTrainer : public LSTMRecognizer {
return best_iteration_;
}
int learning_iteration() const { return learning_iteration_; }
int improvement_steps() const { return improvement_steps_; }
int32_t improvement_steps() const { return improvement_steps_; }
void set_perfect_delay(int delay) { perfect_delay_ = delay; }
const GenericVector<char>& best_trainer() const { return best_trainer_; }
// Returns the error that was just calculated by PrepareForBackward.
Expand Down Expand Up @@ -457,7 +457,7 @@ class LSTMTrainer : public LSTMRecognizer {
GenericVector<double> best_error_history_;
GenericVector<int> best_error_iterations_;
// Number of iterations since the best_error_rate_ was 2% more than it is now.
int improvement_steps_;
int32_t improvement_steps_;
// Number of iterations that yielded a non-zero delta error and thus provided
// significant learning. learning_iteration_ <= training_iteration_.
// learning_iteration_ is used to measure rate of learning progress.
Expand Down

0 comments on commit 3a7f5e4

Please sign in to comment.