Skip to content

Commit

Permalink
Refactor freezing.
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel committed Oct 1, 2020
1 parent 2c8d6c3 commit 67082eb
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 51 deletions.
15 changes: 2 additions & 13 deletions training/deepspeech_training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .evaluate import evaluate
from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint
from .util.checkpoints import drop_freeze_number_to_layers, load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint
from .util.evaluate_tools import save_samples_json
from .util.feeding import create_dataset, audio_to_features, audiofile_to_features
from .util.flags import create_flags, FLAGS
Expand Down Expand Up @@ -325,18 +325,7 @@ def get_tower_results(iterator, optimizer, dropout_rates):

# Filter out layers if we want to freeze some
if FLAGS.freeze_source_layers > 0:
filter_vars = []
if FLAGS.freeze_source_layers <= 5:
filter_vars.append("layer_1")
if FLAGS.freeze_source_layers <= 4:
filter_vars.append("layer_2")
if FLAGS.freeze_source_layers <= 3:
filter_vars.append("layer_3")
if FLAGS.freeze_source_layers <= 2:
filter_vars.append("lstm")
if FLAGS.freeze_source_layers <= 1:
filter_vars.append("layer_5")

filter_vars = drop_freeze_number_to_layers(FLAGS.freeze_source_layers, "freeze")
new_train_vars = list(train_vars)
for fv in filter_vars:
for tv in train_vars:
Expand Down
78 changes: 41 additions & 37 deletions training/deepspeech_training/util/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sys
import tensorflow as tf

import tensorflow.compat.v1 as tfv1

from .flags import FLAGS
from .logging import log_info, log_error, log_warn
from .logging import log_error, log_info, log_warn


def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True):
Expand All @@ -19,47 +19,33 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
# compatibility with older checkpoints.
lr_var = set(v for v in load_vars if v.op.name == 'learning_rate')
if lr_var and ('learning_rate' not in vars_in_ckpt or
(FLAGS.force_initialize_learning_rate and allow_lr_init)):
(FLAGS.force_initialize_learning_rate and allow_lr_init)):
assert len(lr_var) <= 1
load_vars -= lr_var
init_vars |= lr_var

if FLAGS.load_cudnn:
# Initialize training from a CuDNN RNN checkpoint
# Identify the variables which we cannot load, and set them
# for initialization
missing_vars = set()
for v in load_vars:
if v.op.name not in vars_in_ckpt:
log_warn('CUDNN variable not found: %s' % (v.op.name))
missing_vars.add(v)
# After training with "freeze_source_layers" the Adam moment tensors for the frozen layers
# are missing because they were not used. This might also occur when loading a cudnn checkpoint
# Therefore we have to initialize them again to continue training on such checkpoints
print_msg = False
for v in load_vars:
if v.op.name not in vars_in_ckpt:
if 'Adam' in v.name:
init_vars.add(v)
print_msg = True
if print_msg:
msg = "Some Adam tensors are missing, they will be initialized automatically."
log_info(msg)
load_vars -= init_vars

load_vars -= init_vars

# Check that the only missing variables (i.e. those to be initialised)
# are the Adam moment tensors, if they aren't then we have an issue
missing_var_names = [v.op.name for v in missing_vars]
if any('Adam' not in v for v in missing_var_names):
log_error('Tried to load a CuDNN RNN checkpoint but there were '
'more missing variables than just the Adam moment '
'tensors. Missing variables: {}'.format(missing_var_names))
sys.exit(1)

if FLAGS.load_frozen_graph:
# After training with "freeze_source_layers" the Adam tensors for the frozen layers aren't
# existing anymore because they were not used
# Therefore we have to initialize them again to continue training on such checkpoints
if FLAGS.load_cudnn:
# Check all required tensors are included in the cudnn checkpoint we want to load
for v in load_vars:
if v.op.name not in vars_in_ckpt:
if 'Adam' in v.name:
init_vars.add(v)
else:
msg = "Tried to load a frozen checkpoint but there was a missing " \
"variable other than the Adam tensors: {}"
log_error(msg.format(v))
sys.exit(1)
load_vars -= init_vars
if v.op.name not in vars_in_ckpt or 'Adam' not in v.op.name:
msg = 'Tried to load a CuDNN RNN checkpoint but there was a missing' \
' variable other than an Adam moment tensor: {}'
log_error(msg.format(v.op.name))
sys.exit(1)

if allow_drop_layers and FLAGS.drop_source_layers > 0:
# This transfer learning approach requires supplying
Expand All @@ -74,7 +60,7 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
'dropping only 5 layers.')
FLAGS.drop_source_layers = 5

dropped_layers = ['2', '3', 'lstm', '5', '6'][-1 * int(FLAGS.drop_source_layers):]
dropped_layers = drop_freeze_number_to_layers(FLAGS.drop_source_layers, "drop")
# Initialize all variables needed for DS, but not loaded from ckpt
for v in load_vars:
if any(layer in v.op.name for layer in dropped_layers):
Expand All @@ -90,6 +76,24 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=
session.run(v.initializer)


def drop_freeze_number_to_layers(drop_freeze_number, mode):
""" Convert number of layers to drop or freeze into layer names """

if drop_freeze_number >= 6:
log_warn('The checkpoint only has 6 layers, but you are trying '
'to drop or freeze all of them or more. Continuing with 5 layers.')
drop_freeze_number = 5

layer_keys = ["layer_1", "layer_2", "layer_3", "lstm", "layer_5", "layer_6"]
if mode == "drop":
layer_keys = layer_keys[-1 * int(drop_freeze_number):]
elif mode == "freeze":
layer_keys = layer_keys[:-1 * int(drop_freeze_number)]
else:
raise ValueError
return layer_keys


def _checkpoint_path_or_none(checkpoint_filename):
checkpoint = tfv1.train.get_checkpoint_state(FLAGS.load_checkpoint_dir, checkpoint_filename)
if not checkpoint:
Expand Down
1 change: 0 additions & 1 deletion training/deepspeech_training/util/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def create_flags():

f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)')
f.DEFINE_integer('freeze_source_layers', 0, 'use same value as above to freeze the other layers')
f.DEFINE_boolean('load_frozen_graph', False, 'Needed to load a graph checkpoint which was trained with "freeze_source_layers" flag. Allows initialization of missing optimization tensors.')

# Exporting

Expand Down

0 comments on commit 67082eb

Please sign in to comment.