tesseract  5.0.0
tesseract::WeightMatrix Class Reference

#include <weightmatrix.h>

Public Member Functions

 WeightMatrix ()
 
int InitWeightsFloat (int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
 
int RemapOutputs (const std::vector< int > &code_map)
 
void ConvertToInt ()
 
int RoundInputs (int size) const
 
bool is_int_mode () const
 
int NumOutputs () const
 
const TFloatGetWeights (int index) const
 
TFloat GetDW (int i, int j) const
 
void InitBackward ()
 
bool Serialize (bool training, TFile *fp) const
 
bool DeSerialize (bool training, TFile *fp)
 
bool DeSerializeOld (bool training, TFile *fp)
 
void MatrixDotVector (const TFloat *u, TFloat *v) const
 
void MatrixDotVector (const int8_t *u, TFloat *v) const
 
void MultiplyAccumulate (const TFloat *v, TFloat *inout)
 
void VectorDotMatrix (const TFloat *u, TFloat *v) const
 
void SumOuterTransposed (const TransposedArray &u, const TransposedArray &v, bool parallel)
 
void Update (float learning_rate, float momentum, float adam_beta, int num_samples)
 
void AddDeltas (const WeightMatrix &other)
 
void CountAlternators (const WeightMatrix &other, TFloat *same, TFloat *changed) const
 
void Debug2D (const char *msg)
 

Detailed Description

Definition at line 70 of file weightmatrix.h.

Constructor & Destructor Documentation

◆ WeightMatrix()

tesseract::WeightMatrix::WeightMatrix ( )
inline

Definition at line 72 of file weightmatrix.h.

72 : int_mode_(false), use_adam_(false) {}

Member Function Documentation

◆ AddDeltas()

void tesseract::WeightMatrix::AddDeltas ( const WeightMatrix other)

Definition at line 486 of file weightmatrix.cpp.

486  {
487  assert(dw_.dim1() == other.dw_.dim1());
488  assert(dw_.dim2() == other.dw_.dim2());
489  dw_ += other.dw_;
490 }

◆ ConvertToInt()

void tesseract::WeightMatrix::ConvertToInt ( )

Definition at line 183 of file weightmatrix.cpp.

183  {
184  wi_.ResizeNoInit(wf_.dim1(), wf_.dim2());
185  scales_.reserve(wi_.dim1());
186  int dim2 = wi_.dim2();
187  for (int t = 0; t < wi_.dim1(); ++t) {
188  TFloat *f_line = wf_[t];
189  int8_t *i_line = wi_[t];
190  TFloat max_abs = 0;
191  for (int f = 0; f < dim2; ++f) {
192  TFloat abs_val = fabs(f_line[f]);
193  if (abs_val > max_abs) {
194  max_abs = abs_val;
195  }
196  }
197  TFloat scale = max_abs / INT8_MAX;
198  scales_.push_back(scale / INT8_MAX);
199  if (scale == 0.0) {
200  scale = 1.0;
201  }
202  for (int f = 0; f < dim2; ++f) {
203  i_line[f] = IntCastRounded(f_line[f] / scale);
204  }
205  }
206  wf_.Resize(1, 1, 0.0);
207  int_mode_ = true;
209  int32_t rounded_num_out;
210  IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_, rounded_num_out);
211  scales_.resize(rounded_num_out);
212  }
213 }
int IntCastRounded(double x)
Definition: helpers.h:175
double TFloat
Definition: tesstypes.h:39
void Resize(int size1, int size2, const T &empty)
Definition: matrix.h:110
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:94
static const IntSimdMatrix * intSimdMatrix
void Init(const GENERIC_2D_ARRAY< int8_t > &w, std::vector< int8_t > &shaped_w, int32_t &rounded_num_out) const

◆ CountAlternators()

void tesseract::WeightMatrix::CountAlternators ( const WeightMatrix other,
TFloat same,
TFloat changed 
) const

Definition at line 495 of file weightmatrix.cpp.

496  {
497  int num_outputs = updates_.dim1();
498  int num_inputs = updates_.dim2();
499  assert(num_outputs == other.updates_.dim1());
500  assert(num_inputs == other.updates_.dim2());
501  for (int i = 0; i < num_outputs; ++i) {
502  const TFloat *this_i = updates_[i];
503  const TFloat *other_i = other.updates_[i];
504  for (int j = 0; j < num_inputs; ++j) {
505  TFloat product = this_i[j] * other_i[j];
506  if (product < 0.0) {
507  *changed -= product;
508  } else {
509  *same += product;
510  }
511  }
512  }
513 }

◆ Debug2D()

void tesseract::WeightMatrix::Debug2D ( const char *  msg)

Definition at line 527 of file weightmatrix.cpp.

527  {
528  STATS histogram(0, kHistogramBuckets);
529  if (int_mode_) {
530  for (int i = 0; i < wi_.dim1(); ++i) {
531  for (int j = 0; j < wi_.dim2(); ++j) {
532  HistogramWeight(wi_[i][j] * scales_[i], &histogram);
533  }
534  }
535  } else {
536  for (int i = 0; i < wf_.dim1(); ++i) {
537  for (int j = 0; j < wf_.dim2(); ++j) {
538  HistogramWeight(wf_[i][j], &histogram);
539  }
540  }
541  }
542  tprintf("%s\n", msg);
543  histogram.print();
544 }
void tprintf(const char *format,...)
Definition: tprintf.cpp:41
const int kHistogramBuckets

◆ DeSerialize()

bool tesseract::WeightMatrix::DeSerialize ( bool  training,
TFile fp 
)

Definition at line 280 of file weightmatrix.cpp.

280  {
281  uint8_t mode;
282  if (!fp->DeSerialize(&mode)) {
283  return false;
284  }
285  int_mode_ = (mode & kInt8Flag) != 0;
286  use_adam_ = (mode & kAdamFlag) != 0;
287  if ((mode & kDoubleFlag) == 0) {
288  return DeSerializeOld(training, fp);
289  }
290  if (int_mode_) {
291  if (!wi_.DeSerialize(fp)) {
292  return false;
293  }
294  uint32_t size;
295  if (!fp->DeSerialize(&size)) {
296  return false;
297  }
298 #ifdef FAST_FLOAT
299  scales_.reserve(size);
300  for (auto n = size; n > 0; n--) {
301  double val;
302  if (!fp->DeSerialize(&val)) {
303  return false;
304  }
305  scales_.push_back(val / INT8_MAX);
306  }
307 #else
308  scales_.resize(size);
309  if (!fp->DeSerialize(&scales_[0], size)) {
310  return false;
311  }
312  for (auto &scale : scales_) {
313  scale /= INT8_MAX;
314  }
315 #endif
317  int32_t rounded_num_out;
318  IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_, rounded_num_out);
319  scales_.resize(rounded_num_out);
320  }
321  } else {
322  if (!tesseract::DeSerialize(fp, wf_)) {
323  return false;
324  }
325  if (training) {
326  InitBackward();
327  if (!tesseract::DeSerialize(fp, updates_)) {
328  return false;
329  }
330  if (use_adam_) {
331  if (!tesseract::DeSerialize(fp, dw_sq_sum_)) {
332  return false;
333  }
334  }
335  }
336  }
337  return true;
338 }
const int kInt8Flag
bool DeSerialize(bool swap, FILE *fp, std::vector< T > &data)
Definition: helpers.h:220
const int kDoubleFlag
const int kAdamFlag
bool DeSerialize(bool swap, FILE *fp)
Definition: matrix.h:175
bool DeSerializeOld(bool training, TFile *fp)

◆ DeSerializeOld()

bool tesseract::WeightMatrix::DeSerializeOld ( bool  training,
TFile fp 
)

Definition at line 342 of file weightmatrix.cpp.

342  {
343 #ifdef FAST_FLOAT
344  // Not implemented.
345  ASSERT_HOST(!"not implemented");
346  return false;
347 #else
348  if (int_mode_) {
349  if (!wi_.DeSerialize(fp)) {
350  return false;
351  }
352  std::vector<float> old_scales;
353  if (!fp->DeSerialize(old_scales)) {
354  return false;
355  }
356  scales_.reserve(old_scales.size());
357  for (float old_scale : old_scales) {
358  scales_.push_back(old_scale);
359  }
360  } else {
361  GENERIC_2D_ARRAY<float> float_array;
362  if (!float_array.DeSerialize(fp)) {
363  return false;
364  }
365  FloatToDouble(float_array, wf_);
366  }
367  if (training) {
368  InitBackward();
369  GENERIC_2D_ARRAY<float> float_array;
370  if (!float_array.DeSerialize(fp)) {
371  return false;
372  }
373  FloatToDouble(float_array, updates_);
374  // Errs was only used in int training, which is now dead.
375  if (!float_array.DeSerialize(fp)) {
376  return false;
377  }
378  }
379  return true;
380 #endif
381 }
#define ASSERT_HOST(x)
Definition: errcode.h:59

◆ GetDW()

TFloat tesseract::WeightMatrix::GetDW ( int  i,
int  j 
) const
inline

Definition at line 115 of file weightmatrix.h.

115  {
116  return dw_(i, j);
117  }

◆ GetWeights()

const TFloat* tesseract::WeightMatrix::GetWeights ( int  index) const
inline

Definition at line 111 of file weightmatrix.h.

111  {
112  return wf_[index];
113  }

◆ InitBackward()

void tesseract::WeightMatrix::InitBackward ( )

Definition at line 217 of file weightmatrix.cpp.

217  {
218  int no = int_mode_ ? wi_.dim1() : wf_.dim1();
219  int ni = int_mode_ ? wi_.dim2() : wf_.dim2();
220  dw_.Resize(no, ni, 0.0);
221  updates_.Resize(no, ni, 0.0);
222  wf_t_.Transpose(wf_);
223  if (use_adam_) {
224  dw_sq_sum_.Resize(no, ni, 0.0);
225  }
226 }
void Transpose(const GENERIC_2D_ARRAY< TFloat > &input)

◆ InitWeightsFloat()

int tesseract::WeightMatrix::InitWeightsFloat ( int  no,
int  ni,
bool  use_adam,
float  weight_range,
TRand randomizer 
)

Definition at line 130 of file weightmatrix.cpp.

131  {
132  int_mode_ = false;
133  wf_.Resize(no, ni, 0.0);
134  if (randomizer != nullptr) {
135  for (int i = 0; i < no; ++i) {
136  for (int j = 0; j < ni; ++j) {
137  wf_[i][j] = randomizer->SignedRand(weight_range);
138  }
139  }
140  }
141  use_adam_ = use_adam;
142  InitBackward();
143  return ni * no;
144 }

◆ is_int_mode()

bool tesseract::WeightMatrix::is_int_mode ( ) const
inline

Definition at line 104 of file weightmatrix.h.

104  {
105  return int_mode_;
106  }

◆ MatrixDotVector() [1/2]

void tesseract::WeightMatrix::MatrixDotVector ( const int8_t *  u,
TFloat v 
) const

Definition at line 393 of file weightmatrix.cpp.

393  {
394  assert(int_mode_);
396  IntSimdMatrix::intSimdMatrix->matrixDotVectorFunction(wi_.dim1(), wi_.dim2(), &shaped_w_[0],
397  &scales_[0], u, v);
398  } else {
399  IntSimdMatrix::MatrixDotVector(wi_, scales_, u, v);
400  }
401 }
static void MatrixDotVector(const GENERIC_2D_ARRAY< int8_t > &w, const std::vector< TFloat > &scales, const int8_t *u, TFloat *v)
MatrixDotVectorFunction matrixDotVectorFunction

◆ MatrixDotVector() [2/2]

void tesseract::WeightMatrix::MatrixDotVector ( const TFloat u,
TFloat v 
) const

Definition at line 388 of file weightmatrix.cpp.

388  {
389  assert(!int_mode_);
390  MatrixDotVectorInternal(wf_, true, false, u, v);
391 }

◆ MultiplyAccumulate()

void tesseract::WeightMatrix::MultiplyAccumulate ( const TFloat v,
TFloat inout 
)

Definition at line 405 of file weightmatrix.cpp.

405  {
406  assert(!int_mode_);
407  assert(wf_.dim1() == 1);
408  int n = wf_.dim2();
409  const TFloat *u = wf_[0];
410  for (int i = 0; i < n; ++i) {
411  inout[i] += u[i] * v[i];
412  }
413 }

◆ NumOutputs()

int tesseract::WeightMatrix::NumOutputs ( ) const
inline

Definition at line 107 of file weightmatrix.h.

107  {
108  return int_mode_ ? wi_.dim1() : wf_.dim1();
109  }

◆ RemapOutputs()

int tesseract::WeightMatrix::RemapOutputs ( const std::vector< int > &  code_map)

Definition at line 151 of file weightmatrix.cpp.

151  {
152  GENERIC_2D_ARRAY<TFloat> old_wf(wf_);
153  int old_no = wf_.dim1();
154  int new_no = code_map.size();
155  int ni = wf_.dim2();
156  std::vector<TFloat> means(ni, 0.0);
157  for (int c = 0; c < old_no; ++c) {
158  const TFloat *weights = wf_[c];
159  for (int i = 0; i < ni; ++i) {
160  means[i] += weights[i];
161  }
162  }
163  for (auto &mean : means) {
164  mean /= old_no;
165  }
166  wf_.Resize(new_no, ni, 0.0);
167  InitBackward();
168  for (int dest = 0; dest < new_no; ++dest) {
169  int src = code_map[dest];
170  const TFloat *src_data = src >= 0 ? old_wf[src] : means.data();
171  memcpy(wf_[dest], src_data, ni * sizeof(*src_data));
172  }
173  return ni * new_no;
174 }

◆ RoundInputs()

int tesseract::WeightMatrix::RoundInputs ( int  size) const
inline

Definition at line 96 of file weightmatrix.h.

96  {
97  if (!int_mode_ || !IntSimdMatrix::intSimdMatrix) {
98  return size;
99  }
101  }
int RoundInputs(int size) const
Definition: intsimdmatrix.h:70

◆ Serialize()

bool tesseract::WeightMatrix::Serialize ( bool  training,
TFile fp 
) const

Definition at line 238 of file weightmatrix.cpp.

238  {
239  // For backward compatibility, add kDoubleFlag to mode to indicate the doubles
240  // format, without errs, so we can detect and read old format weight matrices.
241  uint8_t mode = (int_mode_ ? kInt8Flag : 0) | (use_adam_ ? kAdamFlag : 0) | kDoubleFlag;
242  if (!fp->Serialize(&mode)) {
243  return false;
244  }
245  if (int_mode_) {
246  if (!wi_.Serialize(fp)) {
247  return false;
248  }
249  uint32_t size = scales_.size();
250  if (!fp->Serialize(&size)) {
251  return false;
252  }
253  for (auto scale : scales_) {
254  // The scales stored in memory have an extra factor applied to them
255  // to allow faster operation. We have to remove that factor here
256  // before writing to disc.
257  double value = scale * INT8_MAX;
258  if (!fp->Serialize(&value)) {
259  return false;
260  }
261  }
262  } else {
263  if (!tesseract::Serialize(fp, wf_)) {
264  return false;
265  }
266  if (training) {
267  if (!tesseract::Serialize(fp, updates_)) {
268  return false;
269  }
270  if (use_adam_ && !tesseract::Serialize(fp, dw_sq_sum_)) {
271  return false;
272  }
273  }
274  }
275  return true;
276 }
bool Serialize(FILE *fp, const std::vector< T > &data)
Definition: helpers.h:251
bool Serialize(FILE *fp) const
Definition: matrix.h:150

◆ SumOuterTransposed()

void tesseract::WeightMatrix::SumOuterTransposed ( const TransposedArray u,
const TransposedArray v,
bool  parallel 
)

Definition at line 429 of file weightmatrix.cpp.

430  {
431  assert(!int_mode_);
432  int num_outputs = dw_.dim1();
433  assert(u.dim1() == num_outputs);
434  assert(u.dim2() == v.dim2());
435  int num_inputs = dw_.dim2() - 1;
436  int num_samples = u.dim2();
437  // v is missing the last element in dim1.
438  assert(v.dim1() == num_inputs);
439 #ifdef _OPENMP
440 # pragma omp parallel for num_threads(4) if (in_parallel)
441 #endif
442  for (int i = 0; i < num_outputs; ++i) {
443  TFloat *dwi = dw_[i];
444  const TFloat *ui = u[i];
445  for (int j = 0; j < num_inputs; ++j) {
446  dwi[j] = DotProduct(ui, v[j], num_samples);
447  }
448  // The last element of v is missing, presumed 1.0f.
449  TFloat total = 0;
450  for (int k = 0; k < num_samples; ++k) {
451  total += ui[k];
452  }
453  dwi[num_inputs] = total;
454  }
455 }
DotProductFunction DotProduct
Definition: simddetect.cpp:79

◆ Update()

void tesseract::WeightMatrix::Update ( float  learning_rate,
float  momentum,
float  adam_beta,
int  num_samples 
)

Definition at line 460 of file weightmatrix.cpp.

460  {
461  assert(!int_mode_);
462  if (use_adam_ && momentum > 0.0f && num_samples > 0 && num_samples < kAdamCorrectionIterations) {
463  learning_rate *= sqrt(1.0f - pow(adam_beta, num_samples));
464  learning_rate /= 1.0f - pow(momentum, num_samples);
465  }
466  if (use_adam_ && num_samples > 0 && momentum > 0.0f) {
467  dw_sq_sum_.SumSquares(dw_, adam_beta);
468  dw_ *= learning_rate * (1.0f - momentum);
469  updates_ *= momentum;
470  updates_ += dw_;
471  wf_.AdamUpdate(updates_, dw_sq_sum_, learning_rate * kAdamEpsilon);
472  } else {
473  dw_ *= learning_rate;
474  updates_ += dw_;
475  if (momentum > 0.0f) {
476  wf_ += updates_;
477  }
478  if (momentum >= 0.0f) {
479  updates_ *= momentum;
480  }
481  }
482  wf_t_.Transpose(wf_);
483 }
const TFloat kAdamEpsilon
const int kAdamCorrectionIterations
void AdamUpdate(const GENERIC_2D_ARRAY< T > &sum, const GENERIC_2D_ARRAY< T > &sqsum, const T &epsilon)
Definition: matrix.h:429
void SumSquares(const GENERIC_2D_ARRAY< T > &src, const T &decay_factor)
Definition: matrix.h:419

◆ VectorDotMatrix()

void tesseract::WeightMatrix::VectorDotMatrix ( const TFloat u,
TFloat v 
) const

Definition at line 419 of file weightmatrix.cpp.

419  {
420  assert(!int_mode_);
421  MatrixDotVectorInternal(wf_t_, false, true, u, v);
422 }

The documentation for this class was generated from the following files: