Skip to content

Commit

Permalink
Clean code for IntSimdMatrix
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan Weil <[email protected]>
  • Loading branch information
stweil committed Jan 14, 2019
1 parent 7fc7d28 commit 9560639
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 97 deletions.
41 changes: 8 additions & 33 deletions src/arch/intsimdmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,42 +77,17 @@ void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w,
// u is imagined to have an extra element at the end with value 1, to
// implement the bias, but it doesn't actually have it.
void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
const std::vector<int8_t>& shaped_w,
const GenericVector<double>& scales,
const int8_t* u, double* v) const {
const int8_t* u, double* v) {
int num_out = w.dim1();
int num_in = w.dim2() - 1;
if (partial_funcs_.empty()) {
// Base implementation.
for (int i = 0; i < num_out; ++i) {
const int8_t* wi = w[i];
int total = 0;
for (int j = 0; j < num_in; ++j) total += wi[j] * u[j];
// Add in the bias and correct for integer values.
v[i] = (static_cast<double>(total) / INT8_MAX + wi[num_in]) * scales[i];
}
} else {
const int8_t* w_data = shaped_w.data();
const double* scales_data = &scales[0];
// Each call to a partial_func_ produces group_size outputs, except the
// last one, which can produce less.
int group_size = num_outputs_per_register_ * max_output_registers_;
int rounded_num_in = Roundup(num_in, num_inputs_per_group_);
int rounded_num_out = RoundOutputs(num_out);
int output = 0;
for (auto fn : partial_funcs_) {
// The amount of w_data consumed by each call to fn.
int w_step = (rounded_num_in + 1) * group_size;
// Run with this group size, until it would produce too much output, then
// switch to a smaller size.
for (; output + group_size <= rounded_num_out; output += group_size) {
(*fn)(w_data, scales_data, u, rounded_num_in, num_out - output, v);
w_data += w_step;
scales_data += group_size;
v += group_size;
}
group_size /= 2;
}
// Base implementation.
for (int i = 0; i < num_out; ++i) {
const int8_t* wi = w[i];
int total = 0;
for (int j = 0; j < num_in; ++j) total += wi[j] * u[j];
// Add in the bias and correct for integer values.
v[i] = (static_cast<double>(total) / INT8_MAX + wi[num_in]) * scales[i];
}
}

Expand Down
67 changes: 23 additions & 44 deletions src/arch/intsimdmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,8 @@ namespace tesseract {
// NOTE that, although the subclasses execute on different SIMD hardware, no
// virtual methods are needed, as the constructor sets up everything that
// is required to allow the base class implementation to do all the work.
class IntSimdMatrix {
public:
// Function to compute part of a matrix.vector multiplication. The weights
// are in a very specific order (see above) in w, which is multiplied by
// u of length num_in, to produce output v after scaling the integer results
// by the corresponding member of scales.
// The amount of w and scales consumed is fixed and not available to the
// caller. The number of outputs written to v will be at most num_out.
typedef void (*PartialFunc)(const int8_t* w, const double* scales,
const int8_t* u, int num_in, int num_out,
double* v);

IntSimdMatrix(int num_outputs_per_register, int max_output_registers, int num_inputs_per_register, int num_inputs_per_group, int num_input_groups, std::vector<PartialFunc> partial_funcs) :
// Number of 32 bit outputs held in each register.
num_outputs_per_register_(num_outputs_per_register),
// Maximum number of registers that we will use to hold outputs.
max_output_registers_(max_output_registers),
// Number of 8 bit inputs in the inputs register.
num_inputs_per_register_(num_inputs_per_register),
// Number of inputs in each weight group.
num_inputs_per_group_(num_inputs_per_group),
// Number of groups of inputs to be broadcast.
num_input_groups_(num_input_groups),
// A series of functions to compute a partial result.
partial_funcs_(partial_funcs)
{}

// Computes a reshaped copy of the weight matrix w. If there are no
// partial_funcs_, it does nothing.
struct IntSimdMatrix {
// Computes a reshaped copy of the weight matrix w.
void Init(const GENERIC_2D_ARRAY<int8_t>& w, std::vector<int8_t>& shaped_w) const;

// Rounds the size up to a multiple of the input register size (in int8_t).
Expand All @@ -102,20 +75,11 @@ class IntSimdMatrix {
// u is of size W.dim2() - 1 and the output v is of size W.dim1().
// u is imagined to have an extra element at the end with value 1, to
// implement the bias, but it doesn't actually have it.
// Computes the base C++ implementation, if there are no partial_funcs_.
// NOTE: The size of the input vector (u) must be padded using
// RoundInputs above.
// The input will be over-read to the extent of the padding. There are no
// alignment requirements.
void MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w, const std::vector<int8_t>& shaped_w,
const GenericVector<double>& scales, const int8_t* u,
double* v) const;

static const IntSimdMatrix* intSimdMatrix;
static const IntSimdMatrix IntSimdMatrixAVX2;
static const IntSimdMatrix IntSimdMatrixSSE;
// Computes the base C++ implementation.
static void MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
const GenericVector<double>& scales, const int8_t* u,
double* v);

protected:
// Rounds the input up to a multiple of the given factor.
static int Roundup(int input, int factor) {
return (input + factor - 1) / factor * factor;
Expand All @@ -131,8 +95,23 @@ class IntSimdMatrix {
int num_inputs_per_group_;
// Number of groups of inputs to be broadcast.
int num_input_groups_;
// A series of functions to compute a partial result.
std::vector<PartialFunc> partial_funcs_;

// Computes matrix.vector v = Wu.
// u is of size W.dim2() - 1 and the output v is of size W.dim1().
// u is imagined to have an extra element at the end with value 1, to
// implement the bias, but it doesn't actually have it.
// Uses an optimized implementation with partial funcs.
// NOTE: The size of the input vector (u) must be padded using
// RoundInputs above.
// The input will be over-read to the extent of the padding. There are no
// alignment requirements.
typedef void (*MatrixDotVectorFunction)(int dim1, int dim2,
const int8_t* wi, const double* scales, const int8_t* u, double* v);
MatrixDotVectorFunction matrixDotVectorFunction;

static const IntSimdMatrix* intSimdMatrix;
static const IntSimdMatrix intSimdMatrixAVX2;
static const IntSimdMatrix intSimdMatrixSSE;
};

} // namespace tesseract
Expand Down
76 changes: 73 additions & 3 deletions src/arch/intsimdmatrixavx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ constexpr int kNumInputsPerGroup = 4;
// Number of groups of inputs to be broadcast.
constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;

// Functions to compute part of a matrix.vector multiplication. The weights
// are in a very specific order (see above) in w, which is multiplied by
// u of length num_in, to produce output v after scaling the integer results
// by the corresponding member of scales.
// The amount of w and scales consumed is fixed and not available to the
// caller. The number of outputs written to v will be at most num_out.

// Computes one set of 4x8 products of inputs and weights, adding to result.
// Horizontally adds 4 adjacent results, making 8x32-bit results.
// rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
Expand Down Expand Up @@ -269,8 +276,71 @@ static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
ExtractResults(result0, shift_id, wi, scales, num_out, v);
}

const IntSimdMatrix IntSimdMatrix::IntSimdMatrixAVX2 =
IntSimdMatrix(kNumOutputsPerRegister, kMaxOutputRegisters, kNumInputsPerRegister, kNumInputsPerGroup, kNumInputGroups, {PartialMatrixDotVector64, PartialMatrixDotVector32,
PartialMatrixDotVector16, PartialMatrixDotVector8});
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
const double* scales, const int8_t* u, double* v) {
const int num_out = dim1;
const int num_in = dim2 - 1;
// Each call to a partial_func_ produces group_size outputs, except the
// last one, which can produce less.
const int rounded_num_in =
IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
const int rounded_num_out =
IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
int output = 0;

int w_step = (rounded_num_in + 1) * group_size;

// Run with this group size, until it would produce too much output, then
// switch to a smaller size.
for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, num_out - output, v);
wi += w_step;
scales += group_size;
v += group_size;
}
group_size /= 2;
w_step /= 2;

for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, num_out - output, v);
wi += w_step;
scales += group_size;
v += group_size;
}
group_size /= 2;
w_step /= 2;

for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, num_out - output, v);
wi += w_step;
scales += group_size;
v += group_size;
}
group_size /= 2;
w_step /= 2;

for (; output + group_size <= rounded_num_out; output += group_size) {
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, num_out - output, v);
wi += w_step;
scales += group_size;
v += group_size;
}
}

const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
// Number of 32 bit outputs held in each register.
kNumOutputsPerRegister,
// Maximum number of registers that we will use to hold outputs.
kMaxOutputRegisters,
// Number of 8 bit inputs in the inputs register.
kNumInputsPerRegister,
// Number of inputs in each weight group.
kNumInputsPerGroup,
// Number of groups of inputs to be broadcast.
kNumInputGroups,
// Function.
matrixDotVector
};

} // namespace tesseract.
29 changes: 27 additions & 2 deletions src/arch/intsimdmatrixsse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,32 @@ static void PartialMatrixDotVector1(const int8_t* wi, const double* scales,
*v = (total / INT8_MAX + wi[num_in]) * *scales;
}

const IntSimdMatrix IntSimdMatrix::IntSimdMatrixSSE =
IntSimdMatrix(1, 1, 1, 1, 1, {PartialMatrixDotVector1});
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
const double* scales, const int8_t* u, double* v) {
const int num_out = dim1;
const int num_in = dim2 - 1;
int output = 0;

for (; output + 1 <= num_out; output += 1) {
PartialMatrixDotVector1(wi, scales, u, num_in, num_out - output, v);
wi += dim2;
scales += 1;
v += 1;
}
}

const IntSimdMatrix IntSimdMatrix::intSimdMatrixSSE = {
// Number of 32 bit outputs held in each register.
1,
// Maximum number of registers that we will use to hold outputs.
1,
// Number of 8 bit inputs in the inputs register.
1,
// Number of inputs in each weight group.
1,
// Number of groups of inputs to be broadcast.
1,
matrixDotVector
};

} // namespace tesseract.
8 changes: 4 additions & 4 deletions src/arch/simddetect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ SIMDDetect::SIMDDetect() {
#if defined(AVX)
} else if (avx_available_) {
// AVX detected.
SetDotProduct(DotProductAVX, &IntSimdMatrix::IntSimdMatrixAVX2);
SetDotProduct(DotProductAVX, &IntSimdMatrix::intSimdMatrixAVX2);
#endif
#if defined(SSE4_1)
} else if (sse_available_) {
// SSE detected.
SetDotProduct(DotProductSSE, &IntSimdMatrix::IntSimdMatrixSSE);
SetDotProduct(DotProductSSE, &IntSimdMatrix::intSimdMatrixSSE);
#endif
}
}
Expand All @@ -155,13 +155,13 @@ void SIMDDetect::Update() {
#if defined(AVX)
} else if (!strcmp(dotproduct.string(), "avx")) {
// AVX selected by config variable.
SetDotProduct(DotProductAVX, &IntSimdMatrix::IntSimdMatrixAVX2);
SetDotProduct(DotProductAVX, &IntSimdMatrix::intSimdMatrixAVX2);
dotproduct_method = "avx";
#endif
#if defined(SSE4_1)
} else if (!strcmp(dotproduct.string(), "sse")) {
// SSE selected by config variable.
SetDotProduct(DotProductSSE, &IntSimdMatrix::IntSimdMatrixSSE);
SetDotProduct(DotProductSSE, &IntSimdMatrix::intSimdMatrixSSE);
dotproduct_method = "sse";
#endif
} else {
Expand Down
13 changes: 10 additions & 3 deletions src/lstm/weightmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ void WeightMatrix::ConvertToInt() {
}
wf_.Resize(1, 1, 0.0);
int_mode_ = true;
if (IntSimdMatrix::intSimdMatrix)
if (IntSimdMatrix::intSimdMatrix) {
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
}
}

// Allocates any needed memory for running Backward, and zeroes the deltas,
Expand Down Expand Up @@ -196,8 +197,9 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
if (int_mode_) {
if (!wi_.DeSerialize(fp)) return false;
if (!scales_.DeSerialize(fp)) return false;
if (IntSimdMatrix::intSimdMatrix)
if (IntSimdMatrix::intSimdMatrix) {
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
}
} else {
if (!wf_.DeSerialize(fp)) return false;
if (training) {
Expand Down Expand Up @@ -245,7 +247,12 @@ void WeightMatrix::MatrixDotVector(const double* u, double* v) const {

void WeightMatrix::MatrixDotVector(const int8_t* u, double* v) const {
assert(int_mode_);
IntSimdMatrix::intSimdMatrix->MatrixDotVector(wi_, shaped_w_, scales_, u, v);
if (IntSimdMatrix::intSimdMatrix) {
IntSimdMatrix::intSimdMatrix->matrixDotVectorFunction(
wi_.dim1(), wi_.dim2(), &shaped_w_[0], &scales_[0], u, v);
} else {
IntSimdMatrix::MatrixDotVector(wi_, scales_, u, v);
}
}

// MatrixDotVector for peep weights, MultiplyAccumulate adds the
Expand Down
18 changes: 10 additions & 8 deletions unittest/intsimdmatrix_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
namespace tesseract {
namespace {

static const IntSimdMatrix IntSimdMatrixNative = IntSimdMatrix(1, 1, 1, 1, 1, {});

class IntSimdMatrixTest : public ::testing::Test {
protected:
// Makes a random weights matrix of the given size.
Expand Down Expand Up @@ -65,12 +63,16 @@ class IntSimdMatrixTest : public ::testing::Test {
std::vector<int8_t> u = RandomVector(num_in, matrix);
GenericVector<double> scales = RandomScales(num_out);
std::vector<double> base_result(num_out);
std::vector<int8_t> dummy;
IntSimdMatrixNative.MatrixDotVector(w, dummy, scales, u.data(), base_result.data());
IntSimdMatrix::MatrixDotVector(w, scales, u.data(), base_result.data());
std::vector<double> test_result(num_out);
std::vector<int8_t> shaped_wi;
matrix.Init(w, shaped_wi);
matrix.MatrixDotVector(w, shaped_wi, scales, u.data(), test_result.data());
if (matrix.matrixDotVectorFunction) {
matrix.matrixDotVectorFunction(w.dim1(), w.dim2(), &shaped_wi[0],
&scales[0], &u[0], &test_result[0]);
} else {
IntSimdMatrix::MatrixDotVector(w, scales, u.data(), test_result.data());
}
for (int i = 0; i < num_out; ++i) {
EXPECT_FLOAT_EQ(base_result[i], test_result[i]) << "i=" << i;
total += base_result[i];
Expand All @@ -86,7 +88,7 @@ class IntSimdMatrixTest : public ::testing::Test {

// Test the C++ implementation without SIMD.
TEST_F(IntSimdMatrixTest, C) {
static const IntSimdMatrix matrix(1, 1, 1, 1, 1, {});
static const IntSimdMatrix matrix = {1, 1, 1, 1, 1, nullptr};
ExpectEqualResults(matrix);
}

Expand All @@ -99,7 +101,7 @@ TEST_F(IntSimdMatrixTest, SSE) {
tprintf("No SSE found! Not tested!");
return;
}
ExpectEqualResults(IntSimdMatrix::IntSimdMatrixSSE);
ExpectEqualResults(IntSimdMatrix::intSimdMatrixSSE);
#else
tprintf("SSE unsupported! Not tested!");
#endif
Expand All @@ -114,7 +116,7 @@ TEST_F(IntSimdMatrixTest, AVX2) {
tprintf("No AVX2 found! Not tested!");
return;
}
ExpectEqualResults(IntSimdMatrix::IntSimdMatrixAVX2);
ExpectEqualResults(IntSimdMatrix::intSimdMatrixAVX2);
#else
tprintf("AVX2 unsupported! Not tested!");
#endif
Expand Down

0 comments on commit 9560639

Please sign in to comment.