31 #include <unordered_set>
38 5, 10, 16, 16, 16, 16, 16, 16, 16, 16,
41 static const char *kNodeContNames[] = {
"Anything",
"OnlyDup",
"NoDup"};
46 if (
code == null_char) {
55 if (depth > 0 &&
prev !=
nullptr) {
57 prev->
Print(null_char, unicharset, depth - 1);
65 int null_char,
bool simple_text,
Dict *dict)
71 space_delimited_(true),
72 is_simple_text_(simple_text),
73 null_char_(null_char) {
75 space_delimited_ =
false;
80 for (
auto data : beam_) {
83 for (
auto data : secondary_beam_) {
90 double cert_offset,
double worst_dict_cert,
91 const UNICHARSET *charset,
int lstm_choice_mode) {
93 int width = output.
Width();
94 if (lstm_choice_mode) {
97 for (
int t = 0; t < width; ++t) {
98 ComputeTopN(output.
f(t), output.
NumFeatures(), kBeamWidths[0]);
99 DecodeStep(output.
f(t), t, dict_ratio, cert_offset, worst_dict_cert,
101 if (lstm_choice_mode) {
102 SaveMostCertainChoices(output.
f(t), output.
NumFeatures(), charset, t);
107 double dict_ratio,
double cert_offset,
108 double worst_dict_cert,
111 int width = output.
dim1();
112 for (
int t = 0; t < width; ++t) {
113 ComputeTopN(output[t], output.
dim2(), kBeamWidths[0]);
114 DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
119 const NetworkIO &output,
double dict_ratio,
double cert_offset,
120 double worst_dict_cert,
const UNICHARSET *charset,
int lstm_choice_mode) {
121 for (
auto data : secondary_beam_) {
124 secondary_beam_.clear();
128 int width = output.
Width();
129 unsigned bucketNumber = 0;
130 for (
int t = 0; t < width; ++t) {
137 DecodeSecondaryStep(output.
f(t), t, dict_ratio, cert_offset,
138 worst_dict_cert, charset);
142 void RecodeBeamSearch::SaveMostCertainChoices(
const float *outputs,
146 std::vector<std::pair<const char *, float>> choices;
147 for (
int i = 0; i < num_outputs; ++i) {
148 if (outputs[i] >= 0.01f) {
150 if (i + 2 >= num_outputs) {
160 while (choices.size() > pos && choices[pos].second > outputs[i]) {
163 choices.insert(choices.begin() + pos,
164 std::pair<const char *, float>(
character, outputs[i]));
172 std::vector<std::vector<std::pair<const char *, float>>> segment;
180 std::vector<std::vector<std::pair<const char *, float>>>
182 std::vector<std::vector<std::vector<std::pair<const char *, float>>>>
183 *segmentedTimesteps) {
184 std::vector<std::vector<std::pair<const char *, float>>> combined_timesteps;
186 for (
auto &j : segmentedTimestep) {
187 combined_timesteps.push_back(j);
190 return combined_timesteps;
193 void RecodeBeamSearch::calculateCharBoundaries(std::vector<int> *starts,
194 std::vector<int> *ends,
195 std::vector<int> *char_bounds_,
197 char_bounds_->push_back(0);
198 for (
unsigned i = 0; i < ends->size(); ++i) {
199 int middle = ((*starts)[i + 1] - (*ends)[i]) / 2;
200 char_bounds_->push_back((*ends)[i] + middle);
202 char_bounds_->pop_back();
203 char_bounds_->push_back(maxWidth);
208 std::vector<int> *labels, std::vector<int> *xcoords)
const {
211 std::vector<const RecodeNode *> best_nodes;
212 ExtractBestPaths(&best_nodes,
nullptr);
215 int width = best_nodes.size();
217 int label = best_nodes[t]->code;
218 if (label != null_char_) {
219 labels->push_back(label);
220 xcoords->push_back(t);
222 while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) {
225 xcoords->push_back(width);
231 bool debug,
const UNICHARSET *unicharset, std::vector<int> *unichar_ids,
232 std::vector<float> *certs, std::vector<float> *ratings,
233 std::vector<int> *xcoords)
const {
234 std::vector<const RecodeNode *> best_nodes;
235 ExtractBestPaths(&best_nodes,
nullptr);
236 ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords);
238 DebugPath(unicharset, best_nodes);
239 DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings,
246 float scale_factor,
bool debug,
249 int lstm_choice_mode) {
251 std::vector<int> unichar_ids;
252 std::vector<float> certs;
253 std::vector<float> ratings;
254 std::vector<int> xcoords;
255 std::vector<const RecodeNode *> best_nodes;
256 std::vector<const RecodeNode *> second_nodes;
258 ExtractBestPaths(&best_nodes, &second_nodes);
260 DebugPath(unicharset, best_nodes);
261 ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings,
263 tprintf(
"\nSecond choice path:\n");
264 DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings,
270 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords,
272 int num_ids = unichar_ids.size();
274 DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings,
279 float prev_space_cert = 0.0f;
280 for (
int word_start = 0; word_start < num_ids; word_start = word_end) {
281 for (word_end = word_start + 1; word_end < num_ids; ++word_end) {
288 int index = xcoords[word_end];
289 if (best_nodes[index]->start_of_word) {
298 float space_cert = 0.0f;
299 if (word_end < num_ids && unichar_ids[word_end] ==
UNICHAR_SPACE) {
300 space_cert = certs[word_end];
303 word_start > 0 && unichar_ids[word_start - 1] ==
UNICHAR_SPACE;
306 InitializeWord(leading_space, line_box, word_start, word_end,
307 std::min(space_cert, prev_space_cert), unicharset,
308 xcoords, scale_factor);
309 for (
int i = word_start; i < word_end; ++i) {
310 auto *choices =
new BLOB_CHOICE_LIST;
311 BLOB_CHOICE_IT bc_it(choices);
312 auto *choice =
new BLOB_CHOICE(unichar_ids[i], ratings[i], certs[i], -1,
313 1.0f,
static_cast<float>(INT16_MAX), 0.0f,
315 int col = i - word_start;
316 choice->set_matrix_cell(col, col);
317 bc_it.add_after_then_move(choice);
320 int index = xcoords[word_end - 1];
323 prev_space_cert = space_cert;
324 if (word_end < num_ids && unichar_ids[word_end] ==
UNICHAR_SPACE) {
338 bool secondary)
const {
339 std::vector<std::vector<const RecodeNode *>> topology;
340 std::unordered_set<const RecodeNode *> visited;
341 const std::vector<RecodeBeam *> &beam = !secondary ? beam_ : secondary_beam_;
343 for (
int step = beam.size() - 1; step >= 0; --step) {
344 std::vector<const RecodeNode *> layer;
345 topology.push_back(layer);
348 for (
int step = beam.size() - 1; step >= 0; --step) {
349 std::vector<tesseract::RecodePair> &heaps = beam.at(step)->beams_->heap();
350 for (
auto node : heaps) {
353 while (curr !=
nullptr && !visited.count(curr)) {
354 visited.insert(curr);
355 topology[step - backtracker].push_back(curr);
363 for (
const std::vector<const RecodeNode *> &layer : topology) {
374 if (node->unichar_id != INVALID_UNICHAR_ID) {
376 intCode = node->unichar_id;
377 }
else if (node->code == null_char_) {
385 const char *prevCode;
387 if (node->prev !=
nullptr) {
388 prevScore = node->prev->score;
389 if (node->prev->unichar_id != INVALID_UNICHAR_ID) {
391 intPrevCode = node->prev->unichar_id;
392 }
else if (node->code == null_char_) {
403 tprintf(
"%x(|)%f(>)%x(|)%f\n", intPrevCode, prevScore, intCode,
406 tprintf(
"%s(|)%f(>)%s(|)%f\n", prevCode, prevScore, code, node->score);
421 std::vector<RecodeBeam *> ¤tBeam =
422 secondary_beam_.empty() ? beam_ : secondary_beam_;
425 std::vector<int> unichar_ids;
426 std::vector<float> certs;
427 std::vector<float> ratings;
428 std::vector<int> xcoords;
430 std::vector<tesseract::RecodePair> &heaps =
432 std::vector<const RecodeNode *> best_nodes;
433 std::vector<const RecodeNode *> best;
435 for (
auto entry : heaps) {
436 bool validChar =
false;
439 while (node !=
nullptr && backcounter < backpath) {
440 if (node->
code != null_char_ &&
449 best.push_back(&entry.data());
455 ExtractPath(best[0], &best_nodes, backpath);
456 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings,
459 if (!unichar_ids.empty()) {
461 for (
unsigned i = 1; i < unichar_ids.size(); ++i) {
462 if (ratings[i] < ratings[bestPos]) {
468 for (
auto &node : best_nodes) {
469 if (node->unichar_id == unichar_ids[bestPos]) {
470 bestCode = node->code;
475 std::unordered_set<int> excludeCodeList;
476 for (
auto &best_node : best_nodes) {
477 if (best_node->code != null_char_) {
478 excludeCodeList.insert(best_node->code);
482 for (
auto elem : excludeCodeList) {
490 int id = unichar_ids[bestPos];
492 float rating = ratings[bestPos];
494 std::pair<const char *, float>(result, rating));
496 std::vector<std::pair<const char *, float>> choice;
497 int id = unichar_ids[bestPos];
499 float rating = ratings[bestPos];
500 choice.emplace_back(result, rating);
506 std::unordered_set<int> excludeCodeList;
510 std::vector<std::pair<const char *, float>> choice;
515 for (
auto data : secondary_beam_) {
518 secondary_beam_.clear();
523 for (
int p = 0; p < beam_size_; ++p) {
524 for (
int d = 0; d < 2; ++d) {
525 for (
int c = 0; c <
NC_COUNT; ++c) {
528 if (beam_[p]->beams_[index].empty()) {
532 tprintf(
"Position %d: %s+%s beam\n", p, d ?
"Dict" :
"Non-Dict",
534 DebugBeamPos(unicharset, beam_[p]->beams_[index]);
541 void RecodeBeamSearch::DebugBeamPos(
const UNICHARSET &unicharset,
543 std::vector<const RecodeNode *> unichar_bests(unicharset.
size());
545 int heap_size = heap.
size();
546 for (
int i = 0; i < heap_size; ++i) {
549 if (null_best ==
nullptr || null_best->
score < node->
score) {
553 if (unichar_bests[node->
unichar_id] ==
nullptr ||
559 for (
auto &unichar_best : unichar_bests) {
560 if (unichar_best !=
nullptr) {
561 const RecodeNode &node = *unichar_best;
562 node.
Print(null_char_, unicharset, 1);
565 if (null_best !=
nullptr) {
566 null_best->
Print(null_char_, unicharset, 1);
573 void RecodeBeamSearch::ExtractPathAsUnicharIds(
574 const std::vector<const RecodeNode *> &best_nodes,
575 std::vector<int> *unichar_ids, std::vector<float> *certs,
576 std::vector<float> *ratings, std::vector<int> *xcoords,
577 std::vector<int> *character_boundaries) {
578 unichar_ids->clear();
582 std::vector<int> starts;
583 std::vector<int> ends;
586 int width = best_nodes.size();
588 double certainty = 0.0;
590 while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) {
591 double cert = best_nodes[t++]->certainty;
592 if (cert < certainty) {
599 int unichar_id = best_nodes[t]->unichar_id;
601 best_nodes[t]->permuter !=
NO_PERM) {
604 if (certainty < certs->back()) {
605 certs->back() = certainty;
607 ratings->back() += rating;
611 unichar_ids->push_back(unichar_id);
612 xcoords->push_back(t);
614 double cert = best_nodes[t++]->certainty;
618 best_nodes[t - 1]->permuter ==
NO_PERM)) {
622 }
while (t < width && best_nodes[t]->duplicate);
624 certs->push_back(certainty);
625 ratings->push_back(rating);
626 }
else if (!certs->empty()) {
627 if (certainty < certs->back()) {
628 certs->back() = certainty;
630 ratings->back() += rating;
633 starts.push_back(width);
634 if (character_boundaries !=
nullptr) {
635 calculateCharBoundaries(&starts, &ends, character_boundaries, width);
637 xcoords->push_back(width);
642 WERD_RES *RecodeBeamSearch::InitializeWord(
bool leading_space,
643 const TBOX &line_box,
int word_start,
644 int word_end,
float space_certainty,
645 const UNICHARSET *unicharset,
646 const std::vector<int> &xcoords,
647 float scale_factor) {
650 C_BLOB_IT b_it(&blobs);
651 for (
int i = word_start; i < word_end; ++i) {
653 TBOX box(
static_cast<int16_t
>(
657 static_cast<int16_t
>(
665 WERD *word =
new WERD(&blobs, leading_space,
nullptr);
667 auto *word_res =
new WERD_RES(word);
668 word_res->end = word_end - word_start + leading_space;
669 word_res->uch_set = unicharset;
670 word_res->combination =
true;
671 word_res->space_certainty = space_certainty;
672 word_res->ratings =
new MATRIX(word_end - word_start, 1);
678 void RecodeBeamSearch::ComputeTopN(
const float *outputs,
int num_outputs,
680 top_n_flags_.clear();
685 for (
int i = 0; i < num_outputs; ++i) {
686 if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) {
687 TopPair entry(outputs[i], i);
688 top_heap_.Push(&entry);
689 if (top_heap_.size() > top_n) {
690 top_heap_.Pop(&entry);
694 while (!top_heap_.empty()) {
696 top_heap_.Pop(&entry);
697 if (top_heap_.size() > 1) {
698 top_n_flags_[entry.data()] =
TN_TOPN;
700 top_n_flags_[entry.data()] =
TN_TOP2;
701 if (top_heap_.empty()) {
702 top_code_ = entry.data();
704 second_code_ = entry.data();
708 top_n_flags_[null_char_] =
TN_TOP2;
711 void RecodeBeamSearch::ComputeSecTopN(std::unordered_set<int> *exList,
712 const float *outputs,
int num_outputs,
714 top_n_flags_.clear();
719 for (
int i = 0; i < num_outputs; ++i) {
720 if ((top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) &&
722 TopPair entry(outputs[i], i);
723 top_heap_.Push(&entry);
724 if (top_heap_.size() > top_n) {
725 top_heap_.Pop(&entry);
729 while (!top_heap_.empty()) {
731 top_heap_.Pop(&entry);
732 if (top_heap_.size() > 1) {
733 top_n_flags_[entry.data()] =
TN_TOPN;
735 top_n_flags_[entry.data()] =
TN_TOP2;
736 if (top_heap_.empty()) {
737 top_code_ = entry.data();
739 second_code_ = entry.data();
743 top_n_flags_[null_char_] =
TN_TOP2;
749 void RecodeBeamSearch::DecodeStep(
const float *outputs,
int t,
750 double dict_ratio,
double cert_offset,
751 double worst_dict_cert,
752 const UNICHARSET *charset,
bool debug) {
753 if (t ==
static_cast<int>(beam_.size())) {
754 beam_.push_back(
new RecodeBeam);
756 RecodeBeam *step = beam_[t];
762 charset, dict_ratio, cert_offset, worst_dict_cert, step);
763 if (dict_ !=
nullptr) {
765 TN_TOP2, charset, dict_ratio, cert_offset,
766 worst_dict_cert, step);
769 RecodeBeam *prev = beam_[t - 1];
772 for (
int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
773 std::vector<const RecodeNode *> path;
774 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
775 tprintf(
"Step %d: Dawg beam %d:\n", t, i);
776 DebugPath(charset, path);
779 for (
int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
780 std::vector<const RecodeNode *> path;
781 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
782 tprintf(
"Step %d: Non-Dawg beam %d:\n", t, i);
783 DebugPath(charset, path);
791 for (
int tn = 0; tn <
TN_COUNT && total_beam == 0; ++tn) {
793 for (
int index = 0; index <
kNumBeams; ++index) {
797 for (
int i = prev->beams_[index].size() - 1; i >= 0; --i) {
798 ContinueContext(&prev->beams_[index].get(i).data(), index, outputs,
799 top_n, charset, dict_ratio, cert_offset,
800 worst_dict_cert, step);
803 for (
int index = 0; index <
kNumBeams; ++index) {
805 total_beam += step->beams_[index].size();
811 for (
int c = 0; c <
NC_COUNT; ++c) {
812 if (step->best_initial_dawgs_[c].code >= 0) {
815 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
822 void RecodeBeamSearch::DecodeSecondaryStep(
823 const float *outputs,
int t,
double dict_ratio,
double cert_offset,
824 double worst_dict_cert,
const UNICHARSET *charset,
bool debug) {
825 if (t ==
static_cast<int>(secondary_beam_.size())) {
826 secondary_beam_.push_back(
new RecodeBeam);
828 RecodeBeam *step = secondary_beam_[t];
833 charset, dict_ratio, cert_offset, worst_dict_cert, step);
834 if (dict_ !=
nullptr) {
836 TN_TOP2, charset, dict_ratio, cert_offset,
837 worst_dict_cert, step);
840 RecodeBeam *prev = secondary_beam_[t - 1];
843 for (
int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
844 std::vector<const RecodeNode *> path;
845 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
846 tprintf(
"Step %d: Dawg beam %d:\n", t, i);
847 DebugPath(charset, path);
850 for (
int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
851 std::vector<const RecodeNode *> path;
852 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
853 tprintf(
"Step %d: Non-Dawg beam %d:\n", t, i);
854 DebugPath(charset, path);
862 for (
int tn = 0; tn <
TN_COUNT && total_beam == 0; ++tn) {
864 for (
int index = 0; index <
kNumBeams; ++index) {
868 for (
int i = prev->beams_[index].size() - 1; i >= 0; --i) {
869 ContinueContext(&prev->beams_[index].get(i).data(), index, outputs,
870 top_n, charset, dict_ratio, cert_offset,
871 worst_dict_cert, step);
874 for (
int index = 0; index <
kNumBeams; ++index) {
876 total_beam += step->beams_[index].size();
882 for (
int c = 0; c <
NC_COUNT; ++c) {
883 if (step->best_initial_dawgs_[c].code >= 0) {
886 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
897 void RecodeBeamSearch::ContinueContext(
898 const RecodeNode *prev,
int index,
const float *outputs,
899 TopNState top_n_flag,
const UNICHARSET *charset,
double dict_ratio,
900 double cert_offset,
double worst_dict_cert, RecodeBeam *step) {
901 RecodedCharID prefix;
902 RecodedCharID full_code;
903 const RecodeNode *previous = prev;
907 for (
int p = length - 1; p >= 0; --p, previous = previous->prev) {
908 while (previous !=
nullptr &&
909 (previous->duplicate || previous->code == null_char_)) {
910 previous = previous->prev;
912 if (previous !=
nullptr) {
913 prefix.Set(p, previous->code);
914 full_code.Set(p, previous->code);
917 if (prev !=
nullptr && !is_simple_text_) {
918 if (top_n_flags_[prev->code] == top_n_flag) {
922 PushDupOrNoDawgIfBetter(length,
true, prev->code, prev->unichar_id,
923 cert, worst_dict_cert, dict_ratio, use_dawgs,
927 prev->code != null_char_) {
929 outputs[null_char_]) +
931 PushDupOrNoDawgIfBetter(length,
true, prev->code, prev->unichar_id,
932 cert, worst_dict_cert, dict_ratio, use_dawgs,
939 if (prev->code != null_char_ && length > 0 &&
940 top_n_flags_[null_char_] == top_n_flag) {
945 PushDupOrNoDawgIfBetter(length,
false, null_char_, INVALID_UNICHAR_ID,
946 cert, worst_dict_cert, dict_ratio, use_dawgs,
950 const std::vector<int> *final_codes = recoder_.
GetFinalCodes(prefix);
951 if (final_codes !=
nullptr) {
952 for (
int code : *final_codes) {
953 if (top_n_flags_[code] != top_n_flag) {
956 if (prev !=
nullptr && prev->code == code && !is_simple_text_) {
963 full_code.Set(length, code);
966 if (length == 0 && code == null_char_) {
967 unichar_id = INVALID_UNICHAR_ID;
969 if (unichar_id != INVALID_UNICHAR_ID && charset !=
nullptr &&
970 !charset->get_enabled(unichar_id)) {
973 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
975 if (top_n_flag ==
TN_TOP2 && code != null_char_) {
976 float prob = outputs[code] + outputs[null_char_];
978 prev->code != null_char_ &&
979 ((prev->code == top_code_ && code == second_code_) ||
980 (code == top_code_ && prev->code == second_code_))) {
981 prob += outputs[prev->code];
984 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
989 const std::vector<int> *next_codes = recoder_.
GetNextCodes(prefix);
990 if (next_codes !=
nullptr) {
991 for (
int code : *next_codes) {
992 if (top_n_flags_[code] != top_n_flag) {
995 if (prev !=
nullptr && prev->code == code && !is_simple_text_) {
999 PushDupOrNoDawgIfBetter(length + 1,
false, code, INVALID_UNICHAR_ID, cert,
1000 worst_dict_cert, dict_ratio, use_dawgs,
1002 if (top_n_flag ==
TN_TOP2 && code != null_char_) {
1003 float prob = outputs[code] + outputs[null_char_];
1004 if (prev !=
nullptr && prev_cont ==
NC_ANYTHING &&
1005 prev->code != null_char_ &&
1006 ((prev->code == top_code_ && code == second_code_) ||
1007 (code == top_code_ && prev->code == second_code_))) {
1008 prob += outputs[prev->code];
1011 PushDupOrNoDawgIfBetter(length + 1,
false, code, INVALID_UNICHAR_ID,
1012 cert, worst_dict_cert, dict_ratio, use_dawgs,
1020 void RecodeBeamSearch::ContinueUnichar(
int code,
int unichar_id,
float cert,
1021 float worst_dict_cert,
float dict_ratio,
1023 const RecodeNode *prev,
1026 if (cert > worst_dict_cert) {
1027 ContinueDawg(code, unichar_id, cert, cont, prev, step);
1031 PushHeapIfBetter(kBeamWidths[0], code, unichar_id,
TOP_CHOICE_PERM,
false,
1032 false,
false,
false, cert * dict_ratio, prev,
nullptr,
1034 if (dict_ !=
nullptr &&
1040 float dawg_cert = cert;
1054 dawg_cert *= dict_ratio;
1056 PushInitialDawgIfBetter(code, unichar_id, permuter,
false,
false,
1057 dawg_cert, cont, prev, step);
1065 void RecodeBeamSearch::ContinueDawg(
int code,
int unichar_id,
float cert,
1067 const RecodeNode *prev, RecodeBeam *step) {
1070 if (unichar_id == INVALID_UNICHAR_ID) {
1071 PushHeapIfBetter(kBeamWidths[0], code, unichar_id,
NO_PERM,
false,
false,
1072 false,
false, cert, prev,
nullptr, dawg_heap);
1077 if (prev !=
nullptr) {
1078 score += prev->score;
1080 if (dawg_heap->size() >= kBeamWidths[0] &&
1081 score <= dawg_heap->PeekTop().data().score &&
1082 nodawg_heap->size() >= kBeamWidths[0] &&
1083 score <= nodawg_heap->PeekTop().data().score) {
1086 const RecodeNode *uni_prev = prev;
1089 while (uni_prev !=
nullptr &&
1090 (uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate)) {
1091 uni_prev = uni_prev->prev;
1094 if (uni_prev !=
nullptr && uni_prev->end_of_word) {
1097 PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter,
false,
1098 false, cert, cont, prev, step);
1099 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, uni_prev->permuter,
1100 false,
false,
false,
false, cert, prev,
nullptr,
1104 }
else if (uni_prev !=
nullptr && uni_prev->start_of_dawg &&
1110 DawgPositionVector initial_dawgs;
1111 auto *updated_dawgs =
new DawgPositionVector;
1112 DawgArgs dawg_args(&initial_dawgs, updated_dawgs,
NO_PERM);
1113 bool word_start =
false;
1114 if (uni_prev ==
nullptr) {
1118 }
else if (uni_prev->dawgs !=
nullptr) {
1120 dawg_args.active_dawgs = uni_prev->dawgs;
1121 word_start = uni_prev->start_of_dawg;
1128 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter,
false,
1129 word_start, dawg_args.valid_end,
false, cert, prev,
1130 dawg_args.updated_dawgs, dawg_heap);
1131 if (dawg_args.valid_end && !space_delimited_) {
1135 PushInitialDawgIfBetter(code, unichar_id, permuter, word_start,
true,
1136 cert, cont, prev, step);
1137 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter,
false,
1138 word_start,
true,
false, cert, prev,
nullptr,
1142 delete updated_dawgs;
1149 void RecodeBeamSearch::PushInitialDawgIfBetter(
int code,
int unichar_id,
1151 bool start,
bool end,
float cert,
1153 const RecodeNode *prev,
1155 RecodeNode *best_initial_dawg = &step->best_initial_dawgs_[cont];
1157 if (prev !=
nullptr) {
1158 score += prev->score;
1160 if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) {
1161 auto *initial_dawgs =
new DawgPositionVector;
1163 RecodeNode node(code, unichar_id, permuter,
true, start, end,
false, cert,
1164 score, prev, initial_dawgs,
1165 ComputeCodeHash(code,
false, prev));
1166 *best_initial_dawg = node;
1174 void RecodeBeamSearch::PushDupOrNoDawgIfBetter(
1175 int length,
bool dup,
int code,
int unichar_id,
float cert,
1176 float worst_dict_cert,
float dict_ratio,
bool use_dawgs,
1178 int index =
BeamIndex(use_dawgs, cont, length);
1180 if (cert > worst_dict_cert) {
1181 PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1182 prev ? prev->permuter :
NO_PERM,
false,
false,
false,
1183 dup, cert, prev,
nullptr, &step->beams_[index]);
1188 PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1190 false, dup, cert, prev,
nullptr, &step->beams_[index]);
1198 void RecodeBeamSearch::PushHeapIfBetter(
int max_size,
int code,
int unichar_id,
1200 bool word_start,
bool end,
bool dup,
1201 float cert,
const RecodeNode *prev,
1202 DawgPositionVector *d,
1205 if (prev !=
nullptr) {
1206 score += prev->score;
1208 if (heap->size() < max_size || score > heap->PeekTop().data().score) {
1209 uint64_t hash = ComputeCodeHash(code, dup, prev);
1210 RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end,
1211 dup, cert, score, prev, d, hash);
1212 if (UpdateHeapIfMatched(&node, heap)) {
1218 if (heap->size() > max_size) {
1228 void RecodeBeamSearch::PushHeapIfBetter(
int max_size, RecodeNode *node,
1230 if (heap->size() < max_size || node->score > heap->PeekTop().data().score) {
1231 if (UpdateHeapIfMatched(node, heap)) {
1237 if (heap->size() > max_size) {
1245 bool RecodeBeamSearch::UpdateHeapIfMatched(RecodeNode *new_node,
1250 std::vector<RecodePair> &nodes = heap->heap();
1251 for (
auto &i : nodes) {
1252 RecodeNode &node = i.data();
1253 if (node.code == new_node->code && node.code_hash == new_node->code_hash &&
1254 node.permuter == new_node->permuter &&
1255 node.start_of_dawg == new_node->start_of_dawg) {
1256 if (new_node->score > node.score) {
1260 i.key() = node.score;
1261 heap->Reshuffle(&i);
1270 uint64_t RecodeBeamSearch::ComputeCodeHash(
int code,
bool dup,
1271 const RecodeNode *prev)
const {
1272 uint64_t hash = prev ==
nullptr ? 0 : prev->code_hash;
1273 if (!dup && code != null_char_) {
1275 uint64_t carry = (((hash >> 32) * num_classes) >> 32);
1276 hash *= num_classes;
1287 void RecodeBeamSearch::ExtractBestPaths(
1288 std::vector<const RecodeNode *> *best_nodes,
1289 std::vector<const RecodeNode *> *second_nodes)
const {
1291 const RecodeNode *best_node =
nullptr;
1292 const RecodeNode *second_best_node =
nullptr;
1293 const RecodeBeam *last_beam = beam_[beam_size_ - 1];
1294 for (
int c = 0; c <
NC_COUNT; ++c) {
1299 for (
int is_dawg = 0; is_dawg < 2; ++is_dawg) {
1300 int beam_index =
BeamIndex(is_dawg, cont, 0);
1301 int heap_size = last_beam->beams_[beam_index].size();
1302 for (
int h = 0; h < heap_size; ++h) {
1303 const RecodeNode *node = &last_beam->beams_[beam_index].get(h).data();
1307 const RecodeNode *dawg_node = node;
1308 while (dawg_node !=
nullptr &&
1309 (dawg_node->unichar_id == INVALID_UNICHAR_ID ||
1310 dawg_node->duplicate)) {
1311 dawg_node = dawg_node->prev;
1313 if (dawg_node ==
nullptr ||
1314 (!dawg_node->end_of_word &&
1320 if (best_node ==
nullptr || node->score > best_node->score) {
1321 second_best_node = best_node;
1323 }
else if (second_best_node ==
nullptr ||
1324 node->score > second_best_node->score) {
1325 second_best_node = node;
1330 if (second_nodes !=
nullptr) {
1331 ExtractPath(second_best_node, second_nodes);
1333 ExtractPath(best_node, best_nodes);
1338 void RecodeBeamSearch::ExtractPath(
1339 const RecodeNode *node, std::vector<const RecodeNode *> *path)
const {
1341 while (node !=
nullptr) {
1342 path->push_back(node);
1345 std::reverse(path->begin(), path->end());
1348 void RecodeBeamSearch::ExtractPath(
const RecodeNode *node,
1349 std::vector<const RecodeNode *> *path,
1350 int limiter)
const {
1351 int pathcounter = 0;
1353 while (node !=
nullptr && pathcounter < limiter) {
1354 path->push_back(node);
1358 std::reverse(path->begin(), path->end());
1362 void RecodeBeamSearch::DebugPath(
1363 const UNICHARSET *unicharset,
1364 const std::vector<const RecodeNode *> &path)
const {
1365 for (
unsigned c = 0; c < path.size(); ++c) {
1366 const RecodeNode &node = *path[c];
1368 node.Print(null_char_, *unicharset, 1);
1373 void RecodeBeamSearch::DebugUnicharPath(
1374 const UNICHARSET *unicharset,
const std::vector<const RecodeNode *> &path,
1375 const std::vector<int> &unichar_ids,
const std::vector<float> &certs,
1376 const std::vector<float> &ratings,
const std::vector<int> &xcoords)
const {
1377 auto num_ids = unichar_ids.size();
1378 double total_rating = 0.0;
1379 for (
unsigned c = 0; c < num_ids; ++c) {
1380 int coord = xcoords[c];
1381 tprintf(
"%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c],
1382 unicharset->debug_str(unichar_ids[c]).c_str(), ratings[c], certs[c],
1383 path[coord]->start_of_word, path[coord]->end_of_word,
1384 path[coord]->permuter);
1385 total_rating += ratings[c];
1387 tprintf(
"Path total rating = %g\n", total_rating);
KDPairInc< double, RecodeNode > RecodePair
void tprintf(const char *format,...)
GenericHeap< RecodePair > RecodeHeap
void put(ICOORD pos, const T &thing)
void FakeWordFromRatings(PermuterType permuter)
static C_BLOB * FakeBlob(const TBOX &box)
const Pair & get(int index) const
static const int kMaxCodeLen
const std::vector< int > * GetFinalCodes(const RecodedCharID &code) const
const std::vector< int > * GetNextCodes(const RecodedCharID &code) const
int DecodeUnichar(const RecodedCharID &code) const
const char * id_to_unichar(UNICHAR_ID id) const
bool IsSpaceDelimited(UNICHAR_ID unichar_id) const
std::string debug_str(UNICHAR_ID id) const
const char * id_to_unichar_ext(UNICHAR_ID id) const
bool IsSpaceDelimitedLang() const
Returns true if the language is space-delimited (not CJ, or T).
const UNICHARSET & getUnicharset() const
void default_dawgs(DawgPositionVector *anylength_dawgs, bool suppress_patterns) const
int def_letter_is_okay(void *void_dawg_args, const UNICHARSET &unicharset, UNICHAR_ID unichar_id, bool word_end) const
static float ProbToCertainty(float prob)
bool operator()(const RecodeNode *&node1, const RecodeNode *&node2)
void Print(int null_char, const UNICHARSET &unicharset, int depth) const
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
static bool IsDawgFromBeamsIndex(int index)
std::vector< std::vector< std::pair< const char *, float > > > ctc_choices
static int LengthFromBeamsIndex(int index)
void DecodeSecondaryBeams(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
std::vector< std::vector< std::pair< const char *, float > > > timesteps
std::vector< std::vector< std::vector< std::pair< const char *, float > > > > segmentedTimesteps
void extractSymbolChoices(const UNICHARSET *unicharset)
std::vector< int > character_boundaries_
std::vector< std::unordered_set< int > > excludedUnichars
static NodeContinuation ContinuationFromBeamsIndex(int index)
std::vector< std::vector< std::pair< const char *, float > > > combineSegmentedTimesteps(std::vector< std::vector< std::vector< std::pair< const char *, float >>>> *segmentedTimesteps)
void PrintBeam2(bool uids, int num_outputs, const UNICHARSET *charset, bool secondary) const
void DebugBeams(const UNICHARSET &unicharset) const
void ExtractBestPathAsUnicharIds(bool debug, const UNICHARSET *unicharset, std::vector< int > *unichar_ids, std::vector< float > *certs, std::vector< float > *ratings, std::vector< int > *xcoords) const
void ExtractBestPathAsLabels(std::vector< int > *labels, std::vector< int > *xcoords) const
RecodeBeamSearch(const UnicharCompress &recoder, int null_char, bool simple_text, Dict *dict)
static const int kNumBeams
static constexpr float kMinCertainty
static int BeamIndex(bool is_dawg, NodeContinuation cont, int length)
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
void segmentTimestepsByCharacters()