tesseract  5.0.0
networkbuilder.cpp
Go to the documentation of this file.
1 // File: networkbuilder.cpp
3 // Description: Class to parse the network description language and
4 // build a corresponding network.
5 // Author: Ray Smith
6 //
7 // (C) Copyright 2014, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #include "networkbuilder.h"
20 
21 #include "convolve.h"
22 #include "fullyconnected.h"
23 #include "input.h"
24 #include "lstm.h"
25 #include "maxpool.h"
26 #include "network.h"
27 #include "parallel.h"
28 #include "reconfig.h"
29 #include "reversed.h"
30 #include "series.h"
31 #include "unicharset.h"
32 
33 namespace tesseract {
34 
35 // Builds a network with a network_spec in the network description
36 // language, to recognize a character set of num_outputs size.
37 // If append_index is non-negative, then *network must be non-null and the
38 // given network_spec will be appended to *network AFTER append_index, with
39 // the top of the input *network discarded.
40 // Note that network_spec is call by value to allow a non-const char* pointer
41 // into the string for BuildFromString.
42 // net_flags control network behavior according to the NetworkFlags enum.
43 // The resulting network is returned via **network.
44 // Returns false if something failed.
45 bool NetworkBuilder::InitNetwork(int num_outputs, const char *network_spec, int append_index,
46  int net_flags, float weight_range, TRand *randomizer,
47  Network **network) {
48  NetworkBuilder builder(num_outputs);
49  Series *bottom_series = nullptr;
50  StaticShape input_shape;
51  if (append_index >= 0) {
52  // Split the current network after the given append_index.
53  ASSERT_HOST(*network != nullptr && (*network)->type() == NT_SERIES);
54  auto *series = static_cast<Series *>(*network);
55  Series *top_series = nullptr;
56  series->SplitAt(append_index, &bottom_series, &top_series);
57  if (bottom_series == nullptr || top_series == nullptr) {
58  tprintf("Yikes! Splitting current network failed!!\n");
59  return false;
60  }
61  input_shape = bottom_series->OutputShape(input_shape);
62  delete top_series;
63  }
64  *network = builder.BuildFromString(input_shape, &network_spec);
65  if (*network == nullptr) {
66  return false;
67  }
68  (*network)->SetNetworkFlags(net_flags);
69  (*network)->InitWeights(weight_range, randomizer);
70  (*network)->SetupNeedsBackprop(false);
71  if (bottom_series != nullptr) {
72  bottom_series->AppendSeries(*network);
73  *network = bottom_series;
74  }
75  (*network)->CacheXScaleFactor((*network)->XScaleFactor());
76  return true;
77 }
78 
79 // Helper skips whitespace.
80 static void SkipWhitespace(const char **str) {
81  while (**str == ' ' || **str == '\t' || **str == '\n') {
82  ++*str;
83  }
84 }
85 
86 // Parses the given string and returns a network according to the network
87 // description language in networkbuilder.h
88 Network *NetworkBuilder::BuildFromString(const StaticShape &input_shape, const char **str) {
89  SkipWhitespace(str);
90  char code_ch = **str;
91  if (code_ch == '[') {
92  return ParseSeries(input_shape, nullptr, str);
93  }
94  if (input_shape.depth() == 0) {
95  // There must be an input at this point.
96  return ParseInput(str);
97  }
98  switch (code_ch) {
99  case '(':
100  return ParseParallel(input_shape, str);
101  case 'R':
102  return ParseR(input_shape, str);
103  case 'S':
104  return ParseS(input_shape, str);
105  case 'C':
106  return ParseC(input_shape, str);
107  case 'M':
108  return ParseM(input_shape, str);
109  case 'L':
110  return ParseLSTM(input_shape, str);
111  case 'F':
112  return ParseFullyConnected(input_shape, str);
113  case 'O':
114  return ParseOutput(input_shape, str);
115  default:
116  tprintf("Invalid network spec:%s\n", *str);
117  return nullptr;
118  }
119  return nullptr;
120 }
121 
122 // Parses an input specification and returns the result, which may include a
123 // series.
124 Network *NetworkBuilder::ParseInput(const char **str) {
125  // There must be an input at this point.
126  int length = 0;
127  int batch, height, width, depth;
128  int num_converted = sscanf(*str, "%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length);
129  StaticShape shape;
130  shape.SetShape(batch, height, width, depth);
131  // num_converted may or may not include the length.
132  if (num_converted != 4 && num_converted != 5) {
133  tprintf("Must specify an input layer as the first layer, not %s!!\n", *str);
134  return nullptr;
135  }
136  *str += length;
137  auto *input = new Input("Input", shape);
138  // We want to allow [<input>rest of net... or <input>[rest of net... so we
139  // have to check explicitly for '[' here.
140  SkipWhitespace(str);
141  if (**str == '[') {
142  return ParseSeries(shape, input, str);
143  }
144  return input;
145 }
146 
147 // Parses a sequential series of networks, defined by [<net><net>...].
148 Network *NetworkBuilder::ParseSeries(const StaticShape &input_shape, Input *input_layer,
149  const char **str) {
150  StaticShape shape = input_shape;
151  auto *series = new Series("Series");
152  ++*str;
153  if (input_layer != nullptr) {
154  series->AddToStack(input_layer);
155  shape = input_layer->OutputShape(shape);
156  }
157  Network *network = nullptr;
158  while (**str != '\0' && **str != ']' && (network = BuildFromString(shape, str)) != nullptr) {
159  shape = network->OutputShape(shape);
160  series->AddToStack(network);
161  }
162  if (**str != ']') {
163  tprintf("Missing ] at end of [Series]!\n");
164  delete series;
165  return nullptr;
166  }
167  ++*str;
168  return series;
169 }
170 
171 // Parses a parallel set of networks, defined by (<net><net>...).
172 Network *NetworkBuilder::ParseParallel(const StaticShape &input_shape, const char **str) {
173  auto *parallel = new Parallel("Parallel", NT_PARALLEL);
174  ++*str;
175  Network *network = nullptr;
176  while (**str != '\0' && **str != ')' &&
177  (network = BuildFromString(input_shape, str)) != nullptr) {
178  parallel->AddToStack(network);
179  }
180  if (**str != ')') {
181  tprintf("Missing ) at end of (Parallel)!\n");
182  delete parallel;
183  return nullptr;
184  }
185  ++*str;
186  return parallel;
187 }
188 
189 // Parses a network that begins with 'R'.
190 Network *NetworkBuilder::ParseR(const StaticShape &input_shape, const char **str) {
191  char dir = (*str)[1];
192  if (dir == 'x' || dir == 'y') {
193  std::string name = "Reverse";
194  name += dir;
195  *str += 2;
196  Network *network = BuildFromString(input_shape, str);
197  if (network == nullptr) {
198  return nullptr;
199  }
200  auto *rev = new Reversed(name, dir == 'y' ? NT_YREVERSED : NT_XREVERSED);
201  rev->SetNetwork(network);
202  return rev;
203  }
204  char *end;
205  int replicas = strtol(*str + 1, &end, 10);
206  *str = end;
207  if (replicas <= 0) {
208  tprintf("Invalid R spec!:%s\n", end);
209  return nullptr;
210  }
211  auto *parallel = new Parallel("Replicated", NT_REPLICATED);
212  const char *str_copy = *str;
213  for (int i = 0; i < replicas; ++i) {
214  str_copy = *str;
215  Network *network = BuildFromString(input_shape, &str_copy);
216  if (network == nullptr) {
217  tprintf("Invalid replicated network!\n");
218  delete parallel;
219  return nullptr;
220  }
221  parallel->AddToStack(network);
222  }
223  *str = str_copy;
224  return parallel;
225 }
226 
227 // Parses a network that begins with 'S'.
228 Network *NetworkBuilder::ParseS(const StaticShape &input_shape, const char **str) {
229  char *end;
230  int y = strtol(*str + 1, &end, 10);
231  *str = end;
232  if (**str == ',') {
233  int x = strtol(*str + 1, &end, 10);
234  *str = end;
235  if (y <= 0 || x <= 0) {
236  tprintf("Invalid S spec!:%s\n", *str);
237  return nullptr;
238  }
239  return new Reconfig("Reconfig", input_shape.depth(), x, y);
240  } else if (**str == '(') {
241  // TODO(rays) Add Generic reshape.
242  tprintf("Generic reshape not yet implemented!!\n");
243  return nullptr;
244  }
245  tprintf("Invalid S spec!:%s\n", *str);
246  return nullptr;
247 }
248 
249 // Helper returns the fully-connected type for the character code.
250 static NetworkType NonLinearity(char func) {
251  switch (func) {
252  case 's':
253  return NT_LOGISTIC;
254  case 't':
255  return NT_TANH;
256  case 'r':
257  return NT_RELU;
258  case 'l':
259  return NT_LINEAR;
260  case 'm':
261  return NT_SOFTMAX;
262  case 'p':
263  return NT_POSCLIP;
264  case 'n':
265  return NT_SYMCLIP;
266  default:
267  return NT_NONE;
268  }
269 }
270 
271 // Parses a network that begins with 'C'.
272 Network *NetworkBuilder::ParseC(const StaticShape &input_shape, const char **str) {
273  NetworkType type = NonLinearity((*str)[1]);
274  if (type == NT_NONE) {
275  tprintf("Invalid nonlinearity on C-spec!: %s\n", *str);
276  return nullptr;
277  }
278  int y = 0, x = 0, d = 0;
279  char *end;
280  if ((y = strtol(*str + 2, &end, 10)) <= 0 || *end != ',' ||
281  (x = strtol(end + 1, &end, 10)) <= 0 || *end != ',' || (d = strtol(end + 1, &end, 10)) <= 0) {
282  tprintf("Invalid C spec!:%s\n", end);
283  return nullptr;
284  }
285  *str = end;
286  if (x == 1 && y == 1) {
287  // No actual convolution. Just a FullyConnected on the current depth, to
288  // be slid over all batch,y,x.
289  return new FullyConnected("Conv1x1", input_shape.depth(), d, type);
290  }
291  auto *series = new Series("ConvSeries");
292  auto *convolve = new Convolve("Convolve", input_shape.depth(), x / 2, y / 2);
293  series->AddToStack(convolve);
294  StaticShape fc_input = convolve->OutputShape(input_shape);
295  series->AddToStack(new FullyConnected("ConvNL", fc_input.depth(), d, type));
296  return series;
297 }
298 
299 // Parses a network that begins with 'M'.
300 Network *NetworkBuilder::ParseM(const StaticShape &input_shape, const char **str) {
301  int y = 0, x = 0;
302  char *end;
303  if ((*str)[1] != 'p' || (y = strtol(*str + 2, &end, 10)) <= 0 || *end != ',' ||
304  (x = strtol(end + 1, &end, 10)) <= 0) {
305  tprintf("Invalid Mp spec!:%s\n", *str);
306  return nullptr;
307  }
308  *str = end;
309  return new Maxpool("Maxpool", input_shape.depth(), x, y);
310 }
311 
312 // Parses an LSTM network, either individual, bi- or quad-directional.
313 Network *NetworkBuilder::ParseLSTM(const StaticShape &input_shape, const char **str) {
314  bool two_d = false;
315  NetworkType type = NT_LSTM;
316  const char *spec_start = *str;
317  int chars_consumed = 1;
318  int num_outputs = 0;
319  char key = (*str)[chars_consumed], dir = 'f', dim = 'x';
320  if (key == 'S') {
321  type = NT_LSTM_SOFTMAX;
322  num_outputs = num_softmax_outputs_;
323  ++chars_consumed;
324  } else if (key == 'E') {
326  num_outputs = num_softmax_outputs_;
327  ++chars_consumed;
328  } else if (key == '2' &&
329  (((*str)[2] == 'x' && (*str)[3] == 'y') || ((*str)[2] == 'y' && (*str)[3] == 'x'))) {
330  chars_consumed = 4;
331  dim = (*str)[3];
332  two_d = true;
333  } else if (key == 'f' || key == 'r' || key == 'b') {
334  dir = key;
335  dim = (*str)[2];
336  if (dim != 'x' && dim != 'y') {
337  tprintf("Invalid dimension (x|y) in L Spec!:%s\n", *str);
338  return nullptr;
339  }
340  chars_consumed = 3;
341  if ((*str)[chars_consumed] == 's') {
342  ++chars_consumed;
343  type = NT_LSTM_SUMMARY;
344  }
345  } else {
346  tprintf("Invalid direction (f|r|b) in L Spec!:%s\n", *str);
347  return nullptr;
348  }
349  char *end;
350  int num_states = strtol(*str + chars_consumed, &end, 10);
351  if (num_states <= 0) {
352  tprintf("Invalid number of states in L Spec!:%s\n", *str);
353  return nullptr;
354  }
355  *str = end;
356  Network *lstm = nullptr;
357  if (two_d) {
358  lstm = BuildLSTMXYQuad(input_shape.depth(), num_states);
359  } else {
360  if (num_outputs == 0) {
361  num_outputs = num_states;
362  }
363  std::string name(spec_start, *str - spec_start);
364  lstm = new LSTM(name, input_shape.depth(), num_states, num_outputs, false, type);
365  if (dir != 'f') {
366  auto *rev = new Reversed("RevLSTM", NT_XREVERSED);
367  rev->SetNetwork(lstm);
368  lstm = rev;
369  }
370  if (dir == 'b') {
371  name += "LTR";
372  auto *parallel = new Parallel("BidiLSTM", NT_PAR_RL_LSTM);
373  parallel->AddToStack(
374  new LSTM(name, input_shape.depth(), num_states, num_outputs, false, type));
375  parallel->AddToStack(lstm);
376  lstm = parallel;
377  }
378  }
379  if (dim == 'y') {
380  auto *rev = new Reversed("XYTransLSTM", NT_XYTRANSPOSE);
381  rev->SetNetwork(lstm);
382  lstm = rev;
383  }
384  return lstm;
385 }
386 
387 // Builds a set of 4 lstms with x and y reversal, running in true parallel.
388 Network *NetworkBuilder::BuildLSTMXYQuad(int num_inputs, int num_states) {
389  auto *parallel = new Parallel("2DLSTMQuad", NT_PAR_2D_LSTM);
390  parallel->AddToStack(new LSTM("L2DLTRDown", num_inputs, num_states, num_states, true, NT_LSTM));
391  auto *rev = new Reversed("L2DLTRXRev", NT_XREVERSED);
392  rev->SetNetwork(new LSTM("L2DRTLDown", num_inputs, num_states, num_states, true, NT_LSTM));
393  parallel->AddToStack(rev);
394  rev = new Reversed("L2DRTLYRev", NT_YREVERSED);
395  rev->SetNetwork(new LSTM("L2DRTLUp", num_inputs, num_states, num_states, true, NT_LSTM));
396  auto *rev2 = new Reversed("L2DXRevU", NT_XREVERSED);
397  rev2->SetNetwork(rev);
398  parallel->AddToStack(rev2);
399  rev = new Reversed("L2DXRevY", NT_YREVERSED);
400  rev->SetNetwork(new LSTM("L2DLTRDown", num_inputs, num_states, num_states, true, NT_LSTM));
401  parallel->AddToStack(rev);
402  return parallel;
403 }
404 
405 // Helper builds a truly (0-d) fully connected layer of the given type.
406 static Network *BuildFullyConnected(const StaticShape &input_shape, NetworkType type,
407  const std::string &name, int depth) {
408  if (input_shape.height() == 0 || input_shape.width() == 0) {
409  tprintf("Fully connected requires positive height and width, had %d,%d\n", input_shape.height(),
410  input_shape.width());
411  return nullptr;
412  }
413  int input_size = input_shape.height() * input_shape.width();
414  int input_depth = input_size * input_shape.depth();
415  Network *fc = new FullyConnected(name, input_depth, depth, type);
416  if (input_size > 1) {
417  auto *series = new Series("FCSeries");
418  series->AddToStack(
419  new Reconfig("FCReconfig", input_shape.depth(), input_shape.width(), input_shape.height()));
420  series->AddToStack(fc);
421  fc = series;
422  }
423  return fc;
424 }
425 
426 // Parses a Fully connected network.
427 Network *NetworkBuilder::ParseFullyConnected(const StaticShape &input_shape, const char **str) {
428  const char *spec_start = *str;
429  NetworkType type = NonLinearity((*str)[1]);
430  if (type == NT_NONE) {
431  tprintf("Invalid nonlinearity on F-spec!: %s\n", *str);
432  return nullptr;
433  }
434  char *end;
435  int depth = strtol(*str + 2, &end, 10);
436  if (depth <= 0) {
437  tprintf("Invalid F spec!:%s\n", *str);
438  return nullptr;
439  }
440  *str = end;
441  std::string name(spec_start, *str - spec_start);
442  return BuildFullyConnected(input_shape, type, name, depth);
443 }
444 
445 // Parses an Output spec.
446 Network *NetworkBuilder::ParseOutput(const StaticShape &input_shape, const char **str) {
447  char dims_ch = (*str)[1];
448  if (dims_ch != '0' && dims_ch != '1' && dims_ch != '2') {
449  tprintf("Invalid dims (2|1|0) in output spec!:%s\n", *str);
450  return nullptr;
451  }
452  char type_ch = (*str)[2];
453  if (type_ch != 'l' && type_ch != 's' && type_ch != 'c') {
454  tprintf("Invalid output type (l|s|c) in output spec!:%s\n", *str);
455  return nullptr;
456  }
457  char *end;
458  int depth = strtol(*str + 3, &end, 10);
459  if (depth != num_softmax_outputs_) {
460  tprintf("Warning: given outputs %d not equal to unicharset of %d.\n", depth,
461  num_softmax_outputs_);
462  depth = num_softmax_outputs_;
463  }
464  *str = end;
465  NetworkType type = NT_SOFTMAX;
466  if (type_ch == 'l') {
467  type = NT_LOGISTIC;
468  } else if (type_ch == 's') {
469  type = NT_SOFTMAX_NO_CTC;
470  }
471  if (dims_ch == '0') {
472  // Same as standard fully connected.
473  return BuildFullyConnected(input_shape, type, "Output", depth);
474  } else if (dims_ch == '2') {
475  // We don't care if x and/or y are variable.
476  return new FullyConnected("Output2d", input_shape.depth(), depth, type);
477  }
478  // For 1-d y has to be fixed, and if not 1, moved to depth.
479  if (input_shape.height() == 0) {
480  tprintf("Fully connected requires fixed height!\n");
481  return nullptr;
482  }
483  int input_size = input_shape.height();
484  int input_depth = input_size * input_shape.depth();
485  Network *fc = new FullyConnected("Output", input_depth, depth, type);
486  if (input_size > 1) {
487  auto *series = new Series("FCSeries");
488  series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), 1, input_shape.height()));
489  series->AddToStack(fc);
490  fc = series;
491  }
492  return fc;
493 }
494 
495 } // namespace tesseract.
#define ASSERT_HOST(x)
Definition: errcode.h:59
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
NetworkType
Definition: network.h:41
@ NT_LINEAR
Definition: network.h:65
@ NT_RELU
Definition: network.h:64
@ NT_XREVERSED
Definition: network.h:54
@ NT_LSTM
Definition: network.h:58
@ NT_SOFTMAX
Definition: network.h:66
@ NT_NONE
Definition: network.h:42
@ NT_LOGISTIC
Definition: network.h:60
@ NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:74
@ NT_PARALLEL
Definition: network.h:47
@ NT_SYMCLIP
Definition: network.h:62
@ NT_PAR_2D_LSTM
Definition: network.h:51
@ NT_LSTM_SUMMARY
Definition: network.h:59
@ NT_YREVERSED
Definition: network.h:55
@ NT_POSCLIP
Definition: network.h:61
@ NT_LSTM_SOFTMAX
Definition: network.h:73
@ NT_XYTRANSPOSE
Definition: network.h:56
@ NT_SERIES
Definition: network.h:52
@ NT_SOFTMAX_NO_CTC
Definition: network.h:67
@ NT_TANH
Definition: network.h:63
@ NT_PAR_RL_LSTM
Definition: network.h:49
@ NT_REPLICATED
Definition: network.h:48
virtual void SetNetworkFlags(uint32_t flags)
Definition: network.cpp:131
NetworkType type() const
Definition: network.h:110
TESS_API void AppendSeries(Network *src)
Definition: series.cpp:192
TESS_API void SplitAt(unsigned last_start, Series **start, Series **end)
Definition: series.cpp:163
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: series.cpp:34
void CacheXScaleFactor(int factor) override
Definition: series.cpp:100
void SetShape(int batch, int height, int width, int depth)
Definition: static_shape.h:71
static bool InitNetwork(int num_outputs, const char *network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
Network * BuildFromString(const StaticShape &input_shape, const char **str)