Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add resource initializer support #6826

Merged
merged 6 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion e2e/integration_tests/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ export const CONVERT_PREDICT_MODELS = {
'saved_model_v1', 'saved_model_v2', 'saved_model_v2_with_control_flow',
'saved_model_with_conv2d', 'saved_model_with_prelu',
'saved_model_v2_complex64', 'saved_model_v2_with_control_flow_v2',
'saved_model_v2_with_tensorlist_ops', 'saved_model_v1_with_hashtable'
'saved_model_v2_with_tensorlist_ops', 'saved_model_v1_with_hashtable',
'saved_model_v2_with_hashtable'
],
layers_model: ['mobilenet']
};
Expand Down
43 changes: 43 additions & 0 deletions e2e/integration_tests/convert_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,47 @@ def _create_saved_model_v1_with_hashtable(save_dir):
}
}

def _create_saved_model_v2_with_hashtable(save_dir):
"""Test a TF V2 model with HashTable Ops.

Args:
save_dir: directory name of where the saved model will be stored.
"""
class Table(tf.Module):
def __init__(self):
super(Table, self).__init__()
keys = tf.constant(['a', 'b'])
vals= tf.constant([0, 1])
init = tf.lookup.KeyValueTensorInitializer(keys, vals)
self.table = tf.lookup.StaticHashTable(init, -1)

def initializeTable(self):
@tf.function
def lookup(input):
return self.table.lookup(input)

return lookup

model = Table()
concrete_fn = model.initializeTable().get_concrete_function(
input=tf.TensorSpec([None], tf.string))

tf.saved_model.save(model, save_dir, signatures={"serving_default": concrete_fn})

return {
"async": False,
"inputs": {
"Placeholder:0": {
"value": ["a", "b", "c"], "shape": [3], "dtype": "string"
}
},
"outputs": {
"StatefulPartitionedCall/None_Lookup/LookupTableFindV2:0": {
"value": [0, 1, -1], "shape": [3], "dtype": "int32"
}
}
}

def _layers_mobilenet():
model = tf.keras.applications.MobileNetV2()
model_path = 'mobilenet'
Expand Down Expand Up @@ -471,6 +512,8 @@ def main():
'saved_model_v2_with_tensorlist_ops', control_flow_v2=True)
_save_and_convert_model(_create_saved_model_v1_with_hashtable,
'saved_model_v1_with_hashtable')
_save_and_convert_model(_create_saved_model_v2_with_hashtable,
'saved_model_v2_with_hashtable')

_layers_mobilenet()
if __name__ == '__main__':
Expand Down
5 changes: 0 additions & 5 deletions e2e/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1012,11 +1012,6 @@
dependencies:
detect-browser "*"

"@types/emscripten@~0.0.34":
version "0.0.34"
resolved "https://registry.yarnpkg.com/@types/emscripten/-/emscripten-0.0.34.tgz#12b4a344274fb102ff2f6c877b37587bc3e46008"
integrity sha512-QSb9ojDincskc+uKMI0KXp8e1NALFINCrMlp8VGKGcTSxeEyRTTKyjWw75NYrCZHUsVEEEpr1tYHpbtaC++/sQ==

"@types/jasmine@~3.0.0":
version "3.0.0"
resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-3.0.0.tgz#9a6b6755a02fcd6baa088a767557709c79728f98"
Expand Down
2 changes: 1 addition & 1 deletion tfjs-converter/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ flax>=0.5.3
jax>=0.3.16
importlib_resources>=5.9.0
protobuf<3.20,>=3.9.2
tensorflow>=2.1.0,<3
tensorflow>=2.10.0,<3
six>=1.12.0,<2
tensorflow-hub>=0.7.0,<0.13; python_version >= "3"
packaging~=20.9
2 changes: 2 additions & 0 deletions tfjs-converter/python/tensorflowjs/converters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
CONVERTED_BY_KEY = 'convertedBy'

SIGNATURE_KEY = 'signature'
INITIALIZER_SIGNATURE_KEY = 'initializerSignature'
USER_DEFINED_METADATA_KEY = 'userDefinedMetadata'
STRUCTURED_OUTPUTS_KEYS_KEY = 'structuredOutputKeys'
RESOURCE_ID_KEY = 'resourceId'

# Model formats.
KERAS_SAVED_MODEL = 'keras_saved_model'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import device_properties_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.checkpoint.trackable_view import TrackableView
from tensorflow.python.eager import context
from tensorflow.python.framework import convert_to_constants
from tensorflow.python.grappler import cluster as gcluster
Expand All @@ -38,6 +39,7 @@
from tensorflow.python.saved_model import loader
from tensorflow.python.training.saver import export_meta_graph
from tensorflow.python.tools.saved_model_cli import get_signature_def_map
from tensorflow.saved_model.experimental import TrackableResource
from google.protobuf.json_format import MessageToDict
import tensorflow_hub as hub
from packaging import version
Expand Down Expand Up @@ -125,6 +127,7 @@ def optimize_graph(graph, signature_def, output_graph,
weight_shard_size_bytes=1024 * 1024 * 4,
experiments=False,
initializer_graph=None,
resource_ids_maps=None,
metadata=None):
"""Takes a Python Graph object and optimizes the graph.

Expand All @@ -141,6 +144,9 @@ def optimize_graph(graph, signature_def, output_graph,
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
initializer_graph: The frozen graph for initializers.
resource_ids_maps: Tuple of two dictionaries, one
mapping inference input names to resource id, and the other
mapping initializer output names to resource id.
metadata: User defined metadata map.
"""

Expand Down Expand Up @@ -211,13 +217,17 @@ def optimize_graph(graph, signature_def, output_graph,
', '.join(unsupported))

initializer_graph_def = None
initializer_signature_def = None
if initializer_graph:
initializer_graph_def = initializer_graph.as_graph_def()
if hasattr(initializer_graph, 'outputs'):
initializer_signature_def = _build_signature_def(initializer_graph, [], initializer_graph.outputs)

extract_weights(
optimized_graph, output_graph, tf_version,
signature_def, quantization_dtype_map, weight_shard_size_bytes,
initializer_graph_def, metadata=metadata)
initializer_graph_def, initializer_signature_def,
resource_ids_maps=resource_ids_maps, metadata=metadata)

def extract_const_nodes(nodes):
"""Takes a list of nodes and extract the weights. Return weight manifest
Expand Down Expand Up @@ -256,6 +266,8 @@ def extract_weights(graph_def,
quantization_dtype_map=None,
weight_shard_size_bytes=1024 * 1024 * 4,
initializer_graph_def=None,
initializer_signature_def=None,
resource_ids_maps=None,
metadata=None):
"""Takes a Python GraphDef object and extract the weights.

Expand All @@ -271,6 +283,10 @@ def extract_weights(graph_def,
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
initializer_graph_def: tf.GraphDef proto object for initializer graph.
initializer_signature_def: the SignatureDef of the initializer graph.
resource_ids_maps: Tuple of two dictionaries, one
mapping inference input names to resource id, and the other
mapping initializer output names to resource id.
metadata: User defined metadata map.
"""
global_manifest = extract_const_nodes(graph_def.node)
Expand Down Expand Up @@ -298,6 +314,8 @@ def extract_weights(graph_def,
quantization_dtype_map=quantization_dtype_map,
weight_shard_size_bytes=weight_shard_size_bytes,
initializer_graph_def=initializer_graph_def,
initializer_signature_def=initializer_signature_def,
resource_ids_maps=resource_ids_maps,
metadata=metadata)

def write_artifacts(topology,
Expand All @@ -308,6 +326,8 @@ def write_artifacts(topology,
quantization_dtype_map=None,
weight_shard_size_bytes=1024 * 1024 * 4,
initializer_graph_def=None,
initializer_signature_def=None,
resource_ids_maps=None,
metadata=None):
"""Writes weights and topology to the output_dir.

Expand All @@ -326,6 +346,10 @@ def write_artifacts(topology,
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
initializer_graph_def: tf.GraphDef proto object for initializer graph.
initializer_signature_def: the SignatureDef of the initializer graph.
resource_ids_maps: Tuple of two dictionaries, one
mapping inference input names to resource id, and the other
mapping initializer output names to resource id.
metadata: User defined metadata map.
"""
model_json = {
Expand All @@ -343,6 +367,30 @@ def write_artifacts(topology,
if initializer_graph_def and initializer_graph_def.node:
model_json[common.ARTIFACT_MODEL_INITIALIZER] = MessageToDict(
initializer_graph_def)
if initializer_signature_def:
model_json[common.INITIALIZER_SIGNATURE_KEY] = MessageToDict(
initializer_signature_def)

# Assign resource ids to inference inputs and initializer outputs. In
# TensorFlow, both inference and initializer graphs have a reference
# to the common resource (so initializer runs on reference, and then inference
# graph uses it). We are doing something similar but instead of assigning
# a reference to the resource in the serialized graph, we assign the id
# of the resource, and then we can recreate the common reference in javascript
# by matching resource ids.
if resource_ids_maps is not None:
model_input_to_resource_id, init_output_to_resource_id = resource_ids_maps
signature_inputs = model_json[common.SIGNATURE_KEY]['inputs']
initializer_signature_outputs = model_json[common.INITIALIZER_SIGNATURE_KEY]['outputs']

for (input, resource_id) in model_input_to_resource_id.items():
if input in signature_inputs:
signature_inputs[input][common.RESOURCE_ID_KEY] = resource_id

for (output, resource_id) in init_output_to_resource_id.items():
if output in initializer_signature_outputs:
initializer_signature_outputs[output][common.RESOURCE_ID_KEY] = resource_id


weights_manifest = write_weights.write_weights(
weights, os.path.dirname(output_graph), write_manifest=False,
Expand Down Expand Up @@ -550,6 +598,108 @@ def _find_signature(saved_model_dir, saved_model_tags, signature_def):

return signature_def_map[signature_def]

def _get_resource_initializer_concrete_function(model):
"""Create a tf.function that creates and initializes all the resources used by the model.
For more information on resources, please see the TensorFlow code:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/trackable/resource.py#L232
Args:
model: Loaded saved model.

Returns:
Nullable. A concrete function.
"""
trackable_view = TrackableView(model)
model_resources = [obj for obj in trackable_view.descendants() if isinstance(obj, TrackableResource)]

if not model_resources:
return None

# A list holding tuples of (TrackableResource, captured_input_index) where
# TrackableResource represents one resource in the model
# (a hash table for example), and captured_input_index is the resource
# initialization function's captured input index corresponding
# to the TrackableResource. Captured inputs are simply inputs not provided
# directly be user, but by the model.
model_resources_with_captured_input_index = []
for model_resource in model_resources:
# A runtime id that is unique across different resources, and constant
# across graphs.
resource_handle_id = model_resource.resource_handle._id
# the _initialize function initializes the resource, so one of its captured
# inputs must be the resource, so search for that input.
captured_inputs = model_resource._initialize.get_concrete_function()._captured_inputs
for captured_input_index in range(len(captured_inputs)):
if captured_inputs[captured_input_index]._id == resource_handle_id:
model_resources_with_captured_input_index.append((model_resource, captured_input_index))

@tf.function()
def resource_initializer():
# Recreate resources to capture them in this tf.function.
new_resources = []
for (model_resource, captured_input_index) in model_resources_with_captured_input_index:
# Make a new resource (that is identical to the old, but captured in
# this functon only).
new_resource = model_resource._create_resource()
new_resources.append(new_resource)

# Since we precomputed the captured input corresponding to this resource,
# we can directly replace it with the copy new_resource. If we don't do
# this, then _initialize will not get capture in this graph since the
# old resource was already initialized in TF model load.
model_resource._initialize.get_concrete_function()._captured_inputs[captured_input_index] = new_resource
model_resource._initialize()

return new_resources

# Add resource_initializer to the output graph.
return resource_initializer.get_concrete_function()

def _get_resource_ids_maps(model, concrete_func, resource_init_concrete_func):
"""Generates dictionaries that map tensor names to the loaded saved model resource id,
allowing for matching of initializer outputs to inference inputs.

Args:
model: Loaded saved model.
concrete_func: Concrete function of the inference graph.
resource_init_concrete_func: Concrete function of the initializer graph.

Returns:
A dictionary mapping inference input names to resource id.
A dictionary mapping initializer output names to resource id.
"""
trackable_view = TrackableView(model)
model_resources = [obj for obj in trackable_view.descendants() if isinstance(obj, TrackableResource)]


# Each resource has a unique runtime resource id associated with it which
# can be used across graphs, so we extract it here from inference
# graph for use later.
resource_id_to_captured_input_index = {
captured_input._id : captured_input_index for \
captured_input_index, captured_input in \
enumerate(concrete_func._captured_inputs)
}
# Captured inputs always come after user provided inputs.
captured_input_index_offset = len(concrete_func.inputs) - len(concrete_func._captured_inputs)

model_input_to_resource_id = {}
init_output_to_resource_id = {}
for i, resource in enumerate(model_resources):
_id = resource.resource_handle._id
# Get input from inference graph corresponding to this resource.
captured_input_index = resource_id_to_captured_input_index[_id]
model_input = concrete_func.inputs[captured_input_index + captured_input_index_offset]

# Get output from initializer graph corresponding to this resource.
init_output = resource_init_concrete_func.outputs[i]

# Match both with the same id (initializer output will be passed in to
# corresponding input in inference input).
model_input_to_resource_id[model_input.name] = _id
init_output_to_resource_id[init_output.name] = _id

return (model_input_to_resource_id, init_output_to_resource_id)

def _convert_tf_saved_model(output_dir,
saved_model_dir=None,
keras_model=None,
Expand Down Expand Up @@ -663,8 +813,15 @@ def _convert_tf_saved_model(output_dir,
# reliable way. Try to freeze the graph using V2 utils. If that fails, freeze
# the graph using V1 utils.
frozen_initializer_graph = None
resource_ids_maps = None
try:
frozen_graph = _freeze_saved_model_v2(concrete_func, control_flow_v2)
resource_initializer_concrete_func = _get_resource_initializer_concrete_function(model)

if resource_initializer_concrete_func:
frozen_initializer_graph = _freeze_saved_model_v2(resource_initializer_concrete_func, control_flow_v2)
resource_ids_maps = _get_resource_ids_maps(model, concrete_func, resource_initializer_concrete_func)

except BaseException:
if saved_model_dir:
(frozen_graph,
Expand All @@ -682,9 +839,8 @@ def _convert_tf_saved_model(output_dir,
with tf.compat.v1.gfile.GFile(frozen_file, 'wb') as f:
f.write(frozen_graph.as_graph_def().SerializeToString())

inputs = [x for x in concrete_func.inputs if not x.dtype == 'resource']
signature = _build_signature_def(
frozen_graph, inputs, concrete_func.outputs, saved_model_sigature)
frozen_graph, concrete_func.inputs, concrete_func.outputs, saved_model_sigature)

define_transform_graph_func()

Expand All @@ -704,6 +860,7 @@ def _convert_tf_saved_model(output_dir,
weight_shard_size_bytes=weight_shard_size_bytes,
experiments=experiments,
initializer_graph=frozen_initializer_graph,
resource_ids_maps=resource_ids_maps,
metadata=metadata)

def define_transform_graph_func():
Expand Down
Loading