Skip to content

Commit

Permalink
Move shaped weights from IntSimMatrix to WeightMatrix
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan Weil <[email protected]>
  • Loading branch information
stweil committed Jan 14, 2019
1 parent ea4d0d3 commit 7c70147
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 16 deletions.
11 changes: 6 additions & 5 deletions src/arch/intsimdmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ IntSimdMatrix* IntSimdMatrix::GetFastestMultiplier() {

// Computes a reshaped copy of the weight matrix w. If there are no
// partial_funcs_, it does nothing.
void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w) {
void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w, std::vector<int8_t>& shaped_w) const {
if (partial_funcs_.empty()) return;
int num_out = w.dim1();
int num_in = w.dim2() - 1;
// The rounded-up sizes of the reshaped weight matrix, excluding biases.
int rounded_num_in = Roundup(num_in, num_inputs_per_group_);
int rounded_num_out = RoundOutputs(num_out);
// Add the bias and compute the required size.
shaped_w_.resize((rounded_num_in + 1) * rounded_num_out, 0);
shaped_w.resize((rounded_num_in + 1) * rounded_num_out, 0);
int shaped_index = 0;
int output = 0;
// Each number of registers needs a different format! Iterates over the
Expand All @@ -74,15 +74,15 @@ void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w) {
int8_t weight = 0;
if (output + j < num_out && input + i < num_in)
weight = w(output + j, input + i);
shaped_w_[shaped_index++] = weight;
shaped_w[shaped_index++] = weight;
}
}
}
// Append the bias weights for the register set.
for (int j = 0; j < num_outputs_per_register_set; ++j) {
int8_t weight = 0;
if (output + j < num_out) weight = w(output + j, num_in);
shaped_w_[shaped_index++] = weight;
shaped_w[shaped_index++] = weight;
}
output += num_outputs_per_register_set;
}
Expand All @@ -94,6 +94,7 @@ void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w) {
// u is imagined to have an extra element at the end with value 1, to
// implement the bias, but it doesn't actually have it.
void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
const std::vector<int8_t>& shaped_w,
const GenericVector<double>& scales,
const int8_t* u, double* v) const {
int num_out = w.dim1();
Expand All @@ -108,7 +109,7 @@ void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
v[i] = (static_cast<double>(total) / INT8_MAX + wi[num_in]) * scales[i];
}
} else {
const int8_t* w_data = shaped_w_.data();
const int8_t* w_data = shaped_w.data();
const double* scales_data = &scales[0];
// Each call to a partial_func_ produces group_size outputs, except the
// last one, which can produce less.
Expand Down
6 changes: 2 additions & 4 deletions src/arch/intsimdmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class IntSimdMatrix {

// Computes a reshaped copy of the weight matrix w. If there are no
// partial_funcs_, it does nothing.
void Init(const GENERIC_2D_ARRAY<int8_t>& w);
void Init(const GENERIC_2D_ARRAY<int8_t>& w, std::vector<int8_t>& shaped_w) const;

// Rounds the size up to a multiple of the input register size (in int8_t).
int RoundInputs(int size) const {
Expand All @@ -95,7 +95,7 @@ class IntSimdMatrix {
// RoundInputs above.
// The input will be over-read to the extent of the padding. There are no
// alignment requirements.
void MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
void MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w, const std::vector<int8_t>& shaped_w,
const GenericVector<double>& scales, const int8_t* u,
double* v) const;

Expand Down Expand Up @@ -125,8 +125,6 @@ class IntSimdMatrix {
int num_inputs_per_group_;
// Number of groups of inputs to be broadcast.
int num_input_groups_;
// The weights matrix reorganized in whatever way suits this instance.
std::vector<int8_t> shaped_w_;
// A series of functions to compute a partial result.
std::vector<PartialFunc> partial_funcs_;
};
Expand Down
6 changes: 3 additions & 3 deletions src/lstm/weightmatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ void WeightMatrix::ConvertToInt() {
wf_.Resize(1, 1, 0.0);
int_mode_ = true;
multiplier_.reset(IntSimdMatrix::GetFastestMultiplier());
multiplier_->Init(wi_);
multiplier_->Init(wi_, shaped_w_);
}

// Allocates any needed memory for running Backward, and zeroes the deltas,
Expand Down Expand Up @@ -197,7 +197,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
if (!wi_.DeSerialize(fp)) return false;
if (!scales_.DeSerialize(fp)) return false;
multiplier_.reset(IntSimdMatrix::GetFastestMultiplier());
multiplier_->Init(wi_);
multiplier_->Init(wi_, shaped_w_);
} else {
if (!wf_.DeSerialize(fp)) return false;
if (training) {
Expand Down Expand Up @@ -246,7 +246,7 @@ void WeightMatrix::MatrixDotVector(const double* u, double* v) const {
void WeightMatrix::MatrixDotVector(const int8_t* u, double* v) const {
assert(int_mode_);
assert(multiplier_ != nullptr);
multiplier_->MatrixDotVector(wi_, scales_, u, v);
multiplier_->MatrixDotVector(wi_, shaped_w_, scales_, u, v);
}

// MatrixDotVector for peep weights, MultiplyAccumulate adds the
Expand Down
2 changes: 2 additions & 0 deletions src/lstm/weightmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ class WeightMatrix {
// Iff use_adam_, the sum of squares of dw_. The number of samples is
// given to Update(). Serialized iff use_adam_.
GENERIC_2D_ARRAY<double> dw_sq_sum_;
// The weights matrix reorganized in whatever way suits this instance.
std::vector<int8_t> shaped_w_;
// Holds the optimal integer multiplier for this machine.
std::unique_ptr<IntSimdMatrix> multiplier_;
};
Expand Down
10 changes: 6 additions & 4 deletions unittest/intsimdmatrix_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,20 @@ class IntSimdMatrixTest : public ::testing::Test {
return v;
}
// Tests a range of sizes and compares the results against the base_ version.
void ExpectEqualResults(IntSimdMatrix* matrix) {
void ExpectEqualResults(const IntSimdMatrix* matrix) {
double total = 0.0;
for (int num_out = 1; num_out < 130; ++num_out) {
for (int num_in = 1; num_in < 130; ++num_in) {
GENERIC_2D_ARRAY<int8_t> w = InitRandom(num_out, num_in + 1);
matrix->Init(w);
std::vector<int8_t> u = RandomVector(num_in, *matrix);
GenericVector<double> scales = RandomScales(num_out);
std::vector<double> base_result(num_out);
base_.MatrixDotVector(w, scales, u.data(), base_result.data());
std::vector<int8_t> dummy;
base_.MatrixDotVector(w, dummy, scales, u.data(), base_result.data());
std::vector<double> test_result(num_out);
matrix->MatrixDotVector(w, scales, u.data(), test_result.data());
std::vector<int8_t> shaped_wi;
matrix->Init(w, shaped_wi);
matrix->MatrixDotVector(w, shaped_wi, scales, u.data(), test_result.data());
for (int i = 0; i < num_out; ++i) {
EXPECT_FLOAT_EQ(base_result[i], test_result[i]) << "i=" << i;
total += base_result[i];
Expand Down

0 comments on commit 7c70147

Please sign in to comment.