tesseract  5.0.0
lstmtrainer_test.cc
Go to the documentation of this file.
1 // (C) Copyright 2017, Google Inc.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 // http://www.apache.org/licenses/LICENSE-2.0
6 // Unless required by applicable law or agreed to in writing, software
7 // distributed under the License is distributed on an "AS IS" BASIS,
8 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 // See the License for the specific language governing permissions and
10 // limitations under the License.
11 
12 #include <allheaders.h>
13 #include <tesseract/baseapi.h>
14 #include "lstm_test.h"
15 
16 namespace tesseract {
17 
18 TEST_F(LSTMTrainerTest, EncodesEng) {
19  TestEncodeDecodeBoth("eng", "The quick brown 'fox' jumps over: the lazy dog!");
20 }
21 
22 TEST_F(LSTMTrainerTest, EncodesKan) {
23  TestEncodeDecodeBoth("kan", "ಫ್ರಬ್ರವರಿ ತತ್ವಾಂಶಗಳೆಂದರೆ ಮತ್ತು ಜೊತೆಗೆ ಕ್ರಮವನ್ನು");
24 }
25 
26 TEST_F(LSTMTrainerTest, EncodesKor) {
27  TestEncodeDecodeBoth("kor", "이는 것으로 다시 넣을 수는 있지만 선택의 의미는");
28 }
29 
30 TEST_F(LSTMTrainerTest, MapCoder) {
31  LSTMTrainer fra_trainer;
32  fra_trainer.InitCharSet(TestDataNameToPath("fra/fra.traineddata"));
33  LSTMTrainer deu_trainer;
34  deu_trainer.InitCharSet(TestDataNameToPath("deu/deu.traineddata"));
35  // A string that uses characters common to French and German.
36  std::string kTestStr = "The quick brown 'fox' jumps over: the lazy dog!";
37  std::vector<int> deu_labels;
38  EXPECT_TRUE(deu_trainer.EncodeString(kTestStr.c_str(), &deu_labels));
39  // The french trainer cannot decode them correctly.
40  std::string badly_decoded = fra_trainer.DecodeLabels(deu_labels);
41  std::string bad_str(&badly_decoded[0], badly_decoded.length());
42  LOG(INFO) << "bad_str fra=" << bad_str << "\n";
43  EXPECT_NE(kTestStr, bad_str);
44  // Encode the string as fra.
45  std::vector<int> fra_labels;
46  EXPECT_TRUE(fra_trainer.EncodeString(kTestStr.c_str(), &fra_labels));
47  // Use the mapper to compute what the labels are as deu.
48  std::vector<int> mapping =
49  fra_trainer.MapRecoder(deu_trainer.GetUnicharset(), deu_trainer.GetRecoder());
50  std::vector<int> mapped_fra_labels(fra_labels.size(), -1);
51  for (unsigned i = 0; i < fra_labels.size(); ++i) {
52  mapped_fra_labels[i] = mapping[fra_labels[i]];
53  EXPECT_NE(-1, mapped_fra_labels[i]) << "i=" << i << ", ch=" << kTestStr[i];
54  EXPECT_EQ(mapped_fra_labels[i], deu_labels[i])
55  << "i=" << i << ", ch=" << kTestStr[i] << " has deu label=" << deu_labels[i]
56  << ", but mapped to " << mapped_fra_labels[i];
57  }
58  // The german trainer can now decode them correctly.
59  std::string decoded = deu_trainer.DecodeLabels(mapped_fra_labels);
60  std::string ok_str(&decoded[0], decoded.length());
61  LOG(INFO) << "ok_str deu=" << ok_str << "\n";
62  EXPECT_EQ(kTestStr, ok_str);
63 }
64 
65 // Tests that the actual fra model can be converted to the deu character set
66 // and still read an eng image with 100% accuracy.
67 TEST_F(LSTMTrainerTest, ConvertModel) {
68  // Setup a trainer with a deu charset.
69  LSTMTrainer deu_trainer;
70  deu_trainer.InitCharSet(TestDataNameToPath("deu/deu.traineddata"));
71  // Load the fra traineddata, strip out the model, and save to a tmp file.
72  TessdataManager mgr;
73  std::string fra_data = file::JoinPath(TESSDATA_DIR "_best", "fra.traineddata");
74  CHECK(mgr.Init(fra_data.c_str()));
75  LOG(INFO) << "Load " << fra_data << "\n";
77  std::string model_path = file::JoinPath(FLAGS_test_tmpdir, "fra.lstm");
78  CHECK(mgr.ExtractToFile(model_path.c_str()));
79  LOG(INFO) << "Extract " << model_path << "\n";
80  // Load the fra model into the deu_trainer, and save the converted model.
81  CHECK(deu_trainer.TryLoadingCheckpoint(model_path.c_str(), fra_data.c_str()));
82  LOG(INFO) << "Checkpoint load for " << model_path << " and " << fra_data << "\n";
83  std::string deu_data = file::JoinPath(FLAGS_test_tmpdir, "deu.traineddata");
84  CHECK(deu_trainer.SaveTraineddata(deu_data.c_str()));
85  LOG(INFO) << "Save " << deu_data << "\n";
86  // Now run the saved model on phototest. (See BasicTesseractTest in
87  // baseapi_test.cc).
88  TessBaseAPI api;
89  api.Init(FLAGS_test_tmpdir, "deu", tesseract::OEM_LSTM_ONLY);
90  Image src_pix = pixRead(TestingNameToPath("phototest.tif").c_str());
91  CHECK(src_pix);
92  api.SetImage(src_pix);
93  std::unique_ptr<char[]> result(api.GetUTF8Text());
94  std::string truth_text;
95  CHECK_OK(
96  file::GetContents(TestingNameToPath("phototest.gold.txt"), &truth_text, file::Defaults()));
97 
98  EXPECT_STREQ(truth_text.c_str(), result.get());
99  src_pix.destroy();
100 }
101 
102 } // namespace tesseract
@ LOG
#define CHECK(condition)
Definition: include_gunit.h:76
#define CHECK_OK(test)
Definition: include_gunit.h:84
@ INFO
Definition: log.h:28
std::string TestDataNameToPath(const std::string &name)
TEST_F(EuroText, FastLatinOCR)
int Init(const char *datapath, const char *language, OcrEngineMode mode, char **configs, int configs_size, const std::vector< std::string > *vars_vec, const std::vector< std::string > *vars_values, bool set_only_non_debug_params)
Definition: baseapi.cpp:365
void SetImage(const unsigned char *imagedata, int width, int height, int bytes_per_pixel, int bytes_per_line)
Definition: baseapi.cpp:573
void destroy()
Definition: image.cpp:32
bool ExtractToFile(const char *filename)
bool Init(const char *data_file_name)
std::string DecodeLabels(const std::vector< int > &labels)
const UNICHARSET & GetUnicharset() const
const UnicharCompress & GetRecoder() const
bool EncodeString(const std::string &str, std::vector< int > *labels) const
Definition: lstmtrainer.h:253
bool InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:99
bool SaveTraineddata(const char *filename)
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
static int Defaults()
Definition: include_gunit.h:61
static void MakeTmpdir()
Definition: include_gunit.h:38
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:65
static bool GetContents(const std::string &filename, std::string *out, int)
Definition: include_gunit.h:52