tesseract  5.0.0
lang_model_test.cc
Go to the documentation of this file.
1 // (C) Copyright 2017, Google Inc.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 // http://www.apache.org/licenses/LICENSE-2.0
6 // Unless required by applicable law or agreed to in writing, software
7 // distributed under the License is distributed on an "AS IS" BASIS,
8 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 // See the License for the specific language governing permissions and
10 // limitations under the License.
11 
12 #include <string> // for std::string
13 
14 #include "gmock/gmock.h" // for testing::ElementsAreArray
15 
16 #include "include_gunit.h"
17 #include "lang_model_helpers.h"
18 #include "log.h" // for LOG
19 #include "lstmtrainer.h"
21 
22 namespace tesseract {
23 
24 std::string TestDataNameToPath(const std::string &name) {
25  return file::JoinPath(TESTING_DIR, name);
26 }
27 
28 // This is an integration test that verifies that CombineLangModel works to
29 // the extent that an LSTMTrainer can be initialized with the result, and it
30 // can encode strings. More importantly, the test verifies that adding an extra
31 // character to the unicharset does not change the encoding of strings.
32 TEST(LangModelTest, AddACharacter) {
33  constexpr char kTestString[] = "Simple ASCII string to encode !@#$%&";
34  constexpr char kTestStringRupees[] = "ASCII string with Rupee symbol ₹";
35  // Setup the arguments.
36  std::string script_dir = LANGDATA_DIR;
37  std::string eng_dir = file::JoinPath(script_dir, "eng");
38  std::string unicharset_path = TestDataNameToPath("eng_beam.unicharset");
39  UNICHARSET unicharset;
40  EXPECT_TRUE(unicharset.load_from_file(unicharset_path.c_str()));
41  std::string version_str = "TestVersion";
43  std::string output_dir = FLAGS_test_tmpdir;
44  LOG(INFO) << "Output dir=" << output_dir << "\n";
45  std::string lang1 = "eng";
46  bool pass_through_recoder = false;
47  // If these reads fail, we get a warning message and an empty list of words.
48  std::vector<std::string> words = split(ReadFile(file::JoinPath(eng_dir, "eng.wordlist")), '\n');
49  EXPECT_GT(words.size(), 0);
50  std::vector<std::string> puncs = split(ReadFile(file::JoinPath(eng_dir, "eng.punc")), '\n');
51  EXPECT_GT(puncs.size(), 0);
52  std::vector<std::string> numbers = split(ReadFile(file::JoinPath(eng_dir, "eng.numbers")), '\n');
53  EXPECT_GT(numbers.size(), 0);
54  bool lang_is_rtl = false;
55  // Generate the traineddata file.
56  EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, version_str, output_dir, lang1,
57  pass_through_recoder, words, puncs, numbers, lang_is_rtl, nullptr,
58  nullptr));
59  // Init a trainer with it, and encode kTestString.
60  std::string traineddata1 = file::JoinPath(output_dir, lang1, lang1) + ".traineddata";
61  LSTMTrainer trainer1;
62  trainer1.InitCharSet(traineddata1);
63  std::vector<int> labels1;
64  EXPECT_TRUE(trainer1.EncodeString(kTestString, &labels1));
65  std::string test1_decoded = trainer1.DecodeLabels(labels1);
66  std::string test1_str(&test1_decoded[0], test1_decoded.length());
67  LOG(INFO) << "Labels1=" << test1_str << "\n";
68 
69  // Add a new character to the unicharset and try again.
70  int size_before = unicharset.size();
71  unicharset.unichar_insert("₹");
72  SetupBasicProperties(/*report_errors*/ true, /*decompose (NFD)*/ false, &unicharset);
73  EXPECT_EQ(size_before + 1, unicharset.size());
74  // Generate the traineddata file.
75  std::string lang2 = "extended";
76  EXPECT_EQ(EXIT_SUCCESS, CombineLangModel(unicharset, script_dir, version_str, output_dir, lang2,
77  pass_through_recoder, words, puncs, numbers, lang_is_rtl,
78  nullptr, nullptr));
79  // Init a trainer with it, and encode kTestString.
80  std::string traineddata2 = file::JoinPath(output_dir, lang2, lang2) + ".traineddata";
81  LSTMTrainer trainer2;
82  trainer2.InitCharSet(traineddata2);
83  std::vector<int> labels2;
84  EXPECT_TRUE(trainer2.EncodeString(kTestString, &labels2));
85  std::string test2_decoded = trainer2.DecodeLabels(labels2);
86  std::string test2_str(&test2_decoded[0], test2_decoded.length());
87  LOG(INFO) << "Labels2=" << test2_str << "\n";
88  // encode kTestStringRupees.
89  std::vector<int> labels3;
90  EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels3));
91  std::string test3_decoded = trainer2.DecodeLabels(labels3);
92  std::string test3_str(&test3_decoded[0], test3_decoded.length());
93  LOG(INFO) << "labels3=" << test3_str << "\n";
94  // Copy labels1 to a std::vector, renumbering the null char to match trainer2.
95  // Since Tensor Flow's CTC implementation insists on having the null be the
96  // last label, and we want to be compatible, null has to be renumbered when
97  // we add a class.
98  int null1 = trainer1.null_char();
99  int null2 = trainer2.null_char();
100  EXPECT_EQ(null1 + 1, null2);
101  std::vector<int> labels1_v(labels1.size());
102  for (unsigned i = 0; i < labels1.size(); ++i) {
103  if (labels1[i] == null1) {
104  labels1_v[i] = null2;
105  } else {
106  labels1_v[i] = labels1[i];
107  }
108  }
109  EXPECT_THAT(labels1_v, testing::ElementsAreArray(&labels2[0], labels2.size()));
110  // To make sure we we are not cheating somehow, we can now encode the Rupee
111  // symbol, which we could not do before.
112  EXPECT_FALSE(trainer1.EncodeString(kTestStringRupees, &labels1));
113  EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels2));
114 }
115 
116 // Same as above test, for hin instead of eng
117 TEST(LangModelTest, AddACharacterHindi) {
118  constexpr char kTestString[] = "हिन्दी में एक लाइन लिखें";
119  constexpr char kTestStringRupees[] = "हिंदी में रूपये का चिन्ह प्रयोग करें ₹१००.००";
120  // Setup the arguments.
121  std::string script_dir = LANGDATA_DIR;
122  std::string hin_dir = file::JoinPath(script_dir, "hin");
123  std::string unicharset_path = TestDataNameToPath("hin_beam.unicharset");
124  UNICHARSET unicharset;
125  EXPECT_TRUE(unicharset.load_from_file(unicharset_path.c_str()));
126  std::string version_str = "TestVersion";
128  std::string output_dir = FLAGS_test_tmpdir;
129  LOG(INFO) << "Output dir=" << output_dir << "\n";
130  std::string lang1 = "hin";
131  bool pass_through_recoder = false;
132  // If these reads fail, we get a warning message and an empty list of words.
133  std::vector<std::string> words = split(ReadFile(file::JoinPath(hin_dir, "hin.wordlist")), '\n');
134  EXPECT_GT(words.size(), 0);
135  std::vector<std::string> puncs = split(ReadFile(file::JoinPath(hin_dir, "hin.punc")), '\n');
136  EXPECT_GT(puncs.size(), 0);
137  std::vector<std::string> numbers = split(ReadFile(file::JoinPath(hin_dir, "hin.numbers")), '\n');
138  EXPECT_GT(numbers.size(), 0);
139  bool lang_is_rtl = false;
140  // Generate the traineddata file.
141  EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, version_str, output_dir, lang1,
142  pass_through_recoder, words, puncs, numbers, lang_is_rtl, nullptr,
143  nullptr));
144  // Init a trainer with it, and encode kTestString.
145  std::string traineddata1 = file::JoinPath(output_dir, lang1, lang1) + ".traineddata";
146  LSTMTrainer trainer1;
147  trainer1.InitCharSet(traineddata1);
148  std::vector<int> labels1;
149  EXPECT_TRUE(trainer1.EncodeString(kTestString, &labels1));
150  std::string test1_decoded = trainer1.DecodeLabels(labels1);
151  std::string test1_str(&test1_decoded[0], test1_decoded.length());
152  LOG(INFO) << "Labels1=" << test1_str << "\n";
153 
154  // Add a new character to the unicharset and try again.
155  int size_before = unicharset.size();
156  unicharset.unichar_insert("₹");
157  SetupBasicProperties(/*report_errors*/ true, /*decompose (NFD)*/ false, &unicharset);
158  EXPECT_EQ(size_before + 1, unicharset.size());
159  // Generate the traineddata file.
160  std::string lang2 = "extendedhin";
161  EXPECT_EQ(EXIT_SUCCESS, CombineLangModel(unicharset, script_dir, version_str, output_dir, lang2,
162  pass_through_recoder, words, puncs, numbers, lang_is_rtl,
163  nullptr, nullptr));
164  // Init a trainer with it, and encode kTestString.
165  std::string traineddata2 = file::JoinPath(output_dir, lang2, lang2) + ".traineddata";
166  LSTMTrainer trainer2;
167  trainer2.InitCharSet(traineddata2);
168  std::vector<int> labels2;
169  EXPECT_TRUE(trainer2.EncodeString(kTestString, &labels2));
170  std::string test2_decoded = trainer2.DecodeLabels(labels2);
171  std::string test2_str(&test2_decoded[0], test2_decoded.length());
172  LOG(INFO) << "Labels2=" << test2_str << "\n";
173  // encode kTestStringRupees.
174  std::vector<int> labels3;
175  EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels3));
176  std::string test3_decoded = trainer2.DecodeLabels(labels3);
177  std::string test3_str(&test3_decoded[0], test3_decoded.length());
178  LOG(INFO) << "labels3=" << test3_str << "\n";
179  // Copy labels1 to a std::vector, renumbering the null char to match trainer2.
180  // Since Tensor Flow's CTC implementation insists on having the null be the
181  // last label, and we want to be compatible, null has to be renumbered when
182  // we add a class.
183  int null1 = trainer1.null_char();
184  int null2 = trainer2.null_char();
185  EXPECT_EQ(null1 + 1, null2);
186  std::vector<int> labels1_v(labels1.size());
187  for (unsigned i = 0; i < labels1.size(); ++i) {
188  if (labels1[i] == null1) {
189  labels1_v[i] = null2;
190  } else {
191  labels1_v[i] = labels1[i];
192  }
193  }
194  EXPECT_THAT(labels1_v, testing::ElementsAreArray(&labels2[0], labels2.size()));
195  // To make sure we we are not cheating somehow, we can now encode the Rupee
196  // symbol, which we could not do before.
197  EXPECT_FALSE(trainer1.EncodeString(kTestStringRupees, &labels1));
198  EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels2));
199 }
200 
201 } // namespace tesseract
@ LOG
@ INFO
Definition: log.h:28
const std::vector< std::string > split(const std::string &s, char c)
Definition: helpers.h:41
void SetupBasicProperties(bool report_errors, bool decompose, UNICHARSET *unicharset)
std::string TestDataNameToPath(const std::string &name)
std::string ReadFile(const std::string &filename, FileReader reader)
int CombineLangModel(const UNICHARSET &unicharset, const std::string &script_dir, const std::string &version_str, const std::string &output_dir, const std::string &lang, bool pass_through_recoder, const std::vector< std::string > &words, const std::vector< std::string > &puncs, const std::vector< std::string > &numbers, bool lang_is_rtl, FileReader reader, FileWriter writer)
TEST(TesseractInstanceTest, TestMultipleTessInstances)
void unichar_insert(const char *const unichar_repr, OldUncleanUnichars old_style)
Definition: unicharset.cpp:654
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:391
size_t size() const
Definition: unicharset.h:355
std::string DecodeLabels(const std::vector< int > &labels)
bool EncodeString(const std::string &str, std::vector< int > *labels) const
Definition: lstmtrainer.h:253
bool InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:99
static void MakeTmpdir()
Definition: include_gunit.h:38
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:65