Skip to content

Commit

Permalink
Merge pull request #1455 from stweil/cov
Browse files Browse the repository at this point in the history
Overload method ForwardTimeStep (CID 1385636 Explicit null dereferenced)
  • Loading branch information
zdenop authored Apr 9, 2018
2 parents 437bf85 + 7cf2e2a commit 4b50f3f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 18 deletions.
33 changes: 19 additions & 14 deletions lstm/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,12 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input,
int thread_id = 0;
#endif
double* temp_line = temp_lines[thread_id];
const double* d_input = nullptr;
const int8_t* i_input = nullptr;
if (input.int_mode()) {
i_input = input.i(t);
ForwardTimeStep(input.i(t), t, temp_line);
} else {
input.ReadTimeStep(t, curr_input[thread_id]);
d_input = curr_input[thread_id];
ForwardTimeStep(curr_input[thread_id], t, temp_line);
}
ForwardTimeStep(d_input, i_input, t, temp_line);
output->WriteTimeStep(t, temp_line);
if (IsTraining() && type_ != NT_SOFTMAX) {
acts_.CopyTimeStepFrom(t, *output, t);
Expand Down Expand Up @@ -188,15 +185,7 @@ void FullyConnected::SetupForward(const NetworkIO& input,
}
}

void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_input,
int t, double* output_line) {
// input is copied to source_ line-by-line for cache coherency.
if (IsTraining() && external_source_ == nullptr && d_input != nullptr)
source_t_.WriteStrided(t, d_input);
if (d_input != nullptr)
weights_.MatrixDotVector(d_input, output_line);
else
weights_.MatrixDotVector(i_input, output_line);
void FullyConnected::ForwardTimeStep(int t, double* output_line) {
if (type_ == NT_TANH) {
FuncInplace<GFunc>(no_, output_line);
} else if (type_ == NT_LOGISTIC) {
Expand All @@ -214,6 +203,22 @@ void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_inpu
}
}

void FullyConnected::ForwardTimeStep(const double* d_input,
int t, double* output_line) {
// input is copied to source_ line-by-line for cache coherency.
if (IsTraining() && external_source_ == NULL)
source_t_.WriteStrided(t, d_input);
weights_.MatrixDotVector(d_input, output_line);
ForwardTimeStep(t, output_line);
}

void FullyConnected::ForwardTimeStep(const int8_t* i_input,
int t, double* output_line) {
// input is copied to source_ line-by-line for cache coherency.
weights_.MatrixDotVector(i_input, output_line);
ForwardTimeStep(t, output_line);
}

// Runs backward propagation of errors on the deltas line.
// See NetworkCpp for a detailed discussion of the arguments.
bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
Expand Down
5 changes: 3 additions & 2 deletions lstm/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ class FullyConnected : public Network {
// Components of Forward so FullyConnected can be reused inside LSTM.
void SetupForward(const NetworkIO& input,
const TransposedArray* input_transpose);
void ForwardTimeStep(const double* d_input, const int8_t* i_input, int t,
double* output_line);
void ForwardTimeStep(int t, double* output_line);
void ForwardTimeStep(const double* d_input, int t, double* output_line);
void ForwardTimeStep(const int8_t* i_input, int t, double* output_line);

// Runs backward propagation of errors on the deltas line.
// See Network for a detailed discussion of the arguments.
Expand Down
4 changes: 2 additions & 2 deletions lstm/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,9 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
if (softmax_ != nullptr) {
if (input.int_mode()) {
int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
softmax_->ForwardTimeStep(nullptr, int_output->i(0), t, softmax_output);
softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
} else {
softmax_->ForwardTimeStep(curr_output, nullptr, t, softmax_output);
softmax_->ForwardTimeStep(curr_output, t, softmax_output);
}
output->WriteTimeStep(t, softmax_output);
if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
Expand Down

0 comments on commit 4b50f3f

Please sign in to comment.