From 95606398f518efd320f7063bf899cf03e5969e76 Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Sat, 12 Jan 2019 23:00:31 +0100 Subject: [PATCH] Clean code for IntSimdMatrix Signed-off-by: Stefan Weil --- src/arch/intsimdmatrix.cpp | 41 ++++-------------- src/arch/intsimdmatrix.h | 67 ++++++++++-------------------- src/arch/intsimdmatrixavx2.cpp | 76 ++++++++++++++++++++++++++++++++-- src/arch/intsimdmatrixsse.cpp | 29 ++++++++++++- src/arch/simddetect.cpp | 8 ++-- src/lstm/weightmatrix.cpp | 13 ++++-- unittest/intsimdmatrix_test.cc | 18 ++++---- 7 files changed, 155 insertions(+), 97 deletions(-) diff --git a/src/arch/intsimdmatrix.cpp b/src/arch/intsimdmatrix.cpp index cd831ce760..abb076e3a8 100644 --- a/src/arch/intsimdmatrix.cpp +++ b/src/arch/intsimdmatrix.cpp @@ -77,42 +77,17 @@ void IntSimdMatrix::Init(const GENERIC_2D_ARRAY& w, // u is imagined to have an extra element at the end with value 1, to // implement the bias, but it doesn't actually have it. void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY& w, - const std::vector& shaped_w, const GenericVector& scales, - const int8_t* u, double* v) const { + const int8_t* u, double* v) { int num_out = w.dim1(); int num_in = w.dim2() - 1; - if (partial_funcs_.empty()) { - // Base implementation. - for (int i = 0; i < num_out; ++i) { - const int8_t* wi = w[i]; - int total = 0; - for (int j = 0; j < num_in; ++j) total += wi[j] * u[j]; - // Add in the bias and correct for integer values. - v[i] = (static_cast(total) / INT8_MAX + wi[num_in]) * scales[i]; - } - } else { - const int8_t* w_data = shaped_w.data(); - const double* scales_data = &scales[0]; - // Each call to a partial_func_ produces group_size outputs, except the - // last one, which can produce less. - int group_size = num_outputs_per_register_ * max_output_registers_; - int rounded_num_in = Roundup(num_in, num_inputs_per_group_); - int rounded_num_out = RoundOutputs(num_out); - int output = 0; - for (auto fn : partial_funcs_) { - // The amount of w_data consumed by each call to fn. - int w_step = (rounded_num_in + 1) * group_size; - // Run with this group size, until it would produce too much output, then - // switch to a smaller size. - for (; output + group_size <= rounded_num_out; output += group_size) { - (*fn)(w_data, scales_data, u, rounded_num_in, num_out - output, v); - w_data += w_step; - scales_data += group_size; - v += group_size; - } - group_size /= 2; - } + // Base implementation. + for (int i = 0; i < num_out; ++i) { + const int8_t* wi = w[i]; + int total = 0; + for (int j = 0; j < num_in; ++j) total += wi[j] * u[j]; + // Add in the bias and correct for integer values. + v[i] = (static_cast(total) / INT8_MAX + wi[num_in]) * scales[i]; } } diff --git a/src/arch/intsimdmatrix.h b/src/arch/intsimdmatrix.h index 66034020ac..592c0fbe0f 100644 --- a/src/arch/intsimdmatrix.h +++ b/src/arch/intsimdmatrix.h @@ -58,35 +58,8 @@ namespace tesseract { // NOTE that, although the subclasses execute on different SIMD hardware, no // virtual methods are needed, as the constructor sets up everything that // is required to allow the base class implementation to do all the work. -class IntSimdMatrix { - public: - // Function to compute part of a matrix.vector multiplication. The weights - // are in a very specific order (see above) in w, which is multiplied by - // u of length num_in, to produce output v after scaling the integer results - // by the corresponding member of scales. - // The amount of w and scales consumed is fixed and not available to the - // caller. The number of outputs written to v will be at most num_out. - typedef void (*PartialFunc)(const int8_t* w, const double* scales, - const int8_t* u, int num_in, int num_out, - double* v); - - IntSimdMatrix(int num_outputs_per_register, int max_output_registers, int num_inputs_per_register, int num_inputs_per_group, int num_input_groups, std::vector partial_funcs) : - // Number of 32 bit outputs held in each register. - num_outputs_per_register_(num_outputs_per_register), - // Maximum number of registers that we will use to hold outputs. - max_output_registers_(max_output_registers), - // Number of 8 bit inputs in the inputs register. - num_inputs_per_register_(num_inputs_per_register), - // Number of inputs in each weight group. - num_inputs_per_group_(num_inputs_per_group), - // Number of groups of inputs to be broadcast. - num_input_groups_(num_input_groups), - // A series of functions to compute a partial result. - partial_funcs_(partial_funcs) - {} - - // Computes a reshaped copy of the weight matrix w. If there are no - // partial_funcs_, it does nothing. +struct IntSimdMatrix { + // Computes a reshaped copy of the weight matrix w. void Init(const GENERIC_2D_ARRAY& w, std::vector& shaped_w) const; // Rounds the size up to a multiple of the input register size (in int8_t). @@ -102,20 +75,11 @@ class IntSimdMatrix { // u is of size W.dim2() - 1 and the output v is of size W.dim1(). // u is imagined to have an extra element at the end with value 1, to // implement the bias, but it doesn't actually have it. - // Computes the base C++ implementation, if there are no partial_funcs_. - // NOTE: The size of the input vector (u) must be padded using - // RoundInputs above. - // The input will be over-read to the extent of the padding. There are no - // alignment requirements. - void MatrixDotVector(const GENERIC_2D_ARRAY& w, const std::vector& shaped_w, - const GenericVector& scales, const int8_t* u, - double* v) const; - - static const IntSimdMatrix* intSimdMatrix; - static const IntSimdMatrix IntSimdMatrixAVX2; - static const IntSimdMatrix IntSimdMatrixSSE; + // Computes the base C++ implementation. + static void MatrixDotVector(const GENERIC_2D_ARRAY& w, + const GenericVector& scales, const int8_t* u, + double* v); - protected: // Rounds the input up to a multiple of the given factor. static int Roundup(int input, int factor) { return (input + factor - 1) / factor * factor; @@ -131,8 +95,23 @@ class IntSimdMatrix { int num_inputs_per_group_; // Number of groups of inputs to be broadcast. int num_input_groups_; - // A series of functions to compute a partial result. - std::vector partial_funcs_; + + // Computes matrix.vector v = Wu. + // u is of size W.dim2() - 1 and the output v is of size W.dim1(). + // u is imagined to have an extra element at the end with value 1, to + // implement the bias, but it doesn't actually have it. + // Uses an optimized implementation with partial funcs. + // NOTE: The size of the input vector (u) must be padded using + // RoundInputs above. + // The input will be over-read to the extent of the padding. There are no + // alignment requirements. + typedef void (*MatrixDotVectorFunction)(int dim1, int dim2, + const int8_t* wi, const double* scales, const int8_t* u, double* v); + MatrixDotVectorFunction matrixDotVectorFunction; + + static const IntSimdMatrix* intSimdMatrix; + static const IntSimdMatrix intSimdMatrixAVX2; + static const IntSimdMatrix intSimdMatrixSSE; }; } // namespace tesseract diff --git a/src/arch/intsimdmatrixavx2.cpp b/src/arch/intsimdmatrixavx2.cpp index 0943ba9b03..05654ac051 100644 --- a/src/arch/intsimdmatrixavx2.cpp +++ b/src/arch/intsimdmatrixavx2.cpp @@ -40,6 +40,13 @@ constexpr int kNumInputsPerGroup = 4; // Number of groups of inputs to be broadcast. constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup; +// Functions to compute part of a matrix.vector multiplication. The weights +// are in a very specific order (see above) in w, which is multiplied by +// u of length num_in, to produce output v after scaling the integer results +// by the corresponding member of scales. +// The amount of w and scales consumed is fixed and not available to the +// caller. The number of outputs written to v will be at most num_out. + // Computes one set of 4x8 products of inputs and weights, adding to result. // Horizontally adds 4 adjacent results, making 8x32-bit results. // rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers. @@ -269,8 +276,71 @@ static void PartialMatrixDotVector8(const int8_t* wi, const double* scales, ExtractResults(result0, shift_id, wi, scales, num_out, v); } -const IntSimdMatrix IntSimdMatrix::IntSimdMatrixAVX2 = - IntSimdMatrix(kNumOutputsPerRegister, kMaxOutputRegisters, kNumInputsPerRegister, kNumInputsPerGroup, kNumInputGroups, {PartialMatrixDotVector64, PartialMatrixDotVector32, - PartialMatrixDotVector16, PartialMatrixDotVector8}); +static void matrixDotVector(int dim1, int dim2, const int8_t* wi, + const double* scales, const int8_t* u, double* v) { + const int num_out = dim1; + const int num_in = dim2 - 1; + // Each call to a partial_func_ produces group_size outputs, except the + // last one, which can produce less. + const int rounded_num_in = + IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup); + const int rounded_num_out = + IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister); + int group_size = kNumOutputsPerRegister * kMaxOutputRegisters; + int output = 0; + + int w_step = (rounded_num_in + 1) * group_size; + + // Run with this group size, until it would produce too much output, then + // switch to a smaller size. + for (; output + group_size <= rounded_num_out; output += group_size) { + PartialMatrixDotVector64(wi, scales, u, rounded_num_in, num_out - output, v); + wi += w_step; + scales += group_size; + v += group_size; + } + group_size /= 2; + w_step /= 2; + + for (; output + group_size <= rounded_num_out; output += group_size) { + PartialMatrixDotVector32(wi, scales, u, rounded_num_in, num_out - output, v); + wi += w_step; + scales += group_size; + v += group_size; + } + group_size /= 2; + w_step /= 2; + + for (; output + group_size <= rounded_num_out; output += group_size) { + PartialMatrixDotVector16(wi, scales, u, rounded_num_in, num_out - output, v); + wi += w_step; + scales += group_size; + v += group_size; + } + group_size /= 2; + w_step /= 2; + + for (; output + group_size <= rounded_num_out; output += group_size) { + PartialMatrixDotVector8(wi, scales, u, rounded_num_in, num_out - output, v); + wi += w_step; + scales += group_size; + v += group_size; + } +} + +const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = { + // Number of 32 bit outputs held in each register. + kNumOutputsPerRegister, + // Maximum number of registers that we will use to hold outputs. + kMaxOutputRegisters, + // Number of 8 bit inputs in the inputs register. + kNumInputsPerRegister, + // Number of inputs in each weight group. + kNumInputsPerGroup, + // Number of groups of inputs to be broadcast. + kNumInputGroups, + // Function. + matrixDotVector +}; } // namespace tesseract. diff --git a/src/arch/intsimdmatrixsse.cpp b/src/arch/intsimdmatrixsse.cpp index 4f5d4d0444..8a416ec9b1 100644 --- a/src/arch/intsimdmatrixsse.cpp +++ b/src/arch/intsimdmatrixsse.cpp @@ -35,7 +35,32 @@ static void PartialMatrixDotVector1(const int8_t* wi, const double* scales, *v = (total / INT8_MAX + wi[num_in]) * *scales; } -const IntSimdMatrix IntSimdMatrix::IntSimdMatrixSSE = - IntSimdMatrix(1, 1, 1, 1, 1, {PartialMatrixDotVector1}); +static void matrixDotVector(int dim1, int dim2, const int8_t* wi, + const double* scales, const int8_t* u, double* v) { + const int num_out = dim1; + const int num_in = dim2 - 1; + int output = 0; + + for (; output + 1 <= num_out; output += 1) { + PartialMatrixDotVector1(wi, scales, u, num_in, num_out - output, v); + wi += dim2; + scales += 1; + v += 1; + } +} + +const IntSimdMatrix IntSimdMatrix::intSimdMatrixSSE = { + // Number of 32 bit outputs held in each register. + 1, + // Maximum number of registers that we will use to hold outputs. + 1, + // Number of 8 bit inputs in the inputs register. + 1, + // Number of inputs in each weight group. + 1, + // Number of groups of inputs to be broadcast. + 1, + matrixDotVector +}; } // namespace tesseract. diff --git a/src/arch/simddetect.cpp b/src/arch/simddetect.cpp index f5f70eca67..96783cf1b5 100644 --- a/src/arch/simddetect.cpp +++ b/src/arch/simddetect.cpp @@ -128,12 +128,12 @@ SIMDDetect::SIMDDetect() { #if defined(AVX) } else if (avx_available_) { // AVX detected. - SetDotProduct(DotProductAVX, &IntSimdMatrix::IntSimdMatrixAVX2); + SetDotProduct(DotProductAVX, &IntSimdMatrix::intSimdMatrixAVX2); #endif #if defined(SSE4_1) } else if (sse_available_) { // SSE detected. - SetDotProduct(DotProductSSE, &IntSimdMatrix::IntSimdMatrixSSE); + SetDotProduct(DotProductSSE, &IntSimdMatrix::intSimdMatrixSSE); #endif } } @@ -155,13 +155,13 @@ void SIMDDetect::Update() { #if defined(AVX) } else if (!strcmp(dotproduct.string(), "avx")) { // AVX selected by config variable. - SetDotProduct(DotProductAVX, &IntSimdMatrix::IntSimdMatrixAVX2); + SetDotProduct(DotProductAVX, &IntSimdMatrix::intSimdMatrixAVX2); dotproduct_method = "avx"; #endif #if defined(SSE4_1) } else if (!strcmp(dotproduct.string(), "sse")) { // SSE selected by config variable. - SetDotProduct(DotProductSSE, &IntSimdMatrix::IntSimdMatrixSSE); + SetDotProduct(DotProductSSE, &IntSimdMatrix::intSimdMatrixSSE); dotproduct_method = "sse"; #endif } else { diff --git a/src/lstm/weightmatrix.cpp b/src/lstm/weightmatrix.cpp index df0fbe5706..221c5a8bfd 100644 --- a/src/lstm/weightmatrix.cpp +++ b/src/lstm/weightmatrix.cpp @@ -143,8 +143,9 @@ void WeightMatrix::ConvertToInt() { } wf_.Resize(1, 1, 0.0); int_mode_ = true; - if (IntSimdMatrix::intSimdMatrix) + if (IntSimdMatrix::intSimdMatrix) { IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_); + } } // Allocates any needed memory for running Backward, and zeroes the deltas, @@ -196,8 +197,9 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) { if (int_mode_) { if (!wi_.DeSerialize(fp)) return false; if (!scales_.DeSerialize(fp)) return false; - if (IntSimdMatrix::intSimdMatrix) + if (IntSimdMatrix::intSimdMatrix) { IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_); + } } else { if (!wf_.DeSerialize(fp)) return false; if (training) { @@ -245,7 +247,12 @@ void WeightMatrix::MatrixDotVector(const double* u, double* v) const { void WeightMatrix::MatrixDotVector(const int8_t* u, double* v) const { assert(int_mode_); - IntSimdMatrix::intSimdMatrix->MatrixDotVector(wi_, shaped_w_, scales_, u, v); + if (IntSimdMatrix::intSimdMatrix) { + IntSimdMatrix::intSimdMatrix->matrixDotVectorFunction( + wi_.dim1(), wi_.dim2(), &shaped_w_[0], &scales_[0], u, v); + } else { + IntSimdMatrix::MatrixDotVector(wi_, scales_, u, v); + } } // MatrixDotVector for peep weights, MultiplyAccumulate adds the diff --git a/unittest/intsimdmatrix_test.cc b/unittest/intsimdmatrix_test.cc index 17e154d6bd..bbfcdeb416 100644 --- a/unittest/intsimdmatrix_test.cc +++ b/unittest/intsimdmatrix_test.cc @@ -25,8 +25,6 @@ namespace tesseract { namespace { -static const IntSimdMatrix IntSimdMatrixNative = IntSimdMatrix(1, 1, 1, 1, 1, {}); - class IntSimdMatrixTest : public ::testing::Test { protected: // Makes a random weights matrix of the given size. @@ -65,12 +63,16 @@ class IntSimdMatrixTest : public ::testing::Test { std::vector u = RandomVector(num_in, matrix); GenericVector scales = RandomScales(num_out); std::vector base_result(num_out); - std::vector dummy; - IntSimdMatrixNative.MatrixDotVector(w, dummy, scales, u.data(), base_result.data()); + IntSimdMatrix::MatrixDotVector(w, scales, u.data(), base_result.data()); std::vector test_result(num_out); std::vector shaped_wi; matrix.Init(w, shaped_wi); - matrix.MatrixDotVector(w, shaped_wi, scales, u.data(), test_result.data()); + if (matrix.matrixDotVectorFunction) { + matrix.matrixDotVectorFunction(w.dim1(), w.dim2(), &shaped_wi[0], + &scales[0], &u[0], &test_result[0]); + } else { + IntSimdMatrix::MatrixDotVector(w, scales, u.data(), test_result.data()); + } for (int i = 0; i < num_out; ++i) { EXPECT_FLOAT_EQ(base_result[i], test_result[i]) << "i=" << i; total += base_result[i]; @@ -86,7 +88,7 @@ class IntSimdMatrixTest : public ::testing::Test { // Test the C++ implementation without SIMD. TEST_F(IntSimdMatrixTest, C) { - static const IntSimdMatrix matrix(1, 1, 1, 1, 1, {}); + static const IntSimdMatrix matrix = {1, 1, 1, 1, 1, nullptr}; ExpectEqualResults(matrix); } @@ -99,7 +101,7 @@ TEST_F(IntSimdMatrixTest, SSE) { tprintf("No SSE found! Not tested!"); return; } - ExpectEqualResults(IntSimdMatrix::IntSimdMatrixSSE); + ExpectEqualResults(IntSimdMatrix::intSimdMatrixSSE); #else tprintf("SSE unsupported! Not tested!"); #endif @@ -114,7 +116,7 @@ TEST_F(IntSimdMatrixTest, AVX2) { tprintf("No AVX2 found! Not tested!"); return; } - ExpectEqualResults(IntSimdMatrix::IntSimdMatrixAVX2); + ExpectEqualResults(IntSimdMatrix::intSimdMatrixAVX2); #else tprintf("AVX2 unsupported! Not tested!"); #endif