tesseract  5.0.0
network.h
Go to the documentation of this file.
1 // File: network.h
3 // Description: Base class for neural network implementations.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #ifndef TESSERACT_LSTM_NETWORK_H_
19 #define TESSERACT_LSTM_NETWORK_H_
20 
21 #include "helpers.h"
22 #include "matrix.h"
23 #include "networkio.h"
24 #include "serialis.h"
25 #include "static_shape.h"
26 #include "tprintf.h"
27 
28 #include <cmath>
29 #include <cstdio>
30 
31 struct Pix;
32 
33 namespace tesseract {
34 
35 class ScrollView;
36 class TBOX;
37 class ImageData;
38 class NetworkScratch;
39 
40 // Enum to store the run-time type of a Network. Keep in sync with kTypeNames.
42  NT_NONE, // The naked base class.
43  NT_INPUT, // Inputs from an image.
44  // Plumbing networks combine other networks or rearrange the inputs.
45  NT_CONVOLVE, // Duplicates inputs in a sliding window neighborhood.
46  NT_MAXPOOL, // Chooses the max result from a rectangle.
47  NT_PARALLEL, // Runs networks in parallel.
48  NT_REPLICATED, // Runs identical networks in parallel.
49  NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel.
50  NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel.
51  NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel.
52  NT_SERIES, // Executes a sequence of layers.
53  NT_RECONFIG, // Scales the time/y size but makes the output deeper.
54  NT_XREVERSED, // Reverses the x direction of the inputs/outputs.
55  NT_YREVERSED, // Reverses the y-direction of the inputs/outputs.
56  NT_XYTRANSPOSE, // Transposes x and y (for just a single op).
57  // Functional networks actually calculate stuff.
58  NT_LSTM, // Long-Short-Term-Memory block.
59  NT_LSTM_SUMMARY, // LSTM that only keeps its last output.
60  NT_LOGISTIC, // Fully connected logistic nonlinearity.
61  NT_POSCLIP, // Fully connected rect lin version of logistic.
62  NT_SYMCLIP, // Fully connected rect lin version of tanh.
63  NT_TANH, // Fully connected with tanh nonlinearity.
64  NT_RELU, // Fully connected with rectifier nonlinearity.
65  NT_LINEAR, // Fully connected with no nonlinearity.
66  NT_SOFTMAX, // Softmax uses exponential normalization, with CTC.
67  NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC.
68  // The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with
69  // the outputs fed back to the input of the LSTM at the next timestep.
70  // The ENCODED version binary encodes the softmax outputs, providing log2 of
71  // the number of outputs as additional inputs, and the other version just
72  // provides all the softmax outputs as additional inputs.
73  NT_LSTM_SOFTMAX, // 1-d LSTM with built-in fully connected softmax.
74  NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax.
75  // A TensorFlow graph encapsulated as a Tesseract network.
77 
78  NT_COUNT // Array size.
79 };
80 
81 // Enum of Network behavior flags. Can in theory be set for each individual
82 // network element.
84  // Network forward/backprop behavior.
85  NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer.
86  NF_ADAM = 128, // Weight-specific learning rate.
87 };
88 
89 // State of training and desired state used in SetEnableTraining.
91  // Valid states of training_.
92  TS_DISABLED, // Disabled permanently.
93  TS_ENABLED, // Enabled for backprop and to write a training dump.
94  // Re-enable from ANY disabled state.
95  TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump.
96  // Valid only for SetEnableTraining.
97  TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED.
98 };
99 
100 // Base class for network types. Not quite an abstract base class, but almost.
101 // Most of the time no isolated Network exists, except prior to
102 // deserialization.
104 public:
105  Network();
106  Network(NetworkType type, const std::string &name, int ni, int no);
107  virtual ~Network() = default;
108 
109  // Accessors.
110  NetworkType type() const {
111  return type_;
112  }
113  bool IsTraining() const {
114  return training_ == TS_ENABLED;
115  }
116  bool needs_to_backprop() const {
117  return needs_to_backprop_;
118  }
119  int num_weights() const {
120  return num_weights_;
121  }
122  int NumInputs() const {
123  return ni_;
124  }
125  int NumOutputs() const {
126  return no_;
127  }
128  // Returns the required shape input to the network.
129  virtual StaticShape InputShape() const {
130  StaticShape result;
131  return result;
132  }
133  // Returns the shape output from the network given an input shape (which may
134  // be partially unknown ie zero).
135  virtual StaticShape OutputShape(const StaticShape &input_shape) const {
136  StaticShape result(input_shape);
137  result.set_depth(no_);
138  return result;
139  }
140  const std::string &name() const {
141  return name_;
142  }
143  virtual std::string spec() const {
144  return "?";
145  }
146  bool TestFlag(NetworkFlags flag) const {
147  return (network_flags_ & flag) != 0;
148  }
149 
150  // Initialization and administrative functions that are mostly provided
151  // by Plumbing.
152  // Returns true if the given type is derived from Plumbing, and thus contains
153  // multiple sub-networks that can have their own learning rate.
154  virtual bool IsPlumbingType() const {
155  return false;
156  }
157 
158  // Suspends/Enables/Permanently disables training by setting the training_
159  // flag. Serialize and DeSerialize only operate on the run-time data if state
160  // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
161  // temporarily disable layers in state TS_ENABLED, allowing a trainer to
162  // serialize as if it were a recognizer.
163  // TS_RE_ENABLE will re-enable layers that were previously in any disabled
164  // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
165  // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
166  // recognizer can be converted back to a trainer.
167  virtual void SetEnableTraining(TrainingState state);
168 
169  // Sets flags that control the action of the network. See NetworkFlags enum
170  // for bit values.
171  virtual void SetNetworkFlags(uint32_t flags);
172 
173  // Sets up the network for training. Initializes weights using weights of
174  // scale `range` picked according to the random number generator `randomizer`.
175  // Note that randomizer is a borrowed pointer that should outlive the network
176  // and should not be deleted by any of the networks.
177  // Returns the number of weights initialized.
178  virtual int InitWeights(float range, TRand *randomizer);
179  // Changes the number of outputs to the outside world to the size of the given
180  // code_map. Recursively searches the entire network for Softmax layers that
181  // have exactly old_no outputs, and operates only on those, leaving all others
182  // unchanged. This enables networks with multiple output layers to get all
183  // their softmaxes updated, but if an internal layer, uses one of those
184  // softmaxes for input, then the inputs will effectively be scrambled.
185  // TODO(rays) Fix this before any such network is implemented.
186  // The softmaxes are resized by copying the old weight matrix entries for each
187  // output from code_map[output] where non-negative, and uses the mean (over
188  // all outputs) of the existing weights for all outputs with negative code_map
189  // entries. Returns the new number of weights.
190  virtual int RemapOutputs([[maybe_unused]] int old_no,
191  [[maybe_unused]] const std::vector<int> &code_map) {
192  return 0;
193  }
194 
195  // Converts a float network to an int network.
196  virtual void ConvertToInt() {}
197 
198  // Provides a pointer to a TRand for any networks that care to use it.
199  // Note that randomizer is a borrowed pointer that should outlive the network
200  // and should not be deleted by any of the networks.
201  virtual void SetRandomizer(TRand *randomizer);
202 
203  // Sets needs_to_backprop_ to needs_backprop and returns true if
204  // needs_backprop || any weights in this network so the next layer forward
205  // can be told to produce backprop for this layer if needed.
206  virtual bool SetupNeedsBackprop(bool needs_backprop);
207 
208  // Returns the most recent reduction factor that the network applied to the
209  // time sequence. Assumes that any 2-d is already eliminated. Used for
210  // scaling bounding boxes of truth data and calculating result bounding boxes.
211  // WARNING: if GlobalMinimax is used to vary the scale, this will return
212  // the last used scale factor. Call it before any forward, and it will return
213  // the minimum scale factor of the paths through the GlobalMinimax.
214  virtual int XScaleFactor() const {
215  return 1;
216  }
217 
218  // Provides the (minimum) x scale factor to the network (of interest only to
219  // input units) so they can determine how to scale bounding boxes.
220  virtual void CacheXScaleFactor([[maybe_unused]] int factor) {}
221 
222  // Provides debug output on the weights.
223  virtual void DebugWeights() = 0;
224 
225  // Writes to the given file. Returns false in case of error.
226  // Should be overridden by subclasses, but called by their Serialize.
227  virtual bool Serialize(TFile *fp) const;
228  // Reads from the given file. Returns false in case of error.
229  // Should be overridden by subclasses, but NOT called by their DeSerialize.
230  virtual bool DeSerialize(TFile *fp) = 0;
231 
232 public:
233  // Updates the weights using the given learning rate, momentum and adam_beta.
234  // num_samples is used in the adam computation iff use_adam_ is true.
235  virtual void Update([[maybe_unused]] float learning_rate,
236  [[maybe_unused]] float momentum,
237  [[maybe_unused]] float adam_beta,
238  [[maybe_unused]] int num_samples) {}
239  // Sums the products of weight updates in *this and other, splitting into
240  // positive (same direction) in *same and negative (different direction) in
241  // *changed.
242  virtual void CountAlternators([[maybe_unused]] const Network &other,
243  [[maybe_unused]] TFloat *same,
244  [[maybe_unused]] TFloat *changed) const {}
245 
246  // Reads from the given file. Returns nullptr in case of error.
247  // Determines the type of the serialized class and calls its DeSerialize
248  // on a new object of the appropriate type, which is returned.
249  static Network *CreateFromFile(TFile *fp);
250 
251  // Runs forward propagation of activations on the input line.
252  // Note that input and output are both 2-d arrays.
253  // The 1st index is the time element. In a 1-d network, it might be the pixel
254  // position on the textline. In a 2-d network, the linearization is defined
255  // by the stride_map. (See networkio.h).
256  // The 2nd index of input is the network inputs/outputs, and the dimension
257  // of the input must match NumInputs() of this network.
258  // The output array will be resized as needed so that its 1st dimension is
259  // always equal to the number of output values, and its second dimension is
260  // always NumOutputs(). Note that all this detail is encapsulated away inside
261  // NetworkIO, as are the internals of the scratch memory space used by the
262  // network. See networkscratch.h for that.
263  // If input_transpose is not nullptr, then it contains the transpose of input,
264  // and the caller guarantees that it will still be valid on the next call to
265  // backward. The callee is therefore at liberty to save the pointer and
266  // reference it on a call to backward. This is a bit ugly, but it makes it
267  // possible for a replicating parallel to calculate the input transpose once
268  // instead of all the replicated networks having to do it.
269  virtual void Forward(bool debug, const NetworkIO &input,
270  const TransposedArray *input_transpose,
271  NetworkScratch *scratch, NetworkIO *output) = 0;
272 
273  // Runs backward propagation of errors on fwdX_deltas.
274  // Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
275  // Returns false if back_deltas was not set, due to there being no point in
276  // propagating further backwards. Thus most complete networks will always
277  // return false from Backward!
278  virtual bool Backward(bool debug, const NetworkIO &fwd_deltas,
279  NetworkScratch *scratch, NetworkIO *back_deltas) = 0;
280 
281  // === Debug image display methods. ===
282  // Displays the image of the matrix to the forward window.
283  void DisplayForward(const NetworkIO &matrix);
284  // Displays the image of the matrix to the backward window.
285  void DisplayBackward(const NetworkIO &matrix);
286 
287  // Creates the window if needed, otherwise clears it.
288  static void ClearWindow(bool tess_coords, const char *window_name, int width,
289  int height, ScrollView **window);
290 
291  // Displays the pix in the given window. and returns the height of the pix.
292  // The pix is pixDestroyed.
293  static int DisplayImage(Image pix, ScrollView *window);
294 
295 protected:
296  // Returns a random number in [-range, range].
297  TFloat Random(TFloat range);
298 
299 protected:
300  NetworkType type_; // Type of the derived network class.
301  TrainingState training_; // Are we currently training?
302  bool needs_to_backprop_; // This network needs to output back_deltas.
303  int32_t network_flags_; // Behavior control flags in NetworkFlags.
304  int32_t ni_; // Number of input values.
305  int32_t no_; // Number of output values.
306  int32_t num_weights_; // Number of weights in this and sub-network.
307  std::string name_; // A unique name for this layer.
308 
309  // NOT-serialized debug data.
310  ScrollView *forward_win_; // Recognition debug display window.
311  ScrollView *backward_win_; // Training debug display window.
312  TRand *randomizer_; // Random number generator.
313 };
314 
315 } // namespace tesseract.
316 
317 #endif // TESSERACT_LSTM_NETWORK_H_
@ TBOX
bool Serialize(FILE *fp, const std::vector< T > &data)
Definition: helpers.h:251
TrainingState
Definition: network.h:90
@ TS_TEMP_DISABLE
Definition: network.h:95
@ TS_ENABLED
Definition: network.h:93
@ TS_DISABLED
Definition: network.h:92
@ TS_RE_ENABLE
Definition: network.h:97
NetworkType
Definition: network.h:41
@ NT_LINEAR
Definition: network.h:65
@ NT_MAXPOOL
Definition: network.h:46
@ NT_RELU
Definition: network.h:64
@ NT_XREVERSED
Definition: network.h:54
@ NT_LSTM
Definition: network.h:58
@ NT_CONVOLVE
Definition: network.h:45
@ NT_SOFTMAX
Definition: network.h:66
@ NT_NONE
Definition: network.h:42
@ NT_LOGISTIC
Definition: network.h:60
@ NT_PAR_UD_LSTM
Definition: network.h:50
@ NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:74
@ NT_PARALLEL
Definition: network.h:47
@ NT_SYMCLIP
Definition: network.h:62
@ NT_PAR_2D_LSTM
Definition: network.h:51
@ NT_LSTM_SUMMARY
Definition: network.h:59
@ NT_YREVERSED
Definition: network.h:55
@ NT_RECONFIG
Definition: network.h:53
@ NT_INPUT
Definition: network.h:43
@ NT_TENSORFLOW
Definition: network.h:76
@ NT_POSCLIP
Definition: network.h:61
@ NT_LSTM_SOFTMAX
Definition: network.h:73
@ NT_XYTRANSPOSE
Definition: network.h:56
@ NT_SERIES
Definition: network.h:52
@ NT_SOFTMAX_NO_CTC
Definition: network.h:67
@ NT_TANH
Definition: network.h:63
@ NT_PAR_RL_LSTM
Definition: network.h:49
@ NT_COUNT
Definition: network.h:78
@ NT_REPLICATED
Definition: network.h:48
double TFloat
Definition: tesstypes.h:39
NetworkFlags
Definition: network.h:83
@ NF_LAYER_SPECIFIC_LR
Definition: network.h:85
@ NF_ADAM
Definition: network.h:86
int32_t network_flags_
Definition: network.h:303
virtual void Update([[maybe_unused]] float learning_rate, [[maybe_unused]] float momentum, [[maybe_unused]] float adam_beta, [[maybe_unused]] int num_samples)
Definition: network.h:235
NetworkType type_
Definition: network.h:300
virtual void CountAlternators([[maybe_unused]] const Network &other, [[maybe_unused]] TFloat *same, [[maybe_unused]] TFloat *changed) const
Definition: network.h:242
virtual int XScaleFactor() const
Definition: network.h:214
int NumOutputs() const
Definition: network.h:125
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0
bool needs_to_backprop_
Definition: network.h:302
int num_weights() const
Definition: network.h:119
const std::string & name() const
Definition: network.h:140
std::string name_
Definition: network.h:307
virtual bool DeSerialize(TFile *fp)=0
virtual bool IsPlumbingType() const
Definition: network.h:154
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
bool needs_to_backprop() const
Definition: network.h:116
ScrollView * forward_win_
Definition: network.h:310
bool IsTraining() const
Definition: network.h:113
ScrollView * backward_win_
Definition: network.h:311
virtual void DebugWeights()=0
virtual int RemapOutputs([[maybe_unused]] int old_no, [[maybe_unused]] const std::vector< int > &code_map)
Definition: network.h:190
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:135
bool TestFlag(NetworkFlags flag) const
Definition: network.h:146
virtual std::string spec() const
Definition: network.h:143
int NumInputs() const
Definition: network.h:122
int32_t num_weights_
Definition: network.h:306
TrainingState training_
Definition: network.h:301
virtual ~Network()=default
virtual void CacheXScaleFactor([[maybe_unused]] int factor)
Definition: network.h:220
NetworkType type() const
Definition: network.h:110
TRand * randomizer_
Definition: network.h:312
virtual void ConvertToInt()
Definition: network.h:196
virtual StaticShape InputShape() const
Definition: network.h:129
void set_depth(int value)
Definition: static_shape.h:62
#define TESS_API
Definition: export.h:34