tesseract  5.0.0
tesseract::LSTMTrainerTest Class Reference

#include <lstm_test.h>

Inheritance diagram for tesseract::LSTMTrainerTest:

Protected Member Functions

void SetUp () override
 
 LSTMTrainerTest ()=default
 
std::string TestDataNameToPath (const std::string &name)
 
std::string TessDataNameToPath (const std::string &name)
 
std::string TestingNameToPath (const std::string &name)
 
void SetupTrainerEng (const std::string &network_spec, const std::string &model_name, bool recode, bool adam)
 
void SetupTrainer (const std::string &network_spec, const std::string &model_name, const std::string &unicharset_file, const std::string &lstmf_file, bool recode, bool adam, float learning_rate, bool layer_specific, const std::string &kLang)
 
double TrainIterations (int max_iterations)
 
double TestIterations (int max_iterations)
 
double TestIntMode (int test_iterations)
 
void TestEncodeDecode (const std::string &lang, const std::string &str, bool recode)
 
void TestEncodeDecodeBoth (const std::string &lang, const std::string &str)
 

Protected Attributes

std::unique_ptr< LSTMTrainertrainer_
 

Detailed Description

Definition at line 45 of file lstm_test.h.

Constructor & Destructor Documentation

◆ LSTMTrainerTest()

tesseract::LSTMTrainerTest::LSTMTrainerTest ( )
protecteddefault

Member Function Documentation

◆ SetUp()

void tesseract::LSTMTrainerTest::SetUp ( )
inlineoverrideprotected

Definition at line 47 of file lstm_test.h.

47  {
48  std::locale::global(std::locale(""));
50  }
static void MakeTmpdir()
Definition: include_gunit.h:38

◆ SetupTrainer()

void tesseract::LSTMTrainerTest::SetupTrainer ( const std::string &  network_spec,
const std::string &  model_name,
const std::string &  unicharset_file,
const std::string &  lstmf_file,
bool  recode,
bool  adam,
float  learning_rate,
bool  layer_specific,
const std::string &  kLang 
)
inlineprotected

Definition at line 68 of file lstm_test.h.

70  {
71  // constexpr char kLang[] = "eng"; // Exact value doesn't matter.
72  std::string unicharset_name = TestDataNameToPath(unicharset_file);
73  UNICHARSET unicharset;
74  ASSERT_TRUE(unicharset.load_from_file(unicharset_name.c_str(), false));
75  std::string script_dir = file::JoinPath(LANGDATA_DIR, "");
76  std::vector<std::string> words;
77  EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, "", FLAGS_test_tmpdir, kLang, !recode,
78  words, words, words, false, nullptr, nullptr));
79  std::string model_path = file::JoinPath(FLAGS_test_tmpdir, model_name);
80  std::string checkpoint_path = model_path + "_checkpoint";
81  trainer_ = std::make_unique<LSTMTrainer>(model_path.c_str(), checkpoint_path.c_str(), 0, 0);
82  trainer_->InitCharSet(
83  file::JoinPath(FLAGS_test_tmpdir, kLang, kLang) + ".traineddata");
84  int net_mode = adam ? NF_ADAM : 0;
85  // Adam needs a higher learning rate, due to not multiplying the effective
86  // rate by 1/(1-momentum).
87  if (adam) {
88  learning_rate *= 20.0f;
89  }
90  if (layer_specific) {
91  net_mode |= NF_LAYER_SPECIFIC_LR;
92  }
93  EXPECT_TRUE(
94  trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1, learning_rate, 0.9, 0.999));
95  std::vector<std::string> filenames;
96  filenames.emplace_back(TestDataNameToPath(lstmf_file).c_str());
97  EXPECT_TRUE(trainer_->LoadAllTrainingData(filenames, CS_SEQUENTIAL, false));
98  LOG(INFO) << "Setup network:" << model_name << "\n";
99  }
@ LOG
@ INFO
Definition: log.h:28
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:85
@ NF_ADAM
Definition: network.h:86
@ CS_SEQUENTIAL
Definition: imagedata.h:49
int CombineLangModel(const UNICHARSET &unicharset, const std::string &script_dir, const std::string &version_str, const std::string &output_dir, const std::string &lang, bool pass_through_recoder, const std::vector< std::string > &words, const std::vector< std::string > &puncs, const std::vector< std::string > &numbers, bool lang_is_rtl, FileReader reader, FileWriter writer)
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:65
std::unique_ptr< LSTMTrainer > trainer_
Definition: lstm_test.h:180
std::string TestDataNameToPath(const std::string &name)
Definition: lstm_test.h:53

◆ SetupTrainerEng()

void tesseract::LSTMTrainerTest::SetupTrainerEng ( const std::string &  network_spec,
const std::string &  model_name,
bool  recode,
bool  adam 
)
inlineprotected

Definition at line 63 of file lstm_test.h.

64  {
65  SetupTrainer(network_spec, model_name, "eng/eng.unicharset", "eng.Arial.exp0.lstmf", recode,
66  adam, 5e-4, false, "eng");
67  }
void SetupTrainer(const std::string &network_spec, const std::string &model_name, const std::string &unicharset_file, const std::string &lstmf_file, bool recode, bool adam, float learning_rate, bool layer_specific, const std::string &kLang)
Definition: lstm_test.h:68

◆ TessDataNameToPath()

std::string tesseract::LSTMTrainerTest::TessDataNameToPath ( const std::string &  name)
inlineprotected

Definition at line 56 of file lstm_test.h.

56  {
57  return file::JoinPath(TESSDATA_DIR, "" + name);
58  }

◆ TestDataNameToPath()

std::string tesseract::LSTMTrainerTest::TestDataNameToPath ( const std::string &  name)
inlineprotected

Definition at line 53 of file lstm_test.h.

53  {
54  return file::JoinPath(TESTDATA_DIR, "" + name);
55  }

◆ TestEncodeDecode()

void tesseract::LSTMTrainerTest::TestEncodeDecode ( const std::string &  lang,
const std::string &  str,
bool  recode 
)
inlineprotected

Definition at line 163 of file lstm_test.h.

163  {
164  std::string unicharset_name = lang + "/" + lang + ".unicharset";
165  std::string lstmf_name = lang + ".Arial_Unicode_MS.exp0.lstmf";
166  SetupTrainer("[1,1,0,32 Lbx100 O1c1]", "bidi-lstm", unicharset_name, lstmf_name, recode, true,
167  5e-4f, true, lang);
168  std::vector<int> labels;
169  EXPECT_TRUE(trainer_->EncodeString(str.c_str(), &labels));
170  std::string decoded = trainer_->DecodeLabels(labels);
171  std::string decoded_str(&decoded[0], decoded.length());
172  EXPECT_EQ(str, decoded_str);
173  }

◆ TestEncodeDecodeBoth()

void tesseract::LSTMTrainerTest::TestEncodeDecodeBoth ( const std::string &  lang,
const std::string &  str 
)
inlineprotected

Definition at line 175 of file lstm_test.h.

175  {
176  TestEncodeDecode(lang, str, false);
177  TestEncodeDecode(lang, str, true);
178  }
void TestEncodeDecode(const std::string &lang, const std::string &str, bool recode)
Definition: lstm_test.h:163

◆ TestingNameToPath()

std::string tesseract::LSTMTrainerTest::TestingNameToPath ( const std::string &  name)
inlineprotected

Definition at line 59 of file lstm_test.h.

59  {
60  return file::JoinPath(TESTING_DIR, "" + name);
61  }

◆ TestIntMode()

double tesseract::LSTMTrainerTest::TestIntMode ( int  test_iterations)
inlineprotected

Definition at line 148 of file lstm_test.h.

148  {
149  std::vector<char> trainer_data;
150  EXPECT_TRUE(trainer_->SaveTrainingDump(NO_BEST_TRAINER, *trainer_, &trainer_data));
151  // Get the error on the next few iterations in float mode.
152  double float_err = TestIterations(test_iterations);
153  // Restore the dump, convert to int and test error on that.
154  EXPECT_TRUE(trainer_->ReadTrainingDump(trainer_data, *trainer_));
155  trainer_->ConvertToInt();
156  double int_err = TestIterations(test_iterations);
157  EXPECT_LT(int_err, float_err + 1.0);
158  return int_err - float_err;
159  }
@ NO_BEST_TRAINER
Definition: lstmtrainer.h:61
double TestIterations(int max_iterations)
Definition: lstm_test.h:126

◆ TestIterations()

double tesseract::LSTMTrainerTest::TestIterations ( int  max_iterations)
inlineprotected

Definition at line 126 of file lstm_test.h.

126  {
127  CHECK_GT(max_iterations, 0);
128  int iteration = trainer_->sample_iteration();
129  double mean_error = 0.0;
130  int error_count = 0;
131  while (error_count < max_iterations) {
132  const ImageData &trainingdata =
133  *trainer_->mutable_training_data()->GetPageBySerial(iteration);
134  NetworkIO fwd_outputs, targets;
135  if (trainer_->PrepareForBackward(&trainingdata, &fwd_outputs, &targets) != UNENCODABLE) {
136  mean_error += trainer_->NewSingleError(ET_CHAR_ERROR);
137  ++error_count;
138  }
139  trainer_->SetIteration(++iteration);
140  }
141  mean_error *= 100.0 / max_iterations;
142  LOG(INFO) << "Tester error rate = " << mean_error << "\n";
143  return mean_error;
144  }
#define CHECK_GT(test, value)
Definition: include_gunit.h:81
@ ET_CHAR_ERROR
Definition: lstmtrainer.h:44

◆ TrainIterations()

double tesseract::LSTMTrainerTest::TrainIterations ( int  max_iterations)
inlineprotected

Definition at line 101 of file lstm_test.h.

101  {
102  int iteration = trainer_->training_iteration();
103  int iteration_limit = iteration + max_iterations;
104  double best_error = 100.0;
105  do {
106  std::string log_str;
107  int target_iteration = iteration + kBatchIterations;
108  // Train a few.
109  double mean_error = 0.0;
110  while (iteration < target_iteration && iteration < iteration_limit) {
111  trainer_->TrainOnLine(trainer_.get(), false);
112  iteration = trainer_->training_iteration();
113  mean_error += trainer_->LastSingleError(ET_CHAR_ERROR);
114  }
115  trainer_->MaintainCheckpoints(nullptr, log_str);
116  iteration = trainer_->training_iteration();
117  mean_error *= 100.0 / kBatchIterations;
118  if (mean_error < best_error) {
119  best_error = mean_error;
120  }
121  } while (iteration < iteration_limit);
122  LOG(INFO) << "Trainer error rate = " << best_error << "\n";
123  return best_error;
124  }
const int kBatchIterations
Definition: lstm_test.h:36

Member Data Documentation

◆ trainer_

std::unique_ptr<LSTMTrainer> tesseract::LSTMTrainerTest::trainer_
protected

Definition at line 180 of file lstm_test.h.


The documentation for this class was generated from the following file: