Skip to content

Commit

Permalink
Add --reset_learning_rate option to lstmtraining (#3470)
Browse files Browse the repository at this point in the history
When the --reset_learning_rate option is specified,
it resets the learning rate stored in each layer of the network
loaded with --continue_from to the value specified by the --learning_rate option.
If checkpoint is available, it does nothing.
  • Loading branch information
nagadomi authored Jun 28, 2021
1 parent d8bd78f commit ff1062d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/lstm/lstmrecognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,26 @@ class TESS_API LSTMRecognizer {
series->ScaleLayerLearningRate(&id[1], factor);
}

// Set the all the learning rate(s) to the given value.
void SetLearningRate(float learning_rate)
{
ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
learning_rate_ = learning_rate;
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
for (auto &id : EnumerateLayers()) {
SetLayerLearningRate(id, learning_rate);
}
}
}
// Set the learning rate of the layer with id, by the given value.
void SetLayerLearningRate(const std::string &id, float learning_rate)
{
ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
ASSERT_HOST(id.length() > 1 && id[0] == ':');
auto *series = static_cast<Series *>(network_);
series->SetLayerLearningRate(&id[1], learning_rate);
}

// Converts the network to int if not already.
void ConvertToInt() {
if ((training_flags_ & TF_INT_MODE) == 0) {
Expand Down
8 changes: 8 additions & 0 deletions src/lstm/plumbing.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ class Plumbing : public Network {
ASSERT_HOST(lr_ptr != nullptr);
*lr_ptr *= factor;
}

// Set the learning rate for a specific layer of the stack to the given value.
void SetLayerLearningRate(const char *id, float learning_rate) {
float *lr_ptr = LayerLearningRatePtr(id);
ASSERT_HOST(lr_ptr != nullptr);
*lr_ptr = learning_rate;
}

// Returns a pointer to the learning rate for the given layer id.
TESS_API
float *LayerLearningRatePtr(const char *id);
Expand Down
6 changes: 6 additions & 0 deletions src/training/lstmtraining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ static INT_PARAM_FLAG(perfect_sample_delay, 0, "How many imperfect samples betwe
static DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.");
static DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights.");
static DOUBLE_PARAM_FLAG(learning_rate, 10.0e-4, "Weight factor for new deltas.");
static BOOL_PARAM_FLAG(reset_learning_rate, false,
"Resets all stored learning rates to the value specified by --learning_rate.");
static DOUBLE_PARAM_FLAG(momentum, 0.5, "Decay factor for repeating deltas.");
static DOUBLE_PARAM_FLAG(adam_beta, 0.999, "Decay factor for repeating deltas.");
static INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images.");
Expand Down Expand Up @@ -157,6 +159,10 @@ int main(int argc, char **argv) {
return EXIT_FAILURE;
}
tprintf("Continuing from %s\n", FLAGS_continue_from.c_str());
if (FLAGS_reset_learning_rate) {
trainer.SetLearningRate(FLAGS_learning_rate);
tprintf("Set learning rate to %f\n", static_cast<float>(FLAGS_learning_rate));
}
trainer.InitIterations();
}
if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
Expand Down

0 comments on commit ff1062d

Please sign in to comment.