Skip to content

Commit

Permalink
Part 2 of separating out the unicharset from the LSTM model, fixing c…
Browse files Browse the repository at this point in the history
…ommand line for training
  • Loading branch information
theraysmith committed Aug 2, 2017
1 parent 61adbdf commit 2633fef
Show file tree
Hide file tree
Showing 19 changed files with 625 additions and 222 deletions.
29 changes: 15 additions & 14 deletions dict/dawg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,16 +339,15 @@ bool SquishedDawg::read_squished_dawg(TFile *file) {
return true;
}

NODE_MAP SquishedDawg::build_node_map(inT32 *num_nodes) const {
std::unique_ptr<EDGE_REF[]> SquishedDawg::build_node_map(
inT32 *num_nodes) const {
EDGE_REF edge;
NODE_MAP node_map;
std::unique_ptr<EDGE_REF[]> node_map(new EDGE_REF[num_edges_]);
inT32 node_counter;
inT32 num_edges;

node_map = (NODE_MAP) malloc(sizeof(EDGE_REF) * num_edges_);

for (edge = 0; edge < num_edges_; edge++) // init all slots
node_map [edge] = -1;
node_map[edge] = -1;

node_counter = num_forward_edges(0);

Expand All @@ -366,33 +365,34 @@ NODE_MAP SquishedDawg::build_node_map(inT32 *num_nodes) const {
edge--;
}
}
return (node_map);
return node_map;
}

void SquishedDawg::write_squished_dawg(FILE *file) {
bool SquishedDawg::write_squished_dawg(TFile *file) {
EDGE_REF edge;
inT32 num_edges;
inT32 node_count = 0;
NODE_MAP node_map;
EDGE_REF old_index;
EDGE_RECORD temp_record;

if (debug_level_) tprintf("write_squished_dawg\n");

node_map = build_node_map(&node_count);
std::unique_ptr<EDGE_REF[]> node_map(build_node_map(&node_count));

// Write the magic number to help detecting a change in endianness.
inT16 magic = kDawgMagicNumber;
fwrite(&magic, sizeof(inT16), 1, file);
fwrite(&unicharset_size_, sizeof(inT32), 1, file);
if (file->FWrite(&magic, sizeof(magic), 1) != 1) return false;
if (file->FWrite(&unicharset_size_, sizeof(unicharset_size_), 1) != 1)
return false;

// Count the number of edges in this Dawg.
num_edges = 0;
for (edge=0; edge < num_edges_; edge++)
if (forward_edge(edge))
num_edges++;

fwrite(&num_edges, sizeof(inT32), 1, file); // write edge count to file
// Write edge count to file.
if (file->FWrite(&num_edges, sizeof(num_edges), 1) != 1) return false;

if (debug_level_) {
tprintf("%d nodes in DAWG\n", node_count);
Expand All @@ -405,7 +405,8 @@ void SquishedDawg::write_squished_dawg(FILE *file) {
old_index = next_node_from_edge_rec(edges_[edge]);
set_next_node(edge, node_map[old_index]);
temp_record = edges_[edge];
fwrite(&(temp_record), sizeof(EDGE_RECORD), 1, file);
if (file->FWrite(&temp_record, sizeof(temp_record), 1) != 1)
return false;
set_next_node(edge, old_index);
} while (!last_edge(edge++));

Expand All @@ -416,7 +417,7 @@ void SquishedDawg::write_squished_dawg(FILE *file) {
edge--;
}
}
free(node_map);
return true;
}

} // namespace tesseract
26 changes: 15 additions & 11 deletions dict/dawg.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
I n c l u d e s
----------------------------------------------------------------------*/

#include <memory>
#include "elst.h"
#include "ratngs.h"
#include "params.h"
#include "ratngs.h"
#include "tesscallback.h"

#ifndef __GNUC__
Expand Down Expand Up @@ -483,18 +484,22 @@ class SquishedDawg : public Dawg {
void print_node(NODE_REF node, int max_num_edges) const;

/// Writes the squished/reduced Dawg to a file.
void write_squished_dawg(FILE *file);
bool write_squished_dawg(TFile *file);

/// Opens the file with the given filename and writes the
/// squished/reduced Dawg to the file.
void write_squished_dawg(const char *filename) {
FILE *file = fopen(filename, "wb");
if (file == NULL) {
tprintf("Error opening %s\n", filename);
exit(1);
bool write_squished_dawg(const char *filename) {
TFile file;
file.OpenWrite(nullptr);
if (!this->write_squished_dawg(&file)) {
tprintf("Error serializing %s\n", filename);
return false;
}
this->write_squished_dawg(file);
fclose(file);
if (!file.CloseWrite(filename, nullptr)) {
tprintf("Error writing file %s\n", filename);
return false;
}
return true;
}

private:
Expand Down Expand Up @@ -549,8 +554,7 @@ class SquishedDawg : public Dawg {
tprintf("__________________________\n");
}
/// Constructs a mapping from the memory node indices to disk node indices.
NODE_MAP build_node_map(inT32 *num_nodes) const;

std::unique_ptr<EDGE_REF[]> build_node_map(inT32 *num_nodes) const;

// Member variables.
EDGE_ARRAY edges_;
Expand Down
39 changes: 17 additions & 22 deletions dict/trie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,51 +290,46 @@ bool Trie::read_and_add_word_list(const char *filename,
const UNICHARSET &unicharset,
Trie::RTLReversePolicy reverse_policy) {
GenericVector<STRING> word_list;
if (!read_word_list(filename, unicharset, reverse_policy, &word_list))
return false;
if (!read_word_list(filename, &word_list)) return false;
word_list.sort(sort_strings_by_dec_length);
return add_word_list(word_list, unicharset);
return add_word_list(word_list, unicharset, reverse_policy);
}

bool Trie::read_word_list(const char *filename,
const UNICHARSET &unicharset,
Trie::RTLReversePolicy reverse_policy,
GenericVector<STRING>* words) {
FILE *word_file;
char string[CHARS_PER_LINE];
char line_str[CHARS_PER_LINE];
int word_count = 0;

word_file = fopen(filename, "rb");
if (word_file == NULL) return false;

while (fgets(string, CHARS_PER_LINE, word_file) != NULL) {
chomp_string(string); // remove newline
WERD_CHOICE word(string, unicharset);
if ((reverse_policy == RRP_REVERSE_IF_HAS_RTL &&
word.has_rtl_unichar_id()) ||
reverse_policy == RRP_FORCE_REVERSE) {
word.reverse_and_mirror_unichar_ids();
}
while (fgets(line_str, sizeof(line_str), word_file) != NULL) {
chomp_string(line_str); // remove newline
STRING word_str(line_str);
++word_count;
if (debug_level_ && word_count % 10000 == 0)
tprintf("Read %d words so far\n", word_count);
if (word.length() != 0 && !word.contains_unichar_id(INVALID_UNICHAR_ID)) {
words->push_back(word.unichar_string());
} else if (debug_level_) {
tprintf("Skipping invalid word %s\n", string);
if (debug_level_ >= 3) word.print();
}
words->push_back(word_str);
}
if (debug_level_)
tprintf("Read %d words total.\n", word_count);
fclose(word_file);
return true;
}

bool Trie::add_word_list(const GenericVector<STRING>& words,
const UNICHARSET &unicharset) {
bool Trie::add_word_list(const GenericVector<STRING> &words,
const UNICHARSET &unicharset,
Trie::RTLReversePolicy reverse_policy) {
for (int i = 0; i < words.size(); ++i) {
WERD_CHOICE word(words[i].string(), unicharset);
if (word.length() == 0 || word.contains_unichar_id(INVALID_UNICHAR_ID))
continue;
if ((reverse_policy == RRP_REVERSE_IF_HAS_RTL &&
word.has_rtl_unichar_id()) ||
reverse_policy == RRP_FORCE_REVERSE) {
word.reverse_and_mirror_unichar_ids();
}
if (!word_in_dawg(word)) {
add_word_to_dawg(word);
if (!word_in_dawg(word)) {
Expand Down
12 changes: 5 additions & 7 deletions dict/trie.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,16 @@ class Trie : public Dawg {
const UNICHARSET &unicharset,
Trie::RTLReversePolicy reverse);

// Reads a list of words from the given file, applying the reverse_policy,
// according to information in the unicharset.
// Reads a list of words from the given file.
// Returns false on error.
bool read_word_list(const char *filename,
const UNICHARSET &unicharset,
Trie::RTLReversePolicy reverse_policy,
GenericVector<STRING>* words);
// Adds a list of words previously read using read_word_list to the trie
// using the given unicharset to convert to unichar-ids.
// using the given unicharset and reverse_policy to convert to unichar-ids.
// Returns false on error.
bool add_word_list(const GenericVector<STRING>& words,
const UNICHARSET &unicharset);
bool add_word_list(const GenericVector<STRING> &words,
const UNICHARSET &unicharset,
Trie::RTLReversePolicy reverse_policy);

// Inserts the list of patterns from the given file into the Trie.
// The pattern list file should contain one pattern per line in UTF-8 format.
Expand Down
76 changes: 12 additions & 64 deletions lstm/lstmtrainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,6 @@ bool LSTMTrainer::TryLoadingCheckpoint(const char* filename) {
return checkpoint_reader_->Run(data, this);
}

// Initializes the character set encode/decode mechanism.
// train_flags control training behavior according to the TrainingFlags
// enum, including character set encoding.
// script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided,
// fully initializes the unicharset from the universal unicharsets.
// Note: Call before InitNetwork!
void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset,
const STRING& script_dir, int train_flags) {
EmptyConstructor();
training_flags_ = train_flags;
ccutil_.unicharset.CopyFrom(unicharset);
null_char_ = GetUnicharset().has_special_codes() ? UNICHAR_BROKEN
: GetUnicharset().size();
SetUnicharsetProperties(script_dir);
}

// Initializes the trainer with a network_spec in the network description
// net_flags control network behavior according to the NetworkFlags enum.
// There isn't really much difference between them - only where the effects
Expand Down Expand Up @@ -278,9 +262,10 @@ void LSTMTrainer::DebugNetwork() {
// Loads a set of lstmf files that were created using the lstm.train config to
// tesseract into memory ready for training. Returns false if nothing was
// loaded.
bool LSTMTrainer::LoadAllTrainingData(const GenericVector<STRING>& filenames) {
bool LSTMTrainer::LoadAllTrainingData(const GenericVector<STRING>& filenames,
CachingStrategy cache_strategy) {
training_data_.Clear();
return training_data_.LoadDocuments(filenames, CacheStrategy(), file_reader_);
return training_data_.LoadDocuments(filenames, cache_strategy, file_reader_);
}

// Keeps track of best and locally worst char error_rate and launches tests
Expand Down Expand Up @@ -908,6 +893,15 @@ bool LSTMTrainer::ReadLocalTrainingDump(const TessdataManager* mgr,
return DeSerialize(mgr, &fp);
}

// Writes the full recognition traineddata to the given filename.
bool LSTMTrainer::SaveTraineddata(const STRING& filename) {
GenericVector<char> recognizer_data;
SaveRecognitionDump(&recognizer_data);
mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
recognizer_data.size());
return mgr_.SaveFile(filename, file_writer_);
}

// Writes the recognizer to memory, so that it can be used for testing later.
void LSTMTrainer::SaveRecognitionDump(GenericVector<char>* data) const {
TFile fp;
Expand Down Expand Up @@ -964,52 +958,6 @@ void LSTMTrainer::EmptyConstructor() {
InitIterations();
}

// Sets the unicharset properties using the given script_dir as a source of
// script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets
// up the recoder_ to simplify the unicharset.
void LSTMTrainer::SetUnicharsetProperties(const STRING& script_dir) {
tprintf("Setting unichar properties\n");
for (int s = 0; s < GetUnicharset().get_script_table_size(); ++s) {
if (strcmp("NULL", GetUnicharset().get_script_from_script_id(s)) == 0)
continue;
// Load the unicharset for the script if available.
STRING filename = script_dir + "/" +
GetUnicharset().get_script_from_script_id(s) +
".unicharset";
UNICHARSET script_set;
GenericVector<char> data;
if ((*file_reader_)(filename, &data) &&
script_set.load_from_inmemory_file(&data[0], data.size())) {
tprintf("Setting properties for script %s\n",
GetUnicharset().get_script_from_script_id(s));
ccutil_.unicharset.SetPropertiesFromOther(script_set);
}
}
if (IsRecoding()) {
STRING filename = script_dir + "/radical-stroke.txt";
GenericVector<char> data;
if ((*file_reader_)(filename, &data)) {
data += '\0';
STRING stroke_table = &data[0];
if (recoder_.ComputeEncoding(GetUnicharset(), null_char_,
&stroke_table)) {
RecodedCharID code;
recoder_.EncodeUnichar(null_char_, &code);
null_char_ = code(0);
// Space should encode as itself.
recoder_.EncodeUnichar(UNICHAR_SPACE, &code);
ASSERT_HOST(code(0) == UNICHAR_SPACE);
return;
}
} else {
tprintf("Failed to load radical-stroke info from: %s\n",
filename.string());
}
}
training_flags_ |= TF_COMPRESS_UNICHARSET;
recoder_.SetupPassThrough(GetUnicharset());
}

// Outputs the string and periodically displays the given network inputs
// as an image in the given window, and the corresponding labels at the
// corresponding x_starts.
Expand Down
22 changes: 5 additions & 17 deletions lstm/lstmtrainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,6 @@ class LSTMTrainer : public LSTMRecognizer {
// false in case of failure.
bool TryLoadingCheckpoint(const char* filename);

// Initializes the character set encode/decode mechanism.
// train_flags control training behavior according to the TrainingFlags
// enum, including character set encoding.
// script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided,
// fully initializes the unicharset from the universal unicharsets.
// Note: Call before InitNetwork!
void InitCharSet(const UNICHARSET& unicharset, const STRING& script_dir,
int train_flags);
// Initializes the character set encode/decode mechanism directly from a
// previously setup traineddata containing dawgs, UNICHARSET and
// UnicharCompress. Note: Call before InitNetwork!
Expand Down Expand Up @@ -186,7 +178,8 @@ class LSTMTrainer : public LSTMRecognizer {
// Loads a set of lstmf files that were created using the lstm.train config to
// tesseract into memory ready for training. Returns false if nothing was
// loaded.
bool LoadAllTrainingData(const GenericVector<STRING>& filenames);
bool LoadAllTrainingData(const GenericVector<STRING>& filenames,
CachingStrategy cache_strategy);

// Keeps track of best and locally worst error rate, using internally computed
// values. See MaintainCheckpointsSpecific for more detail.
Expand Down Expand Up @@ -315,12 +308,12 @@ class LSTMTrainer : public LSTMRecognizer {
// Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
void SetupCheckpointInfo();

// Writes the full recognition traineddata to the given filename.
bool SaveTraineddata(const STRING& filename);

// Writes the recognizer to memory, so that it can be used for testing later.
void SaveRecognitionDump(GenericVector<char>* data) const;

// Writes current best model to a file, unless it has already been written.
bool SaveBestModel(FileWriter writer) const;

// Returns a suitable filename for a training dump, based on the model_base_,
// the iteration and the error rates.
STRING DumpFilename() const;
Expand All @@ -336,11 +329,6 @@ class LSTMTrainer : public LSTMRecognizer {
// Factored sub-constructor sets up reasonable default values.
void EmptyConstructor();

// Sets the unicharset properties using the given script_dir as a source of
// script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets
// up the recoder_ to simplify the unicharset.
void SetUnicharsetProperties(const STRING& script_dir);

// Outputs the string and periodically displays the given network inputs
// as an image in the given window, and the corresponding labels at the
// corresponding x_starts.
Expand Down
Loading

0 comments on commit 2633fef

Please sign in to comment.