tesseract  5.0.0
plumbing.h
Go to the documentation of this file.
1 // File: plumbing.h
3 // Description: Base class for networks that organize other networks
4 // eg series or parallel.
5 // Author: Ray Smith
6 //
7 // (C) Copyright 2014, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_PLUMBING_H_
20 #define TESSERACT_LSTM_PLUMBING_H_
21 
22 #include "matrix.h"
23 #include "network.h"
24 
25 namespace tesseract {
26 
27 // Holds a collection of other networks and forwards calls to each of them.
28 class Plumbing : public Network {
29 public:
30  // ni_ and no_ will be set by AddToStack.
31  explicit Plumbing(const std::string &name);
32  ~Plumbing() override {
33  for (auto data : stack_) {
34  delete data;
35  }
36  }
37 
38  // Returns the required shape input to the network.
39  StaticShape InputShape() const override {
40  return stack_[0]->InputShape();
41  }
42  std::string spec() const override {
43  return "Sub-classes of Plumbing must implement spec()!";
44  }
45 
46  // Returns true if the given type is derived from Plumbing, and thus contains
47  // multiple sub-networks that can have their own learning rate.
48  bool IsPlumbingType() const override {
49  return true;
50  }
51 
52  // Suspends/Enables training by setting the training_ flag. Serialize and
53  // DeSerialize only operate on the run-time data if state is false.
54  void SetEnableTraining(TrainingState state) override;
55 
56  // Sets flags that control the action of the network. See NetworkFlags enum
57  // for bit values.
58  void SetNetworkFlags(uint32_t flags) override;
59 
60  // Sets up the network for training. Initializes weights using weights of
61  // scale `range` picked according to the random number generator `randomizer`.
62  // Note that randomizer is a borrowed pointer that should outlive the network
63  // and should not be deleted by any of the networks.
64  // Returns the number of weights initialized.
65  int InitWeights(float range, TRand *randomizer) override;
66  // Recursively searches the network for softmaxes with old_no outputs,
67  // and remaps their outputs according to code_map. See network.h for details.
68  int RemapOutputs(int old_no, const std::vector<int> &code_map) override;
69 
70  // Converts a float network to an int network.
71  void ConvertToInt() override;
72 
73  // Provides a pointer to a TRand for any networks that care to use it.
74  // Note that randomizer is a borrowed pointer that should outlive the network
75  // and should not be deleted by any of the networks.
76  void SetRandomizer(TRand *randomizer) override;
77 
78  // Adds the given network to the stack.
79  virtual void AddToStack(Network *network);
80 
81  // Sets needs_to_backprop_ to needs_backprop and returns true if
82  // needs_backprop || any weights in this network so the next layer forward
83  // can be told to produce backprop for this layer if needed.
84  bool SetupNeedsBackprop(bool needs_backprop) override;
85 
86  // Returns an integer reduction factor that the network applies to the
87  // time sequence. Assumes that any 2-d is already eliminated. Used for
88  // scaling bounding boxes of truth data.
89  // WARNING: if GlobalMinimax is used to vary the scale, this will return
90  // the last used scale factor. Call it before any forward, and it will return
91  // the minimum scale factor of the paths through the GlobalMinimax.
92  int XScaleFactor() const override;
93 
94  // Provides the (minimum) x scale factor to the network (of interest only to
95  // input units) so they can determine how to scale bounding boxes.
96  void CacheXScaleFactor(int factor) override;
97 
98  // Provides debug output on the weights.
99  void DebugWeights() override;
100 
101  // Returns the current stack.
102  const std::vector<Network *> &stack() const {
103  return stack_;
104  }
105  // Returns a set of strings representing the layer-ids of all layers below.
106  TESS_API
107  void EnumerateLayers(const std::string *prefix, std::vector<std::string> &layers) const;
108  // Returns a pointer to the network layer corresponding to the given id.
109  TESS_API
110  Network *GetLayer(const char *id) const;
111  // Returns the learning rate for a specific layer of the stack.
112  float LayerLearningRate(const char *id) {
113  const float *lr_ptr = LayerLearningRatePtr(id);
114  ASSERT_HOST(lr_ptr != nullptr);
115  return *lr_ptr;
116  }
117  // Scales the learning rate for a specific layer of the stack.
118  void ScaleLayerLearningRate(const char *id, double factor) {
119  float *lr_ptr = LayerLearningRatePtr(id);
120  ASSERT_HOST(lr_ptr != nullptr);
121  *lr_ptr *= factor;
122  }
123 
124  // Set the learning rate for a specific layer of the stack to the given value.
125  void SetLayerLearningRate(const char *id, float learning_rate) {
126  float *lr_ptr = LayerLearningRatePtr(id);
127  ASSERT_HOST(lr_ptr != nullptr);
128  *lr_ptr = learning_rate;
129  }
130 
131  // Returns a pointer to the learning rate for the given layer id.
132  TESS_API
133  float *LayerLearningRatePtr(const char *id);
134 
135  // Writes to the given file. Returns false in case of error.
136  bool Serialize(TFile *fp) const override;
137  // Reads from the given file. Returns false in case of error.
138  bool DeSerialize(TFile *fp) override;
139 
140  // Updates the weights using the given learning rate, momentum and adam_beta.
141  // num_samples is used in the adam computation iff use_adam_ is true.
142  void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override;
143  // Sums the products of weight updates in *this and other, splitting into
144  // positive (same direction) in *same and negative (different direction) in
145  // *changed.
146  void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override;
147 
148 protected:
149  // The networks.
150  std::vector<Network *> stack_;
151  // Layer-specific learning rate iff network_flags_ & NF_LAYER_SPECIFIC_LR.
152  // One element for each element of stack_.
153  std::vector<float> learning_rates_;
154 };
155 
156 } // namespace tesseract.
157 
158 #endif // TESSERACT_LSTM_PLUMBING_H_
#define ASSERT_HOST(x)
Definition: errcode.h:59
TrainingState
Definition: network.h:90
double TFloat
Definition: tesstypes.h:39
const std::string & name() const
Definition: network.h:140
std::string spec() const override
Definition: plumbing.h:42
void SetEnableTraining(TrainingState state) override
Definition: plumbing.cpp:28
bool DeSerialize(TFile *fp) override
Definition: plumbing.cpp:215
void CacheXScaleFactor(int factor) override
Definition: plumbing.cpp:130
int XScaleFactor() const override
Definition: plumbing.cpp:124
float LayerLearningRate(const char *id)
Definition: plumbing.h:112
void ConvertToInt() override
Definition: plumbing.cpp:68
TESS_API void EnumerateLayers(const std::string *prefix, std::vector< std::string > &layers) const
Definition: plumbing.cpp:144
bool SetupNeedsBackprop(bool needs_backprop) override
Definition: plumbing.cpp:102
StaticShape InputShape() const override
Definition: plumbing.h:39
const std::vector< Network * > & stack() const
Definition: plumbing.h:102
int InitWeights(float range, TRand *randomizer) override
Definition: plumbing.cpp:49
void SetRandomizer(TRand *randomizer) override
Definition: plumbing.cpp:77
virtual void AddToStack(Network *network)
Definition: plumbing.cpp:84
~Plumbing() override
Definition: plumbing.h:32
Plumbing(const std::string &name)
Definition: plumbing.cpp:24
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: plumbing.cpp:59
TESS_API float * LayerLearningRatePtr(const char *id)
Definition: plumbing.cpp:176
void SetNetworkFlags(uint32_t flags) override
Definition: plumbing.cpp:37
void DebugWeights() override
Definition: plumbing.cpp:137
std::vector< Network * > stack_
Definition: plumbing.h:150
void SetLayerLearningRate(const char *id, float learning_rate)
Definition: plumbing.h:125
void ScaleLayerLearningRate(const char *id, double factor)
Definition: plumbing.h:118
void CountAlternators(const Network &other, TFloat *same, TFloat *changed) const override
Definition: plumbing.cpp:258
TESS_API Network * GetLayer(const char *id) const
Definition: plumbing.cpp:161
bool Serialize(TFile *fp) const override
Definition: plumbing.cpp:194
std::vector< float > learning_rates_
Definition: plumbing.h:153
bool IsPlumbingType() const override
Definition: plumbing.h:48
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: plumbing.cpp:240
#define TESS_API
Definition: export.h:34