20 #ifndef THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
21 #define THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
35 #include <unordered_set>
135 memcpy(
this, &src,
sizeof(src));
143 void Print(
int null_char,
const UNICHARSET &unicharset,
int depth)
const;
192 void Decode(
const NetworkIO &output,
double dict_ratio,
double cert_offset,
193 double worst_dict_cert,
const UNICHARSET *charset,
int lstm_choice_mode = 0);
195 double worst_dict_cert,
const UNICHARSET *charset);
197 void DecodeSecondaryBeams(
const NetworkIO &output,
double dict_ratio,
double cert_offset,
198 double worst_dict_cert,
const UNICHARSET *charset,
199 int lstm_choice_mode = 0);
202 void ExtractBestPathAsLabels(std::vector<int> *labels, std::vector<int> *xcoords)
const;
205 void ExtractBestPathAsUnicharIds(
bool debug,
const UNICHARSET *unicharset,
206 std::vector<int> *unichar_ids, std::vector<float> *certs,
207 std::vector<float> *ratings, std::vector<int> *xcoords)
const;
210 void ExtractBestPathAsWords(
const TBOX &line_box,
float scale_factor,
bool debug,
212 int lstm_choice_mode = 0);
215 void DebugBeams(
const UNICHARSET &unicharset)
const;
222 void extractSymbolChoices(
const UNICHARSET *unicharset);
225 void PrintBeam2(
bool uids,
int num_outputs,
const UNICHARSET *charset,
bool secondary)
const;
227 void segmentTimestepsByCharacters();
228 std::vector<std::vector<std::pair<const char *, float>>>
230 combineSegmentedTimesteps(
231 std::vector<std::vector<std::vector<std::pair<const char *, float>>>> *segmentedTimesteps);
234 std::vector<std::vector<std::pair<const char *, float>>>
timesteps;
237 std::vector<std::vector<std::pair<const char *, float>>>
ctc_choices;
251 static const int kNumBeams = 2 *
NC_COUNT * kNumLengths;
254 return index % kNumLengths;
260 return index / (kNumLengths *
NC_COUNT) > 0;
264 return (is_dawg *
NC_COUNT + cont) * kNumLengths + length;
274 for (
auto &beam : beams_) {
278 for (
auto &best_initial_dawg : best_initial_dawgs_) {
279 best_initial_dawg = empty;
298 RecodeNode best_initial_dawgs_[
NC_COUNT];
300 using TopPair = KDPairInc<float, int>;
303 void DebugBeamPos(
const UNICHARSET &unicharset,
const RecodeHeap &heap)
const;
307 static void ExtractPathAsUnicharIds(
const std::vector<const RecodeNode *> &best_nodes,
308 std::vector<int> *unichar_ids, std::vector<float> *certs,
309 std::vector<float> *ratings, std::vector<int> *xcoords,
310 std::vector<int> *character_boundaries =
nullptr);
314 WERD_RES *InitializeWord(
bool leading_space,
const TBOX &line_box,
int word_start,
int word_end,
315 float space_certainty,
const UNICHARSET *unicharset,
316 const std::vector<int> &xcoords,
float scale_factor);
320 void ComputeTopN(
const float *outputs,
int num_outputs,
int top_n);
322 void ComputeSecTopN(std::unordered_set<int> *exList,
const float *outputs,
int num_outputs,
328 void DecodeStep(
const float *outputs,
int t,
double dict_ratio,
double cert_offset,
329 double worst_dict_cert,
const UNICHARSET *charset,
bool debug =
false);
331 void DecodeSecondaryStep(
const float *outputs,
int t,
double dict_ratio,
double cert_offset,
332 double worst_dict_cert,
const UNICHARSET *charset,
bool debug =
false);
335 void SaveMostCertainChoices(
const float *outputs,
int num_outputs,
const UNICHARSET *charset,
340 static void calculateCharBoundaries(std::vector<int> *starts, std::vector<int> *ends,
341 std::vector<int> *character_boundaries_,
int maxWidth);
347 void ContinueContext(
const RecodeNode *prev,
int index,
const float *outputs,
348 TopNState top_n_flag,
const UNICHARSET *unicharset,
double dict_ratio,
349 double cert_offset,
double worst_dict_cert, RecodeBeam *step);
351 void ContinueUnichar(
int code,
int unichar_id,
float cert,
float worst_dict_cert,
353 const RecodeNode *prev, RecodeBeam *step);
356 void ContinueDawg(
int code,
int unichar_id,
float cert,
NodeContinuation cont,
357 const RecodeNode *prev, RecodeBeam *step);
360 void PushInitialDawgIfBetter(
int code,
int unichar_id,
PermuterType permuter,
bool start,
366 void PushDupOrNoDawgIfBetter(
int length,
bool dup,
int code,
int unichar_id,
float cert,
367 float worst_dict_cert,
float dict_ratio,
bool use_dawgs,
371 void PushHeapIfBetter(
int max_size,
int code,
int unichar_id,
PermuterType permuter,
372 bool dawg_start,
bool word_start,
bool end,
bool dup,
float cert,
373 const RecodeNode *prev, DawgPositionVector *d,
RecodeHeap *heap);
376 void PushHeapIfBetter(
int max_size, RecodeNode *node,
RecodeHeap *heap);
379 bool UpdateHeapIfMatched(RecodeNode *new_node,
RecodeHeap *heap);
381 uint64_t ComputeCodeHash(
int code,
bool dup,
const RecodeNode *prev)
const;
386 void ExtractBestPaths(std::vector<const RecodeNode *> *best_nodes,
387 std::vector<const RecodeNode *> *second_nodes)
const;
390 void ExtractPath(
const RecodeNode *node, std::vector<const RecodeNode *> *path)
const;
391 void ExtractPath(
const RecodeNode *node, std::vector<const RecodeNode *> *path,
394 void DebugPath(
const UNICHARSET *unicharset,
const std::vector<const RecodeNode *> &path)
const;
396 void DebugUnicharPath(
const UNICHARSET *unicharset,
const std::vector<const RecodeNode *> &path,
397 const std::vector<int> &unichar_ids,
const std::vector<float> &certs,
398 const std::vector<float> &ratings,
const std::vector<int> &xcoords)
const;
403 const UnicharCompress &recoder_;
405 std::vector<RecodeBeam *> beam_;
407 std::vector<RecodeBeam *> secondary_beam_;
412 std::vector<TopNState> top_n_flags_;
417 GenericHeap<TopPair> top_heap_;
422 bool space_delimited_;
425 bool is_simple_text_;
const float kMinCertainty
GenericHeap< RecodePair > RecodeHeap
static const int kMaxCodeLen
RecodeNode(int c, int uni_id, PermuterType perm, bool dawg_start, bool word_start, bool end, bool dup, float cert, float s, const RecodeNode *p, DawgPositionVector *d, uint64_t hash)
void Print(int null_char, const UNICHARSET &unicharset, int depth) const
RecodeNode & operator=(const RecodeNode &src)
DawgPositionVector * dawgs
RecodeNode(const RecodeNode &src)
static bool IsDawgFromBeamsIndex(int index)
std::vector< std::vector< std::pair< const char *, float > > > ctc_choices
static int LengthFromBeamsIndex(int index)
std::vector< std::vector< std::pair< const char *, float > > > timesteps
std::vector< std::vector< std::vector< std::pair< const char *, float > > > > segmentedTimesteps
std::vector< int > character_boundaries_
std::vector< std::unordered_set< int > > excludedUnichars
static NodeContinuation ContinuationFromBeamsIndex(int index)
static int BeamIndex(bool is_dawg, NodeContinuation cont, int length)