Skip to content

Commit

Permalink
Update the overriden _lookup_dependency methods in Keras.
Browse files Browse the repository at this point in the history
A change in TF will be submitted after this to use the new `cached_dependencies` argument, which will vastly decrease TensorFlow-format checkpoint loading times.

PiperOrigin-RevId: 544131874
  • Loading branch information
k-w-w authored and tensorflower-gardener committed Jun 28, 2023
1 parent d0efc1d commit 4829ddf
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
4 changes: 3 additions & 1 deletion keras/engine/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3835,11 +3835,13 @@ def _trackable_children(self, save_type="checkpoint", **kwargs):
children.update(super()._trackable_children(save_type, **kwargs))
return children

def _lookup_dependency(self, name):
def _lookup_dependency(self, name, cached_dependencies=None):
# When loading from a Keras SavedModel load, make sure that the loader
# can find the random generator, otherwise the loader will assume that
# it does not exist, and will try to create a new generator.
if name == "_random_generator":
return self._random_generator
elif cached_dependencies is not None:
return cached_dependencies.get(name)
else:
return super()._lookup_dependency(name)
6 changes: 5 additions & 1 deletion keras/engine/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,11 @@ def _trackable_children(self, save_type="checkpoint", **kwargs):
dependencies.update(super()._trackable_children(save_type, **kwargs))
return dependencies

def _lookup_dependency(self, name):
def _lookup_dependency(self, name, cached_dependencies=None):
if cached_dependencies:
return cached_dependencies.get(name)
# Fall back to slow lookup (`layer_checkpoint_dependencies` does a
# thorough check of all layer to see if they contain weights.)
layer_dependencies = self._layer_checkpoint_dependencies
if name in layer_dependencies:
return layer_dependencies[name]
Expand Down
7 changes: 5 additions & 2 deletions keras/mixed_precision/loss_scale_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,12 @@ def _trackable_children(self, save_type="checkpoint", **kwargs):
weights.update(super()._trackable_children(save_type, **kwargs))
return weights

def _lookup_dependency(self, name):
def _lookup_dependency(self, name, cached_dependencies=None):
"""From Trackable. Find a weight in the current graph."""
unconditional = super()._lookup_dependency(name)
if cached_dependencies is not None:
unconditional = cached_dependencies.get(name)
else:
unconditional = super()._lookup_dependency(name)
if unconditional is not None:
return unconditional
if tf.executing_eagerly():
Expand Down

0 comments on commit 4829ddf

Please sign in to comment.