Skip to content

Commit

Permalink
Implement WeightMatrix::Serialize for TFloat
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan Weil <[email protected]>
  • Loading branch information
stweil committed Jul 20, 2021
1 parent 4a73913 commit 0d412a8
Showing 1 changed file with 11 additions and 17 deletions.
28 changes: 11 additions & 17 deletions src/lstm/weightmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ const int kInt8Flag = 1;
const int kAdamFlag = 4;
// Flag on mode to indicate that this weightmatrix uses TFloat. Set
// independently of kInt8Flag as even in int mode the scales can
// be float or TFloat.
// be float or double.
const int kDoubleFlag = 128;

// Writes to the given file. Returns false in case of error.
Expand All @@ -246,25 +246,19 @@ bool WeightMatrix::Serialize(bool training, TFile *fp) const {
if (!wi_.Serialize(fp)) {
return false;
}
// The scales stored in memory have an extra factor applied to them
// to allow faster operation. We have to remove that factor here
// before writing to disc.
auto scales = scales_;
for (auto &scale : scales) {
scale *= INT8_MAX;
}
uint32_t size = scales.size();
uint32_t size = scales_.size();
if (!fp->Serialize(&size)) {
return false;
}
#ifdef FAST_FLOAT
assert(!"not implemented");
return false;
#else
if (!fp->Serialize(&scales[0], size)) {
return false;
for (auto scale : scales_) {
// The scales stored in memory have an extra factor applied to them
// to allow faster operation. We have to remove that factor here
// before writing to disc.
double value = scale * INT8_MAX;
if (!fp->Serialize(&value)) {
return false;
}
}
#endif
} else {
if (!tesseract::Serialize(fp, wf_)) {
return false;
Expand Down Expand Up @@ -348,7 +342,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile *fp) {
bool WeightMatrix::DeSerializeOld(bool training, TFile *fp) {
#ifdef FAST_FLOAT
// Not implemented.
assert(!"not implemented");
ASSERT_HOST(!"not implemented");
return false;
#else
if (int_mode_) {
Expand Down

0 comments on commit 0d412a8

Please sign in to comment.