Skip to content

Commit

Permalink
unittest: Add mastertrainer_test (only works partially)
Browse files Browse the repository at this point in the history
The test currently has subtests which fail because of missing files.

Signed-off-by: Stefan Weil <[email protected]>
  • Loading branch information
stweil committed Oct 12, 2018
1 parent f93fb9d commit 2916dc8
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 32 deletions.
4 changes: 4 additions & 0 deletions unittest/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ check_PROGRAMS = \
lang_model_test \
linlsq_test \
loadlang_test \
mastertrainer_test \
matrix_test \
nthitem_test \
osd_test \
Expand Down Expand Up @@ -185,6 +186,9 @@ linlsq_test_LDADD = $(GTEST_LIBS) $(TESS_LIBS)
loadlang_test_SOURCES = loadlang_test.cc
loadlang_test_LDADD = $(GTEST_LIBS) $(TESS_LIBS) $(LEPTONICA_LIBS)

mastertrainer_test_SOURCES = mastertrainer_test.cc
mastertrainer_test_LDADD = $(ABSEIL_LIBS) $(GTEST_LIBS) $(TRAINING_LIBS) $(TESS_LIBS)

matrix_test_SOURCES = matrix_test.cc
matrix_test_LDADD = $(GTEST_LIBS) $(TESS_LIBS)

Expand Down
87 changes: 55 additions & 32 deletions unittest/mastertrainer_test.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
// (C) Copyright 2017, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Although this is a trivial-looking test, it exercises a lot of code:
// SampleIterator has to correctly iterate over the correct characters, or
Expand All @@ -15,36 +25,49 @@
#include <utility>
#include <vector>

#include "tesseract/ccutil/genericvector.h"
#include "tesseract/ccutil/unicharset.h"
#include "tesseract/classify/errorcounter.h"
#include "tesseract/classify/mastertrainer.h"
#include "tesseract/classify/shapeclassifier.h"
#include "tesseract/classify/shapetable.h"
#include "tesseract/classify/trainingsample.h"
#include "tesseract/training/commontraining.h"
#include "tesseract/training/tessopt.h"
#include "absl/strings/numbers.h" // for safe_strto32
#include "absl/strings/str_split.h" // for absl::StrSplit

#include "include_gunit.h"

#include "genericvector.h"
#include "log.h" // for LOG
#include "unicharset.h"
#include "errorcounter.h"
#include "mastertrainer.h"
#include "shapeclassifier.h"
#include "shapetable.h"
#include "trainingsample.h"
#include "commontraining.h"
#include "tessopt.h" // tessoptind

// Commontraining command-line arguments for font_properties, xheights and
// unicharset.
DECLARE_string(F);
DECLARE_string(X);
DECLARE_string(U);
DECLARE_string(output_trainer);
DECLARE_STRING_PARAM_FLAG(F);
DECLARE_STRING_PARAM_FLAG(X);
DECLARE_STRING_PARAM_FLAG(U);
DECLARE_STRING_PARAM_FLAG(output_trainer);

// Specs of the MockClassifier.
const int kNumTopNErrs = 10;
const int kNumTop2Errs = kNumTopNErrs + 20;
const int kNumTop1Errs = kNumTop2Errs + 30;
const int kNumTopTopErrs = kNumTop1Errs + 25;
const int kNumNonReject = 1000;
const int kNumCorrect = kNumNonReject - kNumTop1Errs;
static const int kNumTopNErrs = 10;
static const int kNumTop2Errs = kNumTopNErrs + 20;
static const int kNumTop1Errs = kNumTop2Errs + 30;
static const int kNumTopTopErrs = kNumTop1Errs + 25;
static const int kNumNonReject = 1000;
static const int kNumCorrect = kNumNonReject - kNumTop1Errs;
// The total number of answers is given by the number of non-rejects plus
// all the multiple answers.
const int kNumAnswers = kNumNonReject + 2 * (kNumTop2Errs - kNumTopNErrs) +
static const int kNumAnswers = kNumNonReject + 2 * (kNumTop2Errs - kNumTopNErrs) +
(kNumTop1Errs - kNumTop2Errs) +
(kNumTopTopErrs - kNumTop1Errs);

static bool safe_strto32(const std::string& str, int* pResult)
{
long n = strtol(str.c_str(), nullptr, 0);
*pResult = n;
return true;
}

namespace tesseract {

// Mock ShapeClassifier that cheats by looking at the correct answer, and
Expand Down Expand Up @@ -138,13 +161,13 @@ const double kMin1lDistance = 0.25;
// The fixture for testing Tesseract.
class MasterTrainerTest : public testing::Test {
protected:
string TestDataNameToPath(const string& name) {
return file::JoinPath(FLAGS_test_srcdir, "testdata/" + name);
std::string TestDataNameToPath(const std::string& name) {
return file::JoinPath(TESTING_DIR, name);
}
string TessdataPath() {
return file::JoinPath(FLAGS_test_srcdir, "tessdata");
std::string TessdataPath() {
return TESSDATA_DIR;
}
string TmpNameToPath(const string& name) {
std::string TmpNameToPath(const std::string& name) {
return file::JoinPath(FLAGS_test_tmpdir, name);
}

Expand All @@ -161,11 +184,11 @@ class MasterTrainerTest : public testing::Test {
// if load_from_tmp, then reloads a master trainer that was saved by a
// previous call in which it was false.
void LoadMasterTrainer() {
FLAGS_output_trainer = TmpNameToPath("tmp_trainer");
FLAGS_F = TestDataNameToPath("font_properties");
FLAGS_X = TestDataNameToPath("eng.xheights");
FLAGS_U = TestDataNameToPath("eng.unicharset");
string tr_file_name(TestDataNameToPath("eng.Arial.exp0.tr"));
FLAGS_output_trainer = TmpNameToPath("tmp_trainer").c_str();
FLAGS_F = file::JoinPath(LANGDATA_DIR, "font_properties").c_str();
FLAGS_X = TestDataNameToPath("eng.xheights").c_str();
FLAGS_U = file::JoinPath(LANGDATA_DIR, "eng/eng.unicharset").c_str();
std::string tr_file_name(TestDataNameToPath("eng.Arial.exp0.tr"));
const char* argv[] = {tr_file_name.c_str()};
int argc = 1;
STRING file_prefix;
Expand Down Expand Up @@ -256,8 +279,8 @@ TEST_F(MasterTrainerTest, ErrorCounterTest) {
false, shape_classifier,
&accuracy_report);
LOG(INFO) << accuracy_report.string();
string result_string = accuracy_report.string();
std::vector<string> results =
std::string result_string = accuracy_report.string();
std::vector<std::string> results =
absl::StrSplit(result_string, '\t', absl::SkipEmpty());
EXPECT_EQ(tesseract::CT_SIZE + 1, results.size());
int result_values[tesseract::CT_SIZE];
Expand Down

0 comments on commit 2916dc8

Please sign in to comment.