Skip to content

Commit

Permalink
Revert "Add resource initializer support (tensorflow#6826)"
Browse files Browse the repository at this point in the history
This is causing regression e2e tests to fail:

1) saved_model_v1_with_hashtable.
     #REGRESSION convert_predict webgl {"WEBGL_VERSION":2,"WEBGL_CPU_FORWARD":false,"WEBGL_SIZE_UPLOAD_UNIFORM":0}
     Error: Arrays differ: actual[0] = -1, expected[0] = 3.

To reproduce this, use node 16 in e2e/ and run `NIGHTLY=true ./scripts/test-ci.sh`, or, after running that to generate the required files, run `yarn karma start --tags '#REGRESSION'`.

This reverts commit 42dee16.
  • Loading branch information
mattsoulanille committed Oct 5, 2022
1 parent c5eb1d3 commit 4e80eed
Show file tree
Hide file tree
Showing 18 changed files with 256 additions and 1,092 deletions.
3 changes: 1 addition & 2 deletions e2e/integration_tests/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ export const CONVERT_PREDICT_MODELS = {
'saved_model_v1', 'saved_model_v2', 'saved_model_v2_with_control_flow',
'saved_model_with_conv2d', 'saved_model_with_prelu',
'saved_model_v2_complex64', 'saved_model_v2_with_control_flow_v2',
'saved_model_v2_with_tensorlist_ops', 'saved_model_v1_with_hashtable',
'saved_model_v2_with_hashtable'
'saved_model_v2_with_tensorlist_ops', 'saved_model_v1_with_hashtable'
],
layers_model: ['mobilenet']
};
Expand Down
43 changes: 0 additions & 43 deletions e2e/integration_tests/convert_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,47 +427,6 @@ def _create_saved_model_v1_with_hashtable(save_dir):
}
}

def _create_saved_model_v2_with_hashtable(save_dir):
"""Test a TF V2 model with HashTable Ops.
Args:
save_dir: directory name of where the saved model will be stored.
"""
class Table(tf.Module):
def __init__(self):
super(Table, self).__init__()
keys = tf.constant(['a', 'b'])
vals= tf.constant([0, 1])
init = tf.lookup.KeyValueTensorInitializer(keys, vals)
self.table = tf.lookup.StaticHashTable(init, -1)

def initializeTable(self):
@tf.function
def lookup(input):
return self.table.lookup(input)

return lookup

model = Table()
concrete_fn = model.initializeTable().get_concrete_function(
input=tf.TensorSpec([None], tf.string))

tf.saved_model.save(model, save_dir, signatures={"serving_default": concrete_fn})

return {
"async": False,
"inputs": {
"Placeholder:0": {
"value": ["a", "b", "c"], "shape": [3], "dtype": "string"
}
},
"outputs": {
"StatefulPartitionedCall/None_Lookup/LookupTableFindV2:0": {
"value": [0, 1, -1], "shape": [3], "dtype": "int32"
}
}
}

def _layers_mobilenet():
model = tf.keras.applications.MobileNetV2()
model_path = 'mobilenet'
Expand Down Expand Up @@ -512,8 +471,6 @@ def main():
'saved_model_v2_with_tensorlist_ops', control_flow_v2=True)
_save_and_convert_model(_create_saved_model_v1_with_hashtable,
'saved_model_v1_with_hashtable')
_save_and_convert_model(_create_saved_model_v2_with_hashtable,
'saved_model_v2_with_hashtable')

_layers_mobilenet()
if __name__ == '__main__':
Expand Down
5 changes: 5 additions & 0 deletions e2e/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,11 @@
dependencies:
detect-browser "*"

"@types/emscripten@~0.0.34":
version "0.0.34"
resolved "https://registry.yarnpkg.com/@types/emscripten/-/emscripten-0.0.34.tgz#12b4a344274fb102ff2f6c877b37587bc3e46008"
integrity sha512-QSb9ojDincskc+uKMI0KXp8e1NALFINCrMlp8VGKGcTSxeEyRTTKyjWw75NYrCZHUsVEEEpr1tYHpbtaC++/sQ==

"@types/jasmine@~3.0.0":
version "3.0.0"
resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-3.0.0.tgz#9a6b6755a02fcd6baa088a767557709c79728f98"
Expand Down
2 changes: 1 addition & 1 deletion tfjs-converter/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ flax>=0.5.3
jax>=0.3.16
importlib_resources>=5.9.0
protobuf<3.20,>=3.9.2
tensorflow>=2.10.0,<3
tensorflow>=2.1.0,<3
six>=1.12.0,<2
tensorflow-hub>=0.7.0,<0.13; python_version >= "3"
packaging~=20.9
2 changes: 0 additions & 2 deletions tfjs-converter/python/tensorflowjs/converters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@
CONVERTED_BY_KEY = 'convertedBy'

SIGNATURE_KEY = 'signature'
INITIALIZER_SIGNATURE_KEY = 'initializerSignature'
USER_DEFINED_METADATA_KEY = 'userDefinedMetadata'
STRUCTURED_OUTPUTS_KEYS_KEY = 'structuredOutputKeys'
RESOURCE_ID_KEY = 'resourceId'

# Model formats.
KERAS_SAVED_MODEL = 'keras_saved_model'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import device_properties_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.checkpoint.trackable_view import TrackableView
from tensorflow.python.eager import context
from tensorflow.python.framework import convert_to_constants
from tensorflow.python.grappler import cluster as gcluster
Expand All @@ -39,7 +38,6 @@
from tensorflow.python.saved_model import loader
from tensorflow.python.training.saver import export_meta_graph
from tensorflow.python.tools.saved_model_cli import get_signature_def_map
from tensorflow.saved_model.experimental import TrackableResource
from google.protobuf.json_format import MessageToDict
import tensorflow_hub as hub
from packaging import version
Expand Down Expand Up @@ -127,7 +125,6 @@ def optimize_graph(graph, signature_def, output_graph,
weight_shard_size_bytes=1024 * 1024 * 4,
experiments=False,
initializer_graph=None,
resource_ids_maps=None,
metadata=None):
"""Takes a Python Graph object and optimizes the graph.
Expand All @@ -144,9 +141,6 @@ def optimize_graph(graph, signature_def, output_graph,
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
initializer_graph: The frozen graph for initializers.
resource_ids_maps: Tuple of two dictionaries, one
mapping inference input names to resource id, and the other
mapping initializer output names to resource id.
metadata: User defined metadata map.
"""

Expand Down Expand Up @@ -217,17 +211,13 @@ def optimize_graph(graph, signature_def, output_graph,
', '.join(unsupported))

initializer_graph_def = None
initializer_signature_def = None
if initializer_graph:
initializer_graph_def = initializer_graph.as_graph_def()
if hasattr(initializer_graph, 'outputs'):
initializer_signature_def = _build_signature_def(initializer_graph, [], initializer_graph.outputs)

extract_weights(
optimized_graph, output_graph, tf_version,
signature_def, quantization_dtype_map, weight_shard_size_bytes,
initializer_graph_def, initializer_signature_def,
resource_ids_maps=resource_ids_maps, metadata=metadata)
initializer_graph_def, metadata=metadata)

def extract_const_nodes(nodes):
"""Takes a list of nodes and extract the weights. Return weight manifest
Expand Down Expand Up @@ -266,8 +256,6 @@ def extract_weights(graph_def,
quantization_dtype_map=None,
weight_shard_size_bytes=1024 * 1024 * 4,
initializer_graph_def=None,
initializer_signature_def=None,
resource_ids_maps=None,
metadata=None):
"""Takes a Python GraphDef object and extract the weights.
Expand All @@ -283,10 +271,6 @@ def extract_weights(graph_def,
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
initializer_graph_def: tf.GraphDef proto object for initializer graph.
initializer_signature_def: the SignatureDef of the initializer graph.
resource_ids_maps: Tuple of two dictionaries, one
mapping inference input names to resource id, and the other
mapping initializer output names to resource id.
metadata: User defined metadata map.
"""
global_manifest = extract_const_nodes(graph_def.node)
Expand Down Expand Up @@ -314,8 +298,6 @@ def extract_weights(graph_def,
quantization_dtype_map=quantization_dtype_map,
weight_shard_size_bytes=weight_shard_size_bytes,
initializer_graph_def=initializer_graph_def,
initializer_signature_def=initializer_signature_def,
resource_ids_maps=resource_ids_maps,
metadata=metadata)

def write_artifacts(topology,
Expand All @@ -326,8 +308,6 @@ def write_artifacts(topology,
quantization_dtype_map=None,
weight_shard_size_bytes=1024 * 1024 * 4,
initializer_graph_def=None,
initializer_signature_def=None,
resource_ids_maps=None,
metadata=None):
"""Writes weights and topology to the output_dir.
Expand All @@ -346,10 +326,6 @@ def write_artifacts(topology,
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
initializer_graph_def: tf.GraphDef proto object for initializer graph.
initializer_signature_def: the SignatureDef of the initializer graph.
resource_ids_maps: Tuple of two dictionaries, one
mapping inference input names to resource id, and the other
mapping initializer output names to resource id.
metadata: User defined metadata map.
"""
model_json = {
Expand All @@ -367,30 +343,6 @@ def write_artifacts(topology,
if initializer_graph_def and initializer_graph_def.node:
model_json[common.ARTIFACT_MODEL_INITIALIZER] = MessageToDict(
initializer_graph_def)
if initializer_signature_def:
model_json[common.INITIALIZER_SIGNATURE_KEY] = MessageToDict(
initializer_signature_def)

# Assign resource ids to inference inputs and initializer outputs. In
# TensorFlow, both inference and initializer graphs have a reference
# to the common resource (so initializer runs on reference, and then inference
# graph uses it). We are doing something similar but instead of assigning
# a reference to the resource in the serialized graph, we assign the id
# of the resource, and then we can recreate the common reference in javascript
# by matching resource ids.
if resource_ids_maps is not None:
model_input_to_resource_id, init_output_to_resource_id = resource_ids_maps
signature_inputs = model_json[common.SIGNATURE_KEY]['inputs']
initializer_signature_outputs = model_json[common.INITIALIZER_SIGNATURE_KEY]['outputs']

for (input, resource_id) in model_input_to_resource_id.items():
if input in signature_inputs:
signature_inputs[input][common.RESOURCE_ID_KEY] = resource_id

for (output, resource_id) in init_output_to_resource_id.items():
if output in initializer_signature_outputs:
initializer_signature_outputs[output][common.RESOURCE_ID_KEY] = resource_id


weights_manifest = write_weights.write_weights(
weights, os.path.dirname(output_graph), write_manifest=False,
Expand Down Expand Up @@ -598,108 +550,6 @@ def _find_signature(saved_model_dir, saved_model_tags, signature_def):

return signature_def_map[signature_def]

def _get_resource_initializer_concrete_function(model):
"""Create a tf.function that creates and initializes all the resources used by the model.
For more information on resources, please see the TensorFlow code:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/trackable/resource.py#L232
Args:
model: Loaded saved model.
Returns:
Nullable. A concrete function.
"""
trackable_view = TrackableView(model)
model_resources = [obj for obj in trackable_view.descendants() if isinstance(obj, TrackableResource)]

if not model_resources:
return None

# A list holding tuples of (TrackableResource, captured_input_index) where
# TrackableResource represents one resource in the model
# (a hash table for example), and captured_input_index is the resource
# initialization function's captured input index corresponding
# to the TrackableResource. Captured inputs are simply inputs not provided
# directly be user, but by the model.
model_resources_with_captured_input_index = []
for model_resource in model_resources:
# A runtime id that is unique across different resources, and constant
# across graphs.
resource_handle_id = model_resource.resource_handle._id
# the _initialize function initializes the resource, so one of its captured
# inputs must be the resource, so search for that input.
captured_inputs = model_resource._initialize.get_concrete_function()._captured_inputs
for captured_input_index in range(len(captured_inputs)):
if captured_inputs[captured_input_index]._id == resource_handle_id:
model_resources_with_captured_input_index.append((model_resource, captured_input_index))

@tf.function()
def resource_initializer():
# Recreate resources to capture them in this tf.function.
new_resources = []
for (model_resource, captured_input_index) in model_resources_with_captured_input_index:
# Make a new resource (that is identical to the old, but captured in
# this functon only).
new_resource = model_resource._create_resource()
new_resources.append(new_resource)

# Since we precomputed the captured input corresponding to this resource,
# we can directly replace it with the copy new_resource. If we don't do
# this, then _initialize will not get capture in this graph since the
# old resource was already initialized in TF model load.
model_resource._initialize.get_concrete_function()._captured_inputs[captured_input_index] = new_resource
model_resource._initialize()

return new_resources

# Add resource_initializer to the output graph.
return resource_initializer.get_concrete_function()

def _get_resource_ids_maps(model, concrete_func, resource_init_concrete_func):
"""Generates dictionaries that map tensor names to the loaded saved model resource id,
allowing for matching of initializer outputs to inference inputs.
Args:
model: Loaded saved model.
concrete_func: Concrete function of the inference graph.
resource_init_concrete_func: Concrete function of the initializer graph.
Returns:
A dictionary mapping inference input names to resource id.
A dictionary mapping initializer output names to resource id.
"""
trackable_view = TrackableView(model)
model_resources = [obj for obj in trackable_view.descendants() if isinstance(obj, TrackableResource)]


# Each resource has a unique runtime resource id associated with it which
# can be used across graphs, so we extract it here from inference
# graph for use later.
resource_id_to_captured_input_index = {
captured_input._id : captured_input_index for \
captured_input_index, captured_input in \
enumerate(concrete_func._captured_inputs)
}
# Captured inputs always come after user provided inputs.
captured_input_index_offset = len(concrete_func.inputs) - len(concrete_func._captured_inputs)

model_input_to_resource_id = {}
init_output_to_resource_id = {}
for i, resource in enumerate(model_resources):
_id = resource.resource_handle._id
# Get input from inference graph corresponding to this resource.
captured_input_index = resource_id_to_captured_input_index[_id]
model_input = concrete_func.inputs[captured_input_index + captured_input_index_offset]

# Get output from initializer graph corresponding to this resource.
init_output = resource_init_concrete_func.outputs[i]

# Match both with the same id (initializer output will be passed in to
# corresponding input in inference input).
model_input_to_resource_id[model_input.name] = _id
init_output_to_resource_id[init_output.name] = _id

return (model_input_to_resource_id, init_output_to_resource_id)

def _convert_tf_saved_model(output_dir,
saved_model_dir=None,
keras_model=None,
Expand Down Expand Up @@ -813,15 +663,8 @@ def _convert_tf_saved_model(output_dir,
# reliable way. Try to freeze the graph using V2 utils. If that fails, freeze
# the graph using V1 utils.
frozen_initializer_graph = None
resource_ids_maps = None
try:
frozen_graph = _freeze_saved_model_v2(concrete_func, control_flow_v2)
resource_initializer_concrete_func = _get_resource_initializer_concrete_function(model)

if resource_initializer_concrete_func:
frozen_initializer_graph = _freeze_saved_model_v2(resource_initializer_concrete_func, control_flow_v2)
resource_ids_maps = _get_resource_ids_maps(model, concrete_func, resource_initializer_concrete_func)

except BaseException:
if saved_model_dir:
(frozen_graph,
Expand All @@ -839,8 +682,9 @@ def _convert_tf_saved_model(output_dir,
with tf.compat.v1.gfile.GFile(frozen_file, 'wb') as f:
f.write(frozen_graph.as_graph_def().SerializeToString())

inputs = [x for x in concrete_func.inputs if not x.dtype == 'resource']
signature = _build_signature_def(
frozen_graph, concrete_func.inputs, concrete_func.outputs, saved_model_sigature)
frozen_graph, inputs, concrete_func.outputs, saved_model_sigature)

define_transform_graph_func()

Expand All @@ -860,7 +704,6 @@ def _convert_tf_saved_model(output_dir,
weight_shard_size_bytes=weight_shard_size_bytes,
experiments=experiments,
initializer_graph=frozen_initializer_graph,
resource_ids_maps=resource_ids_maps,
metadata=metadata)

def define_transform_graph_func():
Expand Down
Loading

0 comments on commit 4e80eed

Please sign in to comment.