From 34d1e7331deb494bda0eadc3fd9abe705878cf2f Mon Sep 17 00:00:00 2001 From: Stefan Weil Date: Sun, 4 Jun 2017 17:41:20 +0200 Subject: [PATCH] LSTMTrainer: Catch empty vectors The new test in LSTMTrainer::UpdateErrorGraph fixes an assertion (see issues #644, #792). The new test in LSTMTrainer::ReadTrainingDump was added to improve the robustness of the code. Signed-off-by: Stefan Weil --- lstm/lstmtrainer.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lstm/lstmtrainer.cpp b/lstm/lstmtrainer.cpp index 036199694a..5678c5b87f 100644 --- a/lstm/lstmtrainer.cpp +++ b/lstm/lstmtrainer.cpp @@ -918,6 +918,10 @@ bool LSTMTrainer::SaveTrainingDump(SerializeAmount serialize_amount, // Reads previously saved trainer from memory. bool LSTMTrainer::ReadTrainingDump(const GenericVector& data, LSTMTrainer* trainer) { + if (data.size() == 0) { + tprintf("Warning: data size is zero in LSTMTrainer::ReadTrainingDump\n"); + return false; + } return trainer->ReadSizedTrainingDump(&data[0], data.size()); } @@ -1298,8 +1302,9 @@ STRING LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate, if (error_rate < best_error_rate_) { // This is a new (global) minimum. if (tester != NULL) { - result = tester->Run(worst_iteration_, worst_error_rates_, - worst_model_data_, CurrentTrainingStage()); + if (worst_model_data_.size() != 0) + result = tester->Run(worst_iteration_, worst_error_rates_, + worst_model_data_, CurrentTrainingStage()); worst_model_data_.truncate(0); best_model_data_ = model_data; }