From 4e80eed45ed17aa7a6248739efc99aa7a55ddf79 Mon Sep 17 00:00:00 2001 From: Matthew Soulanille Date: Wed, 5 Oct 2022 10:08:37 -0700 Subject: [PATCH] Revert "Add resource initializer support (#6826)" This is causing regression e2e tests to fail: 1) saved_model_v1_with_hashtable. #REGRESSION convert_predict webgl {"WEBGL_VERSION":2,"WEBGL_CPU_FORWARD":false,"WEBGL_SIZE_UPLOAD_UNIFORM":0} Error: Arrays differ: actual[0] = -1, expected[0] = 3. To reproduce this, use node 16 in e2e/ and run `NIGHTLY=true ./scripts/test-ci.sh`, or, after running that to generate the required files, run `yarn karma start --tags '#REGRESSION'`. This reverts commit 42dee166e0c8681c129d0c952dd7f9540e0ed4cb. --- e2e/integration_tests/constants.ts | 3 +- e2e/integration_tests/convert_predict.py | 43 -- e2e/yarn.lock | 5 + tfjs-converter/python/requirements.txt | 2 +- .../python/tensorflowjs/converters/common.py | 2 - .../tf_saved_model_conversion_v2.py | 163 +----- .../tf_saved_model_conversion_v2_test.py | 116 ----- tfjs-converter/src/data/compiled_api.ts | 5 +- tfjs-converter/src/executor/graph_model.ts | 89 +--- .../src/executor/graph_model_test.ts | 54 +- .../test_data/hash_table_v2_model_loader.ts | 342 ------------- tfjs-converter/yarn.lock | 27 +- tfjs-core/src/io/io_utils.ts | 6 - tfjs-core/src/io/local_storage.ts | 6 - tfjs-core/src/io/local_storage_test.ts | 463 +++++++++--------- tfjs-core/src/io/types.ts | 10 - tfjs-inference/src/file_handler.ts | 5 - tfjs-inference/src/file_handler_test.ts | 7 +- 18 files changed, 256 insertions(+), 1092 deletions(-) delete mode 100644 tfjs-converter/src/executor/test_data/hash_table_v2_model_loader.ts diff --git a/e2e/integration_tests/constants.ts b/e2e/integration_tests/constants.ts index 7fea3bdd272..33817990d25 100644 --- a/e2e/integration_tests/constants.ts +++ b/e2e/integration_tests/constants.ts @@ -37,8 +37,7 @@ export const CONVERT_PREDICT_MODELS = { 'saved_model_v1', 'saved_model_v2', 'saved_model_v2_with_control_flow', 'saved_model_with_conv2d', 'saved_model_with_prelu', 'saved_model_v2_complex64', 'saved_model_v2_with_control_flow_v2', - 'saved_model_v2_with_tensorlist_ops', 'saved_model_v1_with_hashtable', - 'saved_model_v2_with_hashtable' + 'saved_model_v2_with_tensorlist_ops', 'saved_model_v1_with_hashtable' ], layers_model: ['mobilenet'] }; diff --git a/e2e/integration_tests/convert_predict.py b/e2e/integration_tests/convert_predict.py index 51793100702..93987352ea0 100644 --- a/e2e/integration_tests/convert_predict.py +++ b/e2e/integration_tests/convert_predict.py @@ -427,47 +427,6 @@ def _create_saved_model_v1_with_hashtable(save_dir): } } -def _create_saved_model_v2_with_hashtable(save_dir): - """Test a TF V2 model with HashTable Ops. - - Args: - save_dir: directory name of where the saved model will be stored. - """ - class Table(tf.Module): - def __init__(self): - super(Table, self).__init__() - keys = tf.constant(['a', 'b']) - vals= tf.constant([0, 1]) - init = tf.lookup.KeyValueTensorInitializer(keys, vals) - self.table = tf.lookup.StaticHashTable(init, -1) - - def initializeTable(self): - @tf.function - def lookup(input): - return self.table.lookup(input) - - return lookup - - model = Table() - concrete_fn = model.initializeTable().get_concrete_function( - input=tf.TensorSpec([None], tf.string)) - - tf.saved_model.save(model, save_dir, signatures={"serving_default": concrete_fn}) - - return { - "async": False, - "inputs": { - "Placeholder:0": { - "value": ["a", "b", "c"], "shape": [3], "dtype": "string" - } - }, - "outputs": { - "StatefulPartitionedCall/None_Lookup/LookupTableFindV2:0": { - "value": [0, 1, -1], "shape": [3], "dtype": "int32" - } - } - } - def _layers_mobilenet(): model = tf.keras.applications.MobileNetV2() model_path = 'mobilenet' @@ -512,8 +471,6 @@ def main(): 'saved_model_v2_with_tensorlist_ops', control_flow_v2=True) _save_and_convert_model(_create_saved_model_v1_with_hashtable, 'saved_model_v1_with_hashtable') - _save_and_convert_model(_create_saved_model_v2_with_hashtable, - 'saved_model_v2_with_hashtable') _layers_mobilenet() if __name__ == '__main__': diff --git a/e2e/yarn.lock b/e2e/yarn.lock index f8391226e9b..98ce530eec9 100644 --- a/e2e/yarn.lock +++ b/e2e/yarn.lock @@ -1012,6 +1012,11 @@ dependencies: detect-browser "*" +"@types/emscripten@~0.0.34": + version "0.0.34" + resolved "https://registry.yarnpkg.com/@types/emscripten/-/emscripten-0.0.34.tgz#12b4a344274fb102ff2f6c877b37587bc3e46008" + integrity sha512-QSb9ojDincskc+uKMI0KXp8e1NALFINCrMlp8VGKGcTSxeEyRTTKyjWw75NYrCZHUsVEEEpr1tYHpbtaC++/sQ== + "@types/jasmine@~3.0.0": version "3.0.0" resolved "https://registry.yarnpkg.com/@types/jasmine/-/jasmine-3.0.0.tgz#9a6b6755a02fcd6baa088a767557709c79728f98" diff --git a/tfjs-converter/python/requirements.txt b/tfjs-converter/python/requirements.txt index 8f834d9e9ec..4d284b276f5 100644 --- a/tfjs-converter/python/requirements.txt +++ b/tfjs-converter/python/requirements.txt @@ -2,7 +2,7 @@ flax>=0.5.3 jax>=0.3.16 importlib_resources>=5.9.0 protobuf<3.20,>=3.9.2 -tensorflow>=2.10.0,<3 +tensorflow>=2.1.0,<3 six>=1.12.0,<2 tensorflow-hub>=0.7.0,<0.13; python_version >= "3" packaging~=20.9 diff --git a/tfjs-converter/python/tensorflowjs/converters/common.py b/tfjs-converter/python/tensorflowjs/converters/common.py index fbbfe74b08a..8a81c1b6e21 100644 --- a/tfjs-converter/python/tensorflowjs/converters/common.py +++ b/tfjs-converter/python/tensorflowjs/converters/common.py @@ -31,10 +31,8 @@ CONVERTED_BY_KEY = 'convertedBy' SIGNATURE_KEY = 'signature' -INITIALIZER_SIGNATURE_KEY = 'initializerSignature' USER_DEFINED_METADATA_KEY = 'userDefinedMetadata' STRUCTURED_OUTPUTS_KEYS_KEY = 'structuredOutputKeys' -RESOURCE_ID_KEY = 'resourceId' # Model formats. KERAS_SAVED_MODEL = 'keras_saved_model' diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py index 0d39d4baa38..8be0f26a128 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py @@ -27,7 +27,6 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import device_properties_pb2 from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.python.checkpoint.trackable_view import TrackableView from tensorflow.python.eager import context from tensorflow.python.framework import convert_to_constants from tensorflow.python.grappler import cluster as gcluster @@ -39,7 +38,6 @@ from tensorflow.python.saved_model import loader from tensorflow.python.training.saver import export_meta_graph from tensorflow.python.tools.saved_model_cli import get_signature_def_map -from tensorflow.saved_model.experimental import TrackableResource from google.protobuf.json_format import MessageToDict import tensorflow_hub as hub from packaging import version @@ -127,7 +125,6 @@ def optimize_graph(graph, signature_def, output_graph, weight_shard_size_bytes=1024 * 1024 * 4, experiments=False, initializer_graph=None, - resource_ids_maps=None, metadata=None): """Takes a Python Graph object and optimizes the graph. @@ -144,9 +141,6 @@ def optimize_graph(graph, signature_def, output_graph, weight_shard_size_bytes: Shard size (in bytes) of the weight files. The size of each weight file will be <= this value. initializer_graph: The frozen graph for initializers. - resource_ids_maps: Tuple of two dictionaries, one - mapping inference input names to resource id, and the other - mapping initializer output names to resource id. metadata: User defined metadata map. """ @@ -217,17 +211,13 @@ def optimize_graph(graph, signature_def, output_graph, ', '.join(unsupported)) initializer_graph_def = None - initializer_signature_def = None if initializer_graph: initializer_graph_def = initializer_graph.as_graph_def() - if hasattr(initializer_graph, 'outputs'): - initializer_signature_def = _build_signature_def(initializer_graph, [], initializer_graph.outputs) extract_weights( optimized_graph, output_graph, tf_version, signature_def, quantization_dtype_map, weight_shard_size_bytes, - initializer_graph_def, initializer_signature_def, - resource_ids_maps=resource_ids_maps, metadata=metadata) + initializer_graph_def, metadata=metadata) def extract_const_nodes(nodes): """Takes a list of nodes and extract the weights. Return weight manifest @@ -266,8 +256,6 @@ def extract_weights(graph_def, quantization_dtype_map=None, weight_shard_size_bytes=1024 * 1024 * 4, initializer_graph_def=None, - initializer_signature_def=None, - resource_ids_maps=None, metadata=None): """Takes a Python GraphDef object and extract the weights. @@ -283,10 +271,6 @@ def extract_weights(graph_def, weight_shard_size_bytes: Shard size (in bytes) of the weight files. The size of each weight file will be <= this value. initializer_graph_def: tf.GraphDef proto object for initializer graph. - initializer_signature_def: the SignatureDef of the initializer graph. - resource_ids_maps: Tuple of two dictionaries, one - mapping inference input names to resource id, and the other - mapping initializer output names to resource id. metadata: User defined metadata map. """ global_manifest = extract_const_nodes(graph_def.node) @@ -314,8 +298,6 @@ def extract_weights(graph_def, quantization_dtype_map=quantization_dtype_map, weight_shard_size_bytes=weight_shard_size_bytes, initializer_graph_def=initializer_graph_def, - initializer_signature_def=initializer_signature_def, - resource_ids_maps=resource_ids_maps, metadata=metadata) def write_artifacts(topology, @@ -326,8 +308,6 @@ def write_artifacts(topology, quantization_dtype_map=None, weight_shard_size_bytes=1024 * 1024 * 4, initializer_graph_def=None, - initializer_signature_def=None, - resource_ids_maps=None, metadata=None): """Writes weights and topology to the output_dir. @@ -346,10 +326,6 @@ def write_artifacts(topology, weight_shard_size_bytes: Shard size (in bytes) of the weight files. The size of each weight file will be <= this value. initializer_graph_def: tf.GraphDef proto object for initializer graph. - initializer_signature_def: the SignatureDef of the initializer graph. - resource_ids_maps: Tuple of two dictionaries, one - mapping inference input names to resource id, and the other - mapping initializer output names to resource id. metadata: User defined metadata map. """ model_json = { @@ -367,30 +343,6 @@ def write_artifacts(topology, if initializer_graph_def and initializer_graph_def.node: model_json[common.ARTIFACT_MODEL_INITIALIZER] = MessageToDict( initializer_graph_def) - if initializer_signature_def: - model_json[common.INITIALIZER_SIGNATURE_KEY] = MessageToDict( - initializer_signature_def) - - # Assign resource ids to inference inputs and initializer outputs. In - # TensorFlow, both inference and initializer graphs have a reference - # to the common resource (so initializer runs on reference, and then inference - # graph uses it). We are doing something similar but instead of assigning - # a reference to the resource in the serialized graph, we assign the id - # of the resource, and then we can recreate the common reference in javascript - # by matching resource ids. - if resource_ids_maps is not None: - model_input_to_resource_id, init_output_to_resource_id = resource_ids_maps - signature_inputs = model_json[common.SIGNATURE_KEY]['inputs'] - initializer_signature_outputs = model_json[common.INITIALIZER_SIGNATURE_KEY]['outputs'] - - for (input, resource_id) in model_input_to_resource_id.items(): - if input in signature_inputs: - signature_inputs[input][common.RESOURCE_ID_KEY] = resource_id - - for (output, resource_id) in init_output_to_resource_id.items(): - if output in initializer_signature_outputs: - initializer_signature_outputs[output][common.RESOURCE_ID_KEY] = resource_id - weights_manifest = write_weights.write_weights( weights, os.path.dirname(output_graph), write_manifest=False, @@ -598,108 +550,6 @@ def _find_signature(saved_model_dir, saved_model_tags, signature_def): return signature_def_map[signature_def] -def _get_resource_initializer_concrete_function(model): - """Create a tf.function that creates and initializes all the resources used by the model. - For more information on resources, please see the TensorFlow code: - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/trackable/resource.py#L232 - Args: - model: Loaded saved model. - - Returns: - Nullable. A concrete function. - """ - trackable_view = TrackableView(model) - model_resources = [obj for obj in trackable_view.descendants() if isinstance(obj, TrackableResource)] - - if not model_resources: - return None - - # A list holding tuples of (TrackableResource, captured_input_index) where - # TrackableResource represents one resource in the model - # (a hash table for example), and captured_input_index is the resource - # initialization function's captured input index corresponding - # to the TrackableResource. Captured inputs are simply inputs not provided - # directly be user, but by the model. - model_resources_with_captured_input_index = [] - for model_resource in model_resources: - # A runtime id that is unique across different resources, and constant - # across graphs. - resource_handle_id = model_resource.resource_handle._id - # the _initialize function initializes the resource, so one of its captured - # inputs must be the resource, so search for that input. - captured_inputs = model_resource._initialize.get_concrete_function()._captured_inputs - for captured_input_index in range(len(captured_inputs)): - if captured_inputs[captured_input_index]._id == resource_handle_id: - model_resources_with_captured_input_index.append((model_resource, captured_input_index)) - - @tf.function() - def resource_initializer(): - # Recreate resources to capture them in this tf.function. - new_resources = [] - for (model_resource, captured_input_index) in model_resources_with_captured_input_index: - # Make a new resource (that is identical to the old, but captured in - # this functon only). - new_resource = model_resource._create_resource() - new_resources.append(new_resource) - - # Since we precomputed the captured input corresponding to this resource, - # we can directly replace it with the copy new_resource. If we don't do - # this, then _initialize will not get capture in this graph since the - # old resource was already initialized in TF model load. - model_resource._initialize.get_concrete_function()._captured_inputs[captured_input_index] = new_resource - model_resource._initialize() - - return new_resources - - # Add resource_initializer to the output graph. - return resource_initializer.get_concrete_function() - -def _get_resource_ids_maps(model, concrete_func, resource_init_concrete_func): - """Generates dictionaries that map tensor names to the loaded saved model resource id, - allowing for matching of initializer outputs to inference inputs. - - Args: - model: Loaded saved model. - concrete_func: Concrete function of the inference graph. - resource_init_concrete_func: Concrete function of the initializer graph. - - Returns: - A dictionary mapping inference input names to resource id. - A dictionary mapping initializer output names to resource id. - """ - trackable_view = TrackableView(model) - model_resources = [obj for obj in trackable_view.descendants() if isinstance(obj, TrackableResource)] - - - # Each resource has a unique runtime resource id associated with it which - # can be used across graphs, so we extract it here from inference - # graph for use later. - resource_id_to_captured_input_index = { - captured_input._id : captured_input_index for \ - captured_input_index, captured_input in \ - enumerate(concrete_func._captured_inputs) - } - # Captured inputs always come after user provided inputs. - captured_input_index_offset = len(concrete_func.inputs) - len(concrete_func._captured_inputs) - - model_input_to_resource_id = {} - init_output_to_resource_id = {} - for i, resource in enumerate(model_resources): - _id = resource.resource_handle._id - # Get input from inference graph corresponding to this resource. - captured_input_index = resource_id_to_captured_input_index[_id] - model_input = concrete_func.inputs[captured_input_index + captured_input_index_offset] - - # Get output from initializer graph corresponding to this resource. - init_output = resource_init_concrete_func.outputs[i] - - # Match both with the same id (initializer output will be passed in to - # corresponding input in inference input). - model_input_to_resource_id[model_input.name] = _id - init_output_to_resource_id[init_output.name] = _id - - return (model_input_to_resource_id, init_output_to_resource_id) - def _convert_tf_saved_model(output_dir, saved_model_dir=None, keras_model=None, @@ -813,15 +663,8 @@ def _convert_tf_saved_model(output_dir, # reliable way. Try to freeze the graph using V2 utils. If that fails, freeze # the graph using V1 utils. frozen_initializer_graph = None - resource_ids_maps = None try: frozen_graph = _freeze_saved_model_v2(concrete_func, control_flow_v2) - resource_initializer_concrete_func = _get_resource_initializer_concrete_function(model) - - if resource_initializer_concrete_func: - frozen_initializer_graph = _freeze_saved_model_v2(resource_initializer_concrete_func, control_flow_v2) - resource_ids_maps = _get_resource_ids_maps(model, concrete_func, resource_initializer_concrete_func) - except BaseException: if saved_model_dir: (frozen_graph, @@ -839,8 +682,9 @@ def _convert_tf_saved_model(output_dir, with tf.compat.v1.gfile.GFile(frozen_file, 'wb') as f: f.write(frozen_graph.as_graph_def().SerializeToString()) + inputs = [x for x in concrete_func.inputs if not x.dtype == 'resource'] signature = _build_signature_def( - frozen_graph, concrete_func.inputs, concrete_func.outputs, saved_model_sigature) + frozen_graph, inputs, concrete_func.outputs, saved_model_sigature) define_transform_graph_func() @@ -860,7 +704,6 @@ def _convert_tf_saved_model(output_dir, weight_shard_size_bytes=weight_shard_size_bytes, experiments=experiments, initializer_graph=frozen_initializer_graph, - resource_ids_maps=resource_ids_maps, metadata=metadata) def define_transform_graph_func(): diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py index fa1074e9523..36ad807fe5d 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py @@ -121,31 +121,6 @@ def _create_saved_model_v1_with_hashtable(self): builder.save() - def _create_saved_model_v2_with_hashtable(self): - """Create a TensorFlow SavedModel V2 with hash table for testing.""" - - class Table(tf.Module): - def __init__(self): - super(Table, self).__init__() - keys = tf.constant(['a', 'b']) - vals= tf.constant([0, 1]) - init = tf.lookup.KeyValueTensorInitializer(keys, vals) - self.table = tf.lookup.StaticHashTable(init, -1) - - def initializeTable(self): - @tf.function - def lookup(input): - return self.table.lookup(input) - - return lookup - - model = Table() - concrete_fn = model.initializeTable().get_concrete_function( - input=tf.TensorSpec([None], tf.string)) - - save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) - tf.saved_model.save(model, save_dir, signatures={"serving_default": concrete_fn}) - def _create_saved_model_with_fusable_conv2d(self, use_bias): """Test a basic model with fusable conv2d.""" layers = [ @@ -450,97 +425,6 @@ def test_convert_saved_model_v1_with_hashtable(self): tf.__version__) self.assertTrue(glob.glob(os.path.join(output_dir, 'group*-*'))) - def test_convert_saved_model_v2_with_hashtable(self): - self._create_saved_model_v2_with_hashtable() - - input_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) - output_dir = os.path.join(input_dir, 'js') - tf_saved_model_conversion_v2.convert_tf_saved_model( - input_dir, - output_dir - ) - - expected_signature = { - 'inputs': { - 'input': { - 'name': 'input:0', - 'dtype': 'DT_STRING', - 'tensorShape': {'dim': [{'size': '-1'}]} - }, - 'unknown:0': { - 'name': 'unknown:0', - 'dtype': 'DT_RESOURCE', - 'tensorShape': {}, - 'resourceId': None - } - }, - 'outputs': { - 'output_0': { - 'name': 'Identity:0', - 'dtype': 'DT_INT32', - 'tensorShape': {'dim': [{'size': '-1'}]} - } - } - } - - expected_initializer_signature = { - 'outputs': { - 'Identity:0': { - 'name': 'Identity:0', - 'dtype': 'DT_RESOURCE', - 'tensorShape': {}, - 'resourceId': None - } - } - } - - expected_weights_manifest = [{ - 'paths': ['group1-shard1of1.bin'], - 'weights': [ - {'name': 'unknown_0', 'shape': [], 'dtype': 'int32'}, - {'name': '4609', 'shape': [2], 'dtype': 'string'}, - {'name': '4611', 'shape': [2], 'dtype': 'int32'} - ]}] - - tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'js') - # Check model.json and weights manifest. - with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f: - model_json = json.load(f) - - # Check resource ids match which indicates the initializer output is mapped - # to the inference input. - signature_resource_id = model_json['signature']['inputs']['unknown:0']['resourceId'] - initializer_resource_id = model_json['initializerSignature']['outputs']['Identity:0']['resourceId'] - self.assertTrue(signature_resource_id) - self.assertEqual(signature_resource_id, initializer_resource_id) - - # Update expected signatures with resourceId since it is a runtime value. - expected_signature['inputs']['unknown:0']['resourceId'] = signature_resource_id - expected_initializer_signature['outputs']['Identity:0']['resourceId'] = signature_resource_id - self.assertEqual(model_json['signature'], expected_signature) - self.assertEqual(model_json['initializerSignature'], expected_initializer_signature) - - self.assertTrue(model_json['modelTopology']) - self.assertIsNot(model_json['modelTopology']['versions'], None) - model_ops = [node['op'] for node in model_json['modelTopology']['node']] - self.assertTrue('LookupTableFindV2' in model_ops) - - self.assertTrue(model_json['modelInitializer']) - initializer_ops = [node['op'] for node in model_json['modelInitializer']['node']] - self.assertTrue('HashTableV2' in initializer_ops) - self.assertTrue('LookupTableImportV2' in initializer_ops) - - weights_manifest = model_json['weightsManifest'] - self.assertEqual(weights_manifest, expected_weights_manifest) - # Check meta-data in the artifact JSON. - self.assertEqual(model_json['format'], 'graph-model') - self.assertEqual( - model_json['convertedBy'], - 'TensorFlow.js Converter v%s' % version.version) - self.assertEqual(model_json['generatedBy'], - tf.__version__) - self.assertTrue(glob.glob(os.path.join(output_dir, 'group*-*'))) - def test_convert_saved_model_v1_with_metadata(self): self._create_saved_model_v1() diff --git a/tfjs-converter/src/data/compiled_api.ts b/tfjs-converter/src/data/compiled_api.ts index d5d0b4f1786..735525e0e1f 100644 --- a/tfjs-converter/src/data/compiled_api.ts +++ b/tfjs-converter/src/data/compiled_api.ts @@ -349,13 +349,10 @@ export declare interface ITensorInfo { cooSparse?: (TensorInfo.ICooSparse|null); /** TensorInfo dtype */ - dtype?: (DataType|string|null); + dtype?: (DataType|null); /** TensorInfo tensorShape */ tensorShape?: (ITensorShape|null); - - /** Resource id tensor was originally assigned to. */ - resourceId?: (number|null); } export namespace TensorInfo { diff --git a/tfjs-converter/src/executor/graph_model.ts b/tfjs-converter/src/executor/graph_model.ts index 632452e5db9..a260100697e 100644 --- a/tfjs-converter/src/executor/graph_model.ts +++ b/tfjs-converter/src/executor/graph_model.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {dispose, InferenceModel, io, ModelPredictConfig, NamedTensorMap, Tensor, util} from '@tensorflow/tfjs-core'; +import {InferenceModel, io, ModelPredictConfig, NamedTensorMap, Tensor, util} from '@tensorflow/tfjs-core'; import * as tensorflow from '../data/compiled_api'; import {NamedTensorsMap, TensorInfo} from '../data/types'; @@ -46,10 +46,8 @@ export class GraphModel implements private handler: UrlIOHandler; private artifacts: io.ModelArtifacts; private initializer: GraphExecutor; - private resourceIdToCapturedInput: {[key: number]: Tensor}; private resourceManager: ResourceManager; private signature: tensorflow.ISignatureDef; - private initializerSignature: tensorflow.ISignatureDef; private structuredOutputKeys: string[]; private readonly io: typeof io; @@ -162,7 +160,7 @@ export class GraphModel implements /** * Synchronously construct the in memory weight map and - * compile the inference graph. + * compile the inference graph. Also initialize hashtable if any. * * @doc {heading: 'Models', subheading: 'Classes', ignoreCI: true} */ @@ -203,7 +201,7 @@ export class GraphModel implements // hashTables created from when executing the initializer will be stored // in the resourceManager. this.initializer.resourceManager = this.resourceManager; - this.initializerSignature = artifacts.initializerSignature; + this.initializer.executeAsync({}, []); } return true; @@ -336,36 +334,17 @@ export class GraphModel implements NamedTensorMap): NamedTensorMap { if (!(inputs instanceof Tensor) && !Array.isArray(inputs)) { // The input is already a NamedTensorMap. - if (this.signature != null && this.signature.inputs != null) { - for (const input in this.signature.inputs) { - const tensor = this.signature.inputs[input]; - if (tensor.resourceId != null) { - inputs[input] = this.resourceIdToCapturedInput[tensor.resourceId]; - } - } - } return inputs; } inputs = Array.isArray(inputs) ? inputs : [inputs]; - - const numCapturedInputs = - Object.keys(this.resourceIdToCapturedInput).length; - if (inputs.length + numCapturedInputs !== this.inputNodes.length) { - throw new Error(`Input tensor count mismatch, the graph model has ${ - this.inputNodes.length - - numCapturedInputs} non-resource placeholders, while there are ${ - inputs.length} input tensors provided.`); + if (inputs.length !== this.inputNodes.length) { + throw new Error( + 'Input tensor count mismatch,' + + `the graph model has ${this.inputNodes.length} placeholders, ` + + `while there are ${inputs.length} input tensors.`); } - - let inputIndex = 0; - return this.inputNodes.reduce((map, inputName) => { - const signature = - this.signature ? this.signature.inputs[inputName] : null; - if (signature != null && signature.resourceId != null) { - map[inputName] = this.resourceIdToCapturedInput[signature.resourceId]; - } else { - map[inputName] = (inputs as Tensor[])[inputIndex++]; - } + return this.inputNodes.reduce((map, inputName, i) => { + map[inputName] = (inputs as Tensor[])[i]; return map; }, {} as NamedTensorMap); } @@ -375,43 +354,6 @@ export class GraphModel implements return !Array.isArray(outputs) ? [outputs] : outputs; } - private executeInitializerGraph() { - if (this.initializer == null) { - return []; - } - if (this.initializerSignature == null) { - return this.initializer.execute({}, []); - } else { - return this.initializer.execute( - {}, Object.keys(this.initializerSignature.outputs)); - } - } - - private async executeInitializerGraphAsync() { - if (this.initializer == null) { - return []; - } - if (this.initializerSignature == null) { - return this.initializer.executeAsync({}, []); - } else { - return this.initializer.executeAsync( - {}, Object.keys(this.initializerSignature.outputs)); - } - } - - private setResourceIdToCapturedInput(outputs: Tensor[]) { - this.resourceIdToCapturedInput = {}; - - if (this.initializerSignature) { - const outputNames = Object.keys(this.initializerSignature.outputs); - for (let i = 0; i < outputNames.length; i++) { - const outputName = outputNames[i]; - const tensorInfo = this.initializerSignature.outputs[outputName]; - this.resourceIdToCapturedInput[tensorInfo.resourceId] = outputs[i]; - } - } - } - /** * Executes inference for the model for given input tensors. * @param inputs tensor, tensor array or tensor map of the inputs for the @@ -430,15 +372,11 @@ export class GraphModel implements */ execute(inputs: Tensor|Tensor[]|NamedTensorMap, outputs?: string|string[]): Tensor|Tensor[] { - if (this.resourceIdToCapturedInput == null) { - this.setResourceIdToCapturedInput(this.executeInitializerGraph()); - } inputs = this.normalizeInputs(inputs); outputs = this.normalizeOutputs(outputs); const result = this.executor.execute(inputs, outputs); return result.length > 1 ? result : result[0]; } - /** * Executes inference for the model for given input tensors in async * fashion, use this method when your model contains control flow ops. @@ -458,10 +396,6 @@ export class GraphModel implements async executeAsync( inputs: Tensor|Tensor[]|NamedTensorMap, outputs?: string|string[]): Promise { - if (this.resourceIdToCapturedInput == null) { - this.setResourceIdToCapturedInput( - await this.executeInitializerGraphAsync()); - } inputs = this.normalizeInputs(inputs); outputs = this.normalizeOutputs(outputs); const result = await this.executor.executeAsync(inputs, outputs); @@ -505,9 +439,6 @@ export class GraphModel implements if (this.initializer) { this.initializer.dispose(); - if (this.resourceIdToCapturedInput) { - dispose(this.resourceIdToCapturedInput); - } } this.resourceManager.dispose(); diff --git a/tfjs-converter/src/executor/graph_model_test.ts b/tfjs-converter/src/executor/graph_model_test.ts index 9800f4b62ac..76263f4d915 100644 --- a/tfjs-converter/src/executor/graph_model_test.ts +++ b/tfjs-converter/src/executor/graph_model_test.ts @@ -15,7 +15,7 @@ * ============================================================================= */ import * as tfc from '@tensorflow/tfjs-core'; -import {io, scalar, Tensor} from '@tensorflow/tfjs-core'; +import {io, scalar} from '@tensorflow/tfjs-core'; import * as tensorflow from '../data/compiled_api'; import {deregisterOp, registerOp} from '../operations/custom_op/register'; @@ -23,7 +23,6 @@ import {RecursiveSpy, spyOnAllFunctions} from '../operations/executors/spy_ops'; import {GraphNode} from '../operations/types'; import {GraphModel, loadGraphModel, loadGraphModelSync} from './graph_model'; -import {HASH_TABLE_MODEL_V2} from './test_data/hash_table_v2_model_loader'; import {STRUCTURED_OUTPUTS_MODEL} from './test_data/structured_outputs_model_loader'; const HOST = 'http://example.org'; @@ -300,7 +299,7 @@ const HASH_TABLE_SIGNATURE: tensorflow.ISignatureDef = { values: {name: 'LookupTableFindV2:0', dtype: tensorflow.DataType.DT_FLOAT} } }; -const HASHTABLE_V1_HTTP_MODEL_LOADER = { +const HASHTABLE_HTTP_MODEL_LOADER = { load: async () => { return { modelTopology: HASH_TABLE_MODEL, @@ -315,12 +314,6 @@ const HASHTABLE_V1_HTTP_MODEL_LOADER = { } }; -const HASHTABLE_V2_MODEL_LOADER = { - load: async () => { - return HASH_TABLE_MODEL_V2; - } -}; - class IOHandlerForTest implements tfc.io.IOHandler { savedArtifacts: tfc.io.ModelArtifacts; @@ -989,10 +982,10 @@ describe('Model', () => { }); }); - describe('Hashtable V1 model', () => { + describe('Hashtable model', () => { beforeEach(() => { - spyIo.getLoadHandlers.and.returnValue([HASHTABLE_V1_HTTP_MODEL_LOADER]); - spyIo.browserHTTPRequest.and.returnValue(HASHTABLE_V1_HTTP_MODEL_LOADER); + spyIo.getLoadHandlers.and.returnValue([HASHTABLE_HTTP_MODEL_LOADER]); + spyIo.browserHTTPRequest.and.returnValue(HASHTABLE_HTTP_MODEL_LOADER); }); it('should be successful if call executeAsync', async () => { await model.load(); @@ -1002,43 +995,6 @@ describe('Model', () => { expect(res).not.toBeNull(); }); }); - - describe('Hashtable V2 model', () => { - beforeEach(() => { - spyIo.getLoadHandlers.and.returnValue([HASHTABLE_V2_MODEL_LOADER]); - spyIo.browserHTTPRequest.and.returnValue(HASHTABLE_V2_MODEL_LOADER); - }); - it('load', async () => { - const loaded = await model.load(); - expect(loaded).toBe(true); - }); - - describe('execute', () => { - it('should be successful if call executeAsync', async () => { - await model.load(); - const res = await model.executeAsync( - {'input': tfc.tensor1d(['a', 'b', 'c'])}) as Tensor; - expect(Array.from(res.dataSync())).toEqual([0, 1, -1]); - }); - }); - - describe('dispose', () => { - it('should dispose the weights', async () => { - const startTensors = tfc.memory().numTensors; - - await model.load(); - - const input = tfc.tensor1d(['a', 'b', 'c']); - const output = await model.executeAsync({input}) as Tensor; - input.dispose(); - output.dispose(); - - model.dispose(); - - expect(tfc.memory().numTensors).toEqual(startTensors); - }); - }); - }); }); describe('Graph execution gives actionable errors', () => { diff --git a/tfjs-converter/src/executor/test_data/hash_table_v2_model_loader.ts b/tfjs-converter/src/executor/test_data/hash_table_v2_model_loader.ts deleted file mode 100644 index 622bda4e61f..00000000000 --- a/tfjs-converter/src/executor/test_data/hash_table_v2_model_loader.ts +++ /dev/null @@ -1,342 +0,0 @@ -/** - * @license - * Copyright 2022 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ -export const HASH_TABLE_MODEL_V2 = { - modelTopology: { - node: [ - { - name: 'unknown_0', - op: 'Const', - attr: { - value: {tensor: {dtype: 'DT_INT32', tensorShape: {}}}, - dtype: {type: 'DT_INT32'} - } - }, - { - name: 'input', - op: 'Placeholder', - attr: - {shape: {shape: {dim: [{size: '-1'}]}}, dtype: {type: 'DT_STRING'}} - }, - { - name: 'unknown', - op: 'Placeholder', - attr: {shape: {shape: {}}, dtype: {type: 'DT_RESOURCE'}} - }, - { - name: 'StatefulPartitionedCall/None_Lookup/LookupTableFindV2', - op: 'LookupTableFindV2', - input: ['unknown', 'input', 'unknown_0'], - attr: { - Tout: {type: 'DT_INT32'}, - Tin: {type: 'DT_STRING'}, - _has_manual_control_dependencies: {b: true} - } - }, - { - name: 'Identity', - op: 'Identity', - input: ['StatefulPartitionedCall/None_Lookup/LookupTableFindV2'], - attr: {T: {type: 'DT_INT32'}} - } - ], - library: {}, - versions: {producer: 1240} - }, - format: 'graph-model', - generatedBy: '2.11.0-dev20220822', - convertedBy: 'TensorFlow.js Converter v1.7.0', - weightSpecs: [ - {name: 'unknown_0', shape: [], dtype: 'int32'}, - {name: '114', shape: [2], dtype: 'string'}, - {name: '116', shape: [2], dtype: 'int32'} - ], - 'weightData': - new Uint8Array([ - 0xff, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x61, 0x01, 0x00, - 0x00, 0x00, 0x62, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00 - ]).buffer, - - signature: { - inputs: { - input: { - name: 'input:0', - dtype: 'DT_STRING', - tensorShape: {dim: [{size: '-1'}]} - }, - 'unknown:0': { - name: 'unknown:0', - dtype: 'DT_RESOURCE', - tensorShape: {}, - resourceId: 66 - } - }, - outputs: { - output_0: { - name: 'Identity:0', - dtype: 'DT_INT32', - tensorShape: {dim: [{size: '-1'}]} - } - } - }, - modelInitializer: { - node: [ - { - name: 'Func/StatefulPartitionedCall/input_control_node/_0', - op: 'NoOp', - attr: {_has_manual_control_dependencies: {b: true}} - }, - { - name: '114', - op: 'Const', - attr: { - value: - {tensor: {dtype: 'DT_STRING', tensorShape: {dim: [{size: '2'}]}}}, - _has_manual_control_dependencies: {b: true}, - dtype: {type: 'DT_STRING'} - } - }, - { - name: '116', - op: 'Const', - attr: { - _has_manual_control_dependencies: {b: true}, - dtype: {type: 'DT_INT32'}, - value: - {tensor: {dtype: 'DT_INT32', tensorShape: {dim: [{size: '2'}]}}} - } - }, - { - name: - 'Func/StatefulPartitionedCall/StatefulPartitionedCall/input_control_node/_9', - op: 'NoOp', - input: ['^Func/StatefulPartitionedCall/input_control_node/_0'], - attr: {_has_manual_control_dependencies: {b: true}} - }, - { - name: 'StatefulPartitionedCall/StatefulPartitionedCall/hash_table', - op: 'HashTableV2', - input: [ - '^Func/StatefulPartitionedCall/StatefulPartitionedCall/input_control_node/_9' - ], - attr: { - container: {s: ''}, - use_node_name_sharing: {b: true}, - _has_manual_control_dependencies: {b: true}, - shared_name: {s: 'OTVfbG9hZF8xXzUy'}, - value_dtype: {type: 'DT_INT32'}, - key_dtype: {type: 'DT_STRING'} - } - }, - { - name: - 'Func/StatefulPartitionedCall/StatefulPartitionedCall/output_control_node/_11', - op: 'NoOp', - input: ['^StatefulPartitionedCall/StatefulPartitionedCall/hash_table'], - attr: {_has_manual_control_dependencies: {b: true}} - }, - { - name: 'Func/StatefulPartitionedCall/output_control_node/_2', - op: 'NoOp', - input: [ - '^Func/StatefulPartitionedCall/StatefulPartitionedCall/output_control_node/_11' - ], - attr: {_has_manual_control_dependencies: {b: true}} - }, - { - name: 'StatefulPartitionedCall/StatefulPartitionedCall/NoOp', - op: 'NoOp', - input: ['^StatefulPartitionedCall/StatefulPartitionedCall/hash_table'], - attr: { - _acd_function_control_output: {b: true}, - _has_manual_control_dependencies: {b: true} - } - }, - { - name: 'StatefulPartitionedCall/StatefulPartitionedCall/Identity', - op: 'Identity', - input: [ - 'StatefulPartitionedCall/StatefulPartitionedCall/hash_table', - '^StatefulPartitionedCall/StatefulPartitionedCall/NoOp' - ], - attr: {T: {type: 'DT_RESOURCE'}} - }, - { - name: 'Func/StatefulPartitionedCall/StatefulPartitionedCall/output/_10', - op: 'Identity', - input: ['StatefulPartitionedCall/StatefulPartitionedCall/Identity'], - attr: {T: {type: 'DT_RESOURCE'}} - }, - { - name: 'StatefulPartitionedCall/NoOp', - op: 'NoOp', - input: [ - '^Func/StatefulPartitionedCall/StatefulPartitionedCall/output_control_node/_11' - ], - attr: { - _has_manual_control_dependencies: {b: true}, - _acd_function_control_output: {b: true} - } - }, - { - name: 'StatefulPartitionedCall/Identity', - op: 'Identity', - input: [ - 'Func/StatefulPartitionedCall/StatefulPartitionedCall/output/_10', - '^StatefulPartitionedCall/NoOp' - ], - attr: {T: {type: 'DT_RESOURCE'}} - }, - { - name: 'Func/StatefulPartitionedCall/output/_1', - op: 'Identity', - input: ['StatefulPartitionedCall/Identity'], - attr: { - T: {type: 'DT_RESOURCE'}, - _has_manual_control_dependencies: {b: true} - } - }, - { - name: 'Func/StatefulPartitionedCall_1/input_control_node/_3', - op: 'NoOp', - input: ['^114', '^116', '^Func/StatefulPartitionedCall/output/_1'], - attr: {_has_manual_control_dependencies: {b: true}} - }, - { - name: 'Func/StatefulPartitionedCall_1/input/_4', - op: 'Identity', - input: [ - 'Func/StatefulPartitionedCall/output/_1', - '^Func/StatefulPartitionedCall_1/input_control_node/_3' - ], - attr: {T: {type: 'DT_RESOURCE'}} - }, - { - name: - 'Func/StatefulPartitionedCall_1/StatefulPartitionedCall/input_control_node/_12', - op: 'NoOp', - input: ['^Func/StatefulPartitionedCall_1/input_control_node/_3'], - attr: {_has_manual_control_dependencies: {b: true}} - }, - { - name: - 'Func/StatefulPartitionedCall_1/StatefulPartitionedCall/input/_13', - op: 'Identity', - input: [ - 'Func/StatefulPartitionedCall_1/input/_4', - '^Func/StatefulPartitionedCall_1/StatefulPartitionedCall/input_control_node/_12' - ], - attr: {T: {type: 'DT_RESOURCE'}} - }, - { - name: 'Func/StatefulPartitionedCall_1/input/_5', - op: 'Identity', - input: ['114', '^Func/StatefulPartitionedCall_1/input_control_node/_3'], - attr: {T: {type: 'DT_STRING'}} - }, - { - name: - 'Func/StatefulPartitionedCall_1/StatefulPartitionedCall/input/_14', - op: 'Identity', - input: [ - 'Func/StatefulPartitionedCall_1/input/_5', - '^Func/StatefulPartitionedCall_1/StatefulPartitionedCall/input_control_node/_12' - ], - attr: {T: {type: 'DT_STRING'}} - }, - { - name: 'Func/StatefulPartitionedCall_1/input/_6', - op: 'Identity', - input: ['116', '^Func/StatefulPartitionedCall_1/input_control_node/_3'], - attr: {T: {type: 'DT_INT32'}} - }, - { - name: - 'Func/StatefulPartitionedCall_1/StatefulPartitionedCall/input/_15', - op: 'Identity', - input: [ - 'Func/StatefulPartitionedCall_1/input/_6', - '^Func/StatefulPartitionedCall_1/StatefulPartitionedCall/input_control_node/_12' - ], - attr: {T: {type: 'DT_INT32'}} - }, - { - name: - 'StatefulPartitionedCall_1/StatefulPartitionedCall/key_value_init94/LookupTableImportV2', - op: 'LookupTableImportV2', - input: [ - 'Func/StatefulPartitionedCall_1/StatefulPartitionedCall/input/_13', - 'Func/StatefulPartitionedCall_1/StatefulPartitionedCall/input/_14', - 'Func/StatefulPartitionedCall_1/StatefulPartitionedCall/input/_15' - ], - attr: { - Tout: {type: 'DT_INT32'}, - Tin: {type: 'DT_STRING'}, - _has_manual_control_dependencies: {b: true} - } - }, - { - name: - 'Func/StatefulPartitionedCall_1/StatefulPartitionedCall/output_control_node/_17', - op: 'NoOp', - input: [ - '^StatefulPartitionedCall_1/StatefulPartitionedCall/key_value_init94/LookupTableImportV2' - ], - attr: {_has_manual_control_dependencies: {b: true}} - }, - { - name: 'Func/StatefulPartitionedCall_1/output_control_node/_8', - op: 'NoOp', - input: [ - '^Func/StatefulPartitionedCall_1/StatefulPartitionedCall/output_control_node/_17' - ], - attr: {_has_manual_control_dependencies: {b: true}} - }, - { - name: 'NoOp', - op: 'NoOp', - input: [ - '^Func/StatefulPartitionedCall/output_control_node/_2', - '^Func/StatefulPartitionedCall_1/output_control_node/_8' - ], - attr: { - _has_manual_control_dependencies: {b: true}, - _acd_function_control_output: {b: true} - } - }, - { - name: 'Identity', - op: 'Identity', - input: [ - 'Func/StatefulPartitionedCall/output/_1', - '^Func/StatefulPartitionedCall_1/output_control_node/_8', '^NoOp' - ], - attr: {T: {type: 'DT_RESOURCE'}} - } - ], - versions: {producer: 1240} - }, - initializerSignature: { - outputs: { - 'Identity:0': { - name: 'Identity:0', - dtype: 'DT_RESOURCE', - tensorShape: {}, - resourceId: 66 - } - } - } -}; diff --git a/tfjs-converter/yarn.lock b/tfjs-converter/yarn.lock index 6c3eb2abf0f..b0970fa577d 100644 --- a/tfjs-converter/yarn.lock +++ b/tfjs-converter/yarn.lock @@ -100,26 +100,6 @@ resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.38.tgz#f8bb07c371ccb1903f3752872c89f44006132947" integrity sha512-5jY9RhV7c0Z4Jy09G+NIDTsCZ5G0L5n+Z+p+Y7t5VJHM30bgwzSjVtlcBxqAj+6L/swIlvtOSzr8rBk/aNyV2g== -"@types/offscreencanvas@~2019.3.0": - version "2019.3.0" - resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.3.0.tgz#3336428ec7e9180cf4566dfea5da04eb586a6553" - integrity sha512-esIJx9bQg+QYF0ra8GnvfianIY8qWB0GBx54PK5Eps6m+xTj86KLavHv6qDhzKcu5UUOgNfJ2pWaIIV7TRUd9Q== - -"@types/seedrandom@^2.4.28": - version "2.4.30" - resolved "https://registry.yarnpkg.com/@types/seedrandom/-/seedrandom-2.4.30.tgz#d2efe425869b84163c2d56e779dddadb9372cbfa" - integrity sha512-AnxLHewubLVzoF/A4qdxBGHCKifw8cY32iro3DQX9TPcetE95zBeVt3jnsvtvAUf1vwzMfwzp4t/L2yqPlnjkQ== - -"@types/webgl-ext@0.0.30": - version "0.0.30" - resolved "https://registry.yarnpkg.com/@types/webgl-ext/-/webgl-ext-0.0.30.tgz#0ce498c16a41a23d15289e0b844d945b25f0fb9d" - integrity sha512-LKVgNmBxN0BbljJrVUwkxwRYqzsAEPcZOe6S2T6ZaBDIrFp0qu4FNlpc5sM1tGbXUYFgdVQIoeLk1Y1UoblyEg== - -"@webgpu/types@^0.1.16": - version "0.1.21" - resolved "https://registry.yarnpkg.com/@webgpu/types/-/types-0.1.21.tgz#b181202daec30d66ccd67264de23814cfd176d3a" - integrity sha512-pUrWq3V5PiSGFLeLxoGqReTZmiiXwY3jRkIG5sLLKjyqNxrwm/04b4nw7LSmGWJcKk59XOM/YRTUwOzo4MMlow== - ansi-regex@^5.0.1: version "5.0.1" resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" @@ -304,7 +284,7 @@ jsonfile@^4.0.0: optionalDependencies: graceful-fs "^4.1.6" -long@4.0.0, long@^4.0.0: +long@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/long/-/long-4.0.0.tgz#9a7b71cfb7d361a194ea555241c92f7468d5bf28" integrity sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA== @@ -398,11 +378,6 @@ require-directory@^2.1.1: resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" integrity sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q== -seedrandom@^3.0.5: - version "3.0.5" - resolved "https://registry.yarnpkg.com/seedrandom/-/seedrandom-3.0.5.tgz#54edc85c95222525b0c7a6f6b3543d8e0b3aa0a7" - integrity sha512-8OwmbklUNzwezjGInmZ+2clQmExQPvomqjL7LFqOYqtmuxRgQYqOD3mHaU+MvZn5FLUeVxVfQjwLZW/n/JFuqg== - source-map-support@^0.5.6: version "0.5.19" resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.19.tgz#a98b62f86dcaf4f67399648c085291ab9e8fed61" diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index ce20b7c9fee..9a3756b652b 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -396,9 +396,6 @@ export function getModelJSONForModelArtifacts( if (artifacts.modelInitializer != null) { result.modelInitializer = artifacts.modelInitializer; } - if (artifacts.initializerSignature != null) { - result.initializerSignature = artifacts.initializerSignature; - } if (artifacts.trainingConfig != null) { result.trainingConfig = artifacts.trainingConfig; } @@ -449,9 +446,6 @@ export function getModelArtifactsForJSONSync( if (modelJSON.modelInitializer != null) { modelArtifacts.modelInitializer = modelJSON.modelInitializer; } - if (modelJSON.initializerSignature != null) { - modelArtifacts.initializerSignature = modelJSON.initializerSignature; - } return modelArtifacts; } diff --git a/tfjs-core/src/io/local_storage.ts b/tfjs-core/src/io/local_storage.ts index 2de2639c35c..019d6bf5d6e 100644 --- a/tfjs-core/src/io/local_storage.ts +++ b/tfjs-core/src/io/local_storage.ts @@ -198,9 +198,6 @@ export class BrowserLocalStorage implements IOHandler { modelInitializer: modelArtifacts.modelInitializer != null ? modelArtifacts.modelInitializer : undefined, - initializerSignature: modelArtifacts.initializerSignature != null ? - modelArtifacts.initializerSignature : - undefined, trainingConfig: modelArtifacts.trainingConfig != null ? modelArtifacts.trainingConfig : undefined @@ -280,9 +277,6 @@ export class BrowserLocalStorage implements IOHandler { if (metadata.modelInitializer != null) { out.modelInitializer = metadata.modelInitializer; } - if (metadata.initializerSignature != null) { - out.initializerSignature = metadata.initializerSignature; - } if (metadata.trainingConfig != null) { out.trainingConfig = metadata.trainingConfig; } diff --git a/tfjs-core/src/io/local_storage_test.ts b/tfjs-core/src/io/local_storage_test.ts index 7d2a351d958..0936c770aad 100644 --- a/tfjs-core/src/io/local_storage_test.ts +++ b/tfjs-core/src/io/local_storage_test.ts @@ -83,7 +83,6 @@ describeWithFlags('LocalStorage', BROWSER_ENVS, () => { signature: null, userDefinedMetadata: {}, modelInitializer: {}, - initializerSignature: null, trainingConfig: trainingConfig1, }; @@ -125,183 +124,171 @@ describeWithFlags('LocalStorage', BROWSER_ENVS, () => { }); it('Save artifacts succeeds', runWithLock(async () => { - const testStartDate = new Date(); - const handler = tf.io.getSaveHandlers('localstorage://foo/FooModel')[0]; - const saveResult = await handler.save(artifacts1); - - expect(saveResult.modelArtifactsInfo.dateSaved.getTime()) - .toBeGreaterThanOrEqual(testStartDate.getTime()); - // Note: The following two assertions work only because there is no - // non-ASCII characters in `modelTopology1` and `weightSpecs1`. - expect(saveResult.modelArtifactsInfo.modelTopologyBytes) - .toEqual(JSON.stringify(modelTopology1).length); - expect(saveResult.modelArtifactsInfo.weightSpecsBytes) - .toEqual(JSON.stringify(weightSpecs1).length); - expect(saveResult.modelArtifactsInfo.weightDataBytes).toEqual(16); - - // Check the content of the saved items in local storage. - const LS = window.localStorage; - const info = - JSON.parse(LS.getItem('tensorflowjs_models/foo/FooModel/info')); - expect(Date.parse(info.dateSaved)) - .toEqual(saveResult.modelArtifactsInfo.dateSaved.getTime()); - expect(info.modelTopologyBytes) - .toEqual(saveResult.modelArtifactsInfo.modelTopologyBytes); - expect(info.weightSpecsBytes) - .toEqual(saveResult.modelArtifactsInfo.weightSpecsBytes); - expect(info.weightDataBytes) - .toEqual(saveResult.modelArtifactsInfo.weightDataBytes); - - const topologyString = - LS.getItem('tensorflowjs_models/foo/FooModel/model_topology'); - expect(JSON.stringify(modelTopology1)).toEqual(topologyString); - - const weightSpecsString = - LS.getItem('tensorflowjs_models/foo/FooModel/weight_specs'); - expect(JSON.stringify(weightSpecs1)).toEqual(weightSpecsString); - - const weightDataBase64String = - LS.getItem('tensorflowjs_models/foo/FooModel/weight_data'); - expect(base64StringToArrayBuffer(weightDataBase64String)) - .toEqual(weightData1); - })); + const testStartDate = new Date(); + const handler = tf.io.getSaveHandlers('localstorage://foo/FooModel')[0]; + const saveResult = await handler.save(artifacts1); + + expect(saveResult.modelArtifactsInfo.dateSaved.getTime()) + .toBeGreaterThanOrEqual(testStartDate.getTime()); + // Note: The following two assertions work only because there is no + // non-ASCII characters in `modelTopology1` and `weightSpecs1`. + expect(saveResult.modelArtifactsInfo.modelTopologyBytes) + .toEqual(JSON.stringify(modelTopology1).length); + expect(saveResult.modelArtifactsInfo.weightSpecsBytes) + .toEqual(JSON.stringify(weightSpecs1).length); + expect(saveResult.modelArtifactsInfo.weightDataBytes).toEqual(16); + + // Check the content of the saved items in local storage. + const LS = window.localStorage; + const info = + JSON.parse(LS.getItem('tensorflowjs_models/foo/FooModel/info')); + expect(Date.parse(info.dateSaved)) + .toEqual(saveResult.modelArtifactsInfo.dateSaved.getTime()); + expect(info.modelTopologyBytes) + .toEqual(saveResult.modelArtifactsInfo.modelTopologyBytes); + expect(info.weightSpecsBytes) + .toEqual(saveResult.modelArtifactsInfo.weightSpecsBytes); + expect(info.weightDataBytes) + .toEqual(saveResult.modelArtifactsInfo.weightDataBytes); + + const topologyString = + LS.getItem('tensorflowjs_models/foo/FooModel/model_topology'); + expect(JSON.stringify(modelTopology1)).toEqual(topologyString); + + const weightSpecsString = + LS.getItem('tensorflowjs_models/foo/FooModel/weight_specs'); + expect(JSON.stringify(weightSpecs1)).toEqual(weightSpecsString); + + const weightDataBase64String = + LS.getItem('tensorflowjs_models/foo/FooModel/weight_data'); + expect(base64StringToArrayBuffer(weightDataBase64String)) + .toEqual(weightData1); + })); it('Save-load round trip succeeds', runWithLock(async () => { - const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0]; - - await handler1.save(artifacts1); - const handler2 = tf.io.getLoadHandlers('localstorage://FooModel')[0]; - const loaded = await handler2.load(); - expect(loaded.modelTopology).toEqual(modelTopology1); - expect(loaded.weightSpecs).toEqual(weightSpecs1); - expect(loaded.weightData).toEqual(weightData1); - expect(loaded.format).toEqual('layers-model'); - expect(loaded.generatedBy).toEqual('TensorFlow.js v0.0.0'); - expect(loaded.convertedBy).toEqual('1.13.1'); - expect(loaded.userDefinedMetadata).toEqual({}); - expect(loaded.modelInitializer).toEqual({}); - expect(loaded.initializerSignature).toBeUndefined(); - expect(loaded.trainingConfig).toEqual(trainingConfig1); - })); + const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0]; + + await handler1.save(artifacts1); + const handler2 = tf.io.getLoadHandlers('localstorage://FooModel')[0]; + const loaded = await handler2.load(); + expect(loaded.modelTopology).toEqual(modelTopology1); + expect(loaded.weightSpecs).toEqual(weightSpecs1); + expect(loaded.weightData).toEqual(weightData1); + expect(loaded.format).toEqual('layers-model'); + expect(loaded.generatedBy).toEqual('TensorFlow.js v0.0.0'); + expect(loaded.convertedBy).toEqual('1.13.1'); + expect(loaded.userDefinedMetadata).toEqual({}); + expect(loaded.modelInitializer).toEqual({}); + expect(loaded.trainingConfig).toEqual(trainingConfig1); + })); it('Save-load round trip succeeds: v0 format', runWithLock(async () => { - const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0]; - - await handler1.save(artifactsV0); - const handler2 = tf.io.getLoadHandlers('localstorage://FooModel')[0]; - const loaded = await handler2.load(); - expect(loaded.modelTopology).toEqual(modelTopology1); - expect(loaded.weightSpecs).toEqual(weightSpecs1); - expect(loaded.weightData).toEqual(weightData1); - expect(loaded.format).toBeUndefined(); - expect(loaded.generatedBy).toBeUndefined(); - expect(loaded.convertedBy).toBeUndefined(); - expect(loaded.userDefinedMetadata).toBeUndefined(); - expect(loaded.trainingConfig).toBeUndefined(); - expect(loaded.modelInitializer).toBeUndefined(); - expect(loaded.initializerSignature).toBeUndefined(); - expect(loaded.trainingConfig).toBeUndefined(); - })); + const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0]; + + await handler1.save(artifactsV0); + const handler2 = tf.io.getLoadHandlers('localstorage://FooModel')[0]; + const loaded = await handler2.load(); + expect(loaded.modelTopology).toEqual(modelTopology1); + expect(loaded.weightSpecs).toEqual(weightSpecs1); + expect(loaded.weightData).toEqual(weightData1); + expect(loaded.format).toBeUndefined(); + expect(loaded.generatedBy).toBeUndefined(); + expect(loaded.convertedBy).toBeUndefined(); + expect(loaded.userDefinedMetadata).toBeUndefined(); + expect(loaded.trainingConfig).toBeUndefined(); + })); it('Loading nonexistent model fails.', runWithLock(async () => { - const handler = - tf.io.getSaveHandlers('localstorage://NonexistentModel')[0]; - try { - await handler.load(); - } catch (err) { - expect(err.message) - .toEqual( - 'In local storage, there is no model with name ' + - '\'NonexistentModel\''); - return; // Success - } - fail('Loading nonexistent model succeeded unexpectedly.'); - })); + const handler = tf.io.getSaveHandlers('localstorage://NonexistentModel')[0]; + try { + await handler.load(); + } catch (err) { + expect(err.message) + .toEqual( + 'In local storage, there is no model with name ' + + '\'NonexistentModel\''); + return; // Success + } + fail('Loading nonexistent model succeeded unexpectedly.'); + })); it('Loading model with missing topology fails.', runWithLock(async () => { - const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0]; - await handler1.save(artifacts1); - // Manually remove the topology item from local storage. - window.localStorage.removeItem( - 'tensorflowjs_models/FooModel/model_topology'); - - const handler2 = tf.io.getLoadHandlers('localstorage://FooModel')[0]; - try { - await handler2.load(); - } catch (err) { - expect(err.message) - .toEqual( - 'In local storage, the topology of model ' + - '\'FooModel\' is missing.'); - return; // Success - } - fail('Loading of model with missing topology succeeded unexpectedly.'); - })); + const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0]; + await handler1.save(artifacts1); + // Manually remove the topology item from local storage. + window.localStorage.removeItem( + 'tensorflowjs_models/FooModel/model_topology'); + + const handler2 = tf.io.getLoadHandlers('localstorage://FooModel')[0]; + try { + await handler2.load(); + } catch (err) { + expect(err.message) + .toEqual( + 'In local storage, the topology of model ' + + '\'FooModel\' is missing.'); + return; // Success + } + fail('Loading of model with missing topology succeeded unexpectedly.'); + })); it('Loading model with missing weight specs fails.', runWithLock(async () => { - const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0]; - await handler1.save(artifacts1); - // Manually remove the weight specs item from local storage. - window.localStorage.removeItem( - 'tensorflowjs_models/FooModel/weight_specs'); - - const handler2 = tf.io.getLoadHandlers('localstorage://FooModel')[0]; - try { - await handler2.load(); - } catch (err) { - expect(err.message) - .toEqual( - 'In local storage, the weight specs of model ' + - '\'FooModel\' are missing.'); - return; // Success - } - fail( - 'Loading of model with missing weight specs ' + - 'succeeded unexpectedly.'); - })); + const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0]; + await handler1.save(artifacts1); + // Manually remove the weight specs item from local storage. + window.localStorage.removeItem('tensorflowjs_models/FooModel/weight_specs'); + + const handler2 = tf.io.getLoadHandlers('localstorage://FooModel')[0]; + try { + await handler2.load(); + } catch (err) { + expect(err.message) + .toEqual( + 'In local storage, the weight specs of model ' + + '\'FooModel\' are missing.'); + return; // Success + } + fail('Loading of model with missing weight specs succeeded unexpectedly.'); + })); it('Loading model with missing weight data fails.', runWithLock(async () => { - const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0]; - await handler1.save(artifacts1); - - // Manually remove the weight data item from local storage. - window.localStorage.removeItem( - 'tensorflowjs_models/FooModel/weight_data'); - - const handler2 = tf.io.getLoadHandlers('localstorage://FooModel')[0]; - try { - await handler2.load(); - fail( - 'Loading of model with missing weight data ' + - 'succeeded unexpectedly.'); - } catch (err) { - expect(err.message) - .toEqual( - 'In local storage, the binary weight values of model ' + - '\'FooModel\' are missing.'); - } - })); + const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0]; + await handler1.save(artifacts1); + + // Manually remove the weight data item from local storage. + window.localStorage.removeItem('tensorflowjs_models/FooModel/weight_data'); + + const handler2 = tf.io.getLoadHandlers('localstorage://FooModel')[0]; + try { + await handler2.load(); + fail('Loading of model with missing weight data succeeded unexpectedly.'); + } catch (err) { + expect(err.message) + .toEqual( + 'In local storage, the binary weight values of model ' + + '\'FooModel\' are missing.'); + } + })); it('Data size too large leads to error thrown', runWithLock(async () => { - const overflowByteSize = findOverflowingByteSize(); - const overflowArtifacts: tf.io.ModelArtifacts = { - modelTopology: modelTopology1, - weightSpecs: weightSpecs1, - weightData: new ArrayBuffer(overflowByteSize), - }; - const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0]; - try { - await handler1.save(overflowArtifacts); - fail( - 'Saving of model of overflowing-size weight data succeeded ' + - 'unexpectedly.'); - } catch (err) { - expect( - (err.message as string) - .indexOf('Failed to save model \'FooModel\' to local storage')) - .toEqual(0); - } - })); + const overflowByteSize = findOverflowingByteSize(); + const overflowArtifacts: tf.io.ModelArtifacts = { + modelTopology: modelTopology1, + weightSpecs: weightSpecs1, + weightData: new ArrayBuffer(overflowByteSize), + }; + const handler1 = tf.io.getSaveHandlers('localstorage://FooModel')[0]; + try { + await handler1.save(overflowArtifacts); + fail('Saving of model of overflowing-size weight data succeeded ' + + 'unexpectedly.'); + } catch (err) { + expect((err.message as string) + .indexOf( + 'Failed to save model \'FooModel\' to local storage')) + .toEqual(0); + } + })); it('Null, undefined or empty modelPath throws Error', () => { expect(() => browserLocalStorage(null)) @@ -324,81 +311,87 @@ describeWithFlags('LocalStorage', BROWSER_ENVS, () => { }); it('Manager: List models: 0 result', runWithLock(async () => { - // Before any model is saved, listModels should return empty result. - const out = await new BrowserLocalStorageManager().listModels(); - expect(out).toEqual({}); - })); + // Before any model is saved, listModels should return empty result. + const out = await new BrowserLocalStorageManager().listModels(); + expect(out).toEqual({}); + })); it('Manager: List models: 1 result', runWithLock(async () => { - const handler = tf.io.getSaveHandlers('localstorage://baz/QuxModel')[0]; - const saveResult = await handler.save(artifacts1); - - // After successful saving, there should be one model. - const out = await new BrowserLocalStorageManager().listModels(); - if (Object.keys(out).length !== 1) { - console.log(JSON.stringify(out, null, 2)); - } - - expect(Object.keys(out).length).toEqual(1); - expect(out['baz/QuxModel'].modelTopologyType) - .toEqual(saveResult.modelArtifactsInfo.modelTopologyType); - expect(out['baz/QuxModel'].modelTopologyBytes) - .toEqual(saveResult.modelArtifactsInfo.modelTopologyBytes); - expect(out['baz/QuxModel'].weightSpecsBytes) - .toEqual(saveResult.modelArtifactsInfo.weightSpecsBytes); - expect(out['baz/QuxModel'].weightDataBytes) - .toEqual(saveResult.modelArtifactsInfo.weightDataBytes); - })); + const handler = tf.io.getSaveHandlers('localstorage://baz/QuxModel')[0]; + const saveResult = await handler.save(artifacts1); + + // After successful saving, there should be one model. + const out = await new BrowserLocalStorageManager().listModels(); + if (Object.keys(out).length !== 1) { + console.log(JSON.stringify(out, null, 2)); + } + + expect(Object.keys(out).length).toEqual(1); + expect(out['baz/QuxModel'].modelTopologyType) + .toEqual(saveResult.modelArtifactsInfo.modelTopologyType); + expect(out['baz/QuxModel'].modelTopologyBytes) + .toEqual(saveResult.modelArtifactsInfo.modelTopologyBytes); + expect(out['baz/QuxModel'].weightSpecsBytes) + .toEqual(saveResult.modelArtifactsInfo.weightSpecsBytes); + expect(out['baz/QuxModel'].weightDataBytes) + .toEqual(saveResult.modelArtifactsInfo.weightDataBytes); + })); it('Manager: List models: 2 results', runWithLock(async () => { - // First, save a model. - const handler1 = tf.io.getSaveHandlers('localstorage://QuxModel')[0]; - const saveResult1 = await handler1.save(artifacts1); - - // Then, save the model under another path. - const handler2 = - tf.io.getSaveHandlers('localstorage://repeat/QuxModel')[0]; - const saveResult2 = await handler2.save(artifacts1); - - // After successful saving, there should be two models. - const out = await new BrowserLocalStorageManager().listModels(); - if (Object.keys(out).length !== 2) { - console.log(JSON.stringify(out, null, 2)); - } - expect(Object.keys(out).length).toEqual(2); - expect(out['QuxModel'].modelTopologyType) - .toEqual(saveResult1.modelArtifactsInfo.modelTopologyType); - expect(out['QuxModel'].modelTopologyBytes) - .toEqual(saveResult1.modelArtifactsInfo.modelTopologyBytes); - expect(out['QuxModel'].weightSpecsBytes) - .toEqual(saveResult1.modelArtifactsInfo.weightSpecsBytes); - expect(out['QuxModel'].weightDataBytes) - .toEqual(saveResult1.modelArtifactsInfo.weightDataBytes); - expect(out['repeat/QuxModel'].modelTopologyType) - .toEqual(saveResult2.modelArtifactsInfo.modelTopologyType); - expect(out['repeat/QuxModel'].modelTopologyBytes) - .toEqual(saveResult2.modelArtifactsInfo.modelTopologyBytes); - expect(out['repeat/QuxModel'].weightSpecsBytes) - .toEqual(saveResult2.modelArtifactsInfo.weightSpecsBytes); - expect(out['repeat/QuxModel'].weightDataBytes) - .toEqual(saveResult2.modelArtifactsInfo.weightDataBytes); - })); + // First, save a model. + const handler1 = tf.io.getSaveHandlers('localstorage://QuxModel')[0]; + const saveResult1 = await handler1.save(artifacts1); + + // Then, save the model under another path. + const handler2 = tf.io.getSaveHandlers('localstorage://repeat/QuxModel')[0]; + const saveResult2 = await handler2.save(artifacts1); + + // After successful saving, there should be two models. + const out = await new BrowserLocalStorageManager().listModels(); + if (Object.keys(out).length !== 2) { + console.log(JSON.stringify(out, null, 2)); + } + expect(Object.keys(out).length).toEqual(2); + expect(out['QuxModel'].modelTopologyType) + .toEqual( + saveResult1.modelArtifactsInfo.modelTopologyType); + expect(out['QuxModel'].modelTopologyBytes) + .toEqual(saveResult1.modelArtifactsInfo + .modelTopologyBytes); + expect(out['QuxModel'].weightSpecsBytes) + .toEqual( + saveResult1.modelArtifactsInfo.weightSpecsBytes); + expect(out['QuxModel'].weightDataBytes) + .toEqual( + saveResult1.modelArtifactsInfo.weightDataBytes); + expect(out['repeat/QuxModel'].modelTopologyType) + .toEqual( + saveResult2.modelArtifactsInfo.modelTopologyType); + expect(out['repeat/QuxModel'].modelTopologyBytes) + .toEqual(saveResult2.modelArtifactsInfo + .modelTopologyBytes); + expect(out['repeat/QuxModel'].weightSpecsBytes) + .toEqual( + saveResult2.modelArtifactsInfo.weightSpecsBytes); + expect(out['repeat/QuxModel'].weightDataBytes) + .toEqual( + saveResult2.modelArtifactsInfo.weightDataBytes); + })); it('Manager: Successful deleteModel', runWithLock(async () => { - // First, save a model. - const handler1 = tf.io.getSaveHandlers('localstorage://QuxModel')[0]; - await handler1.save(artifacts1); - - // Then, save the model under another path. - const handler2 = - tf.io.getSaveHandlers('localstorage://repeat/QuxModel')[0]; - await handler2.save(artifacts1); - - // After successful saving, delete the first save, and then - // `listModel` should give only one result. - const manager = new BrowserLocalStorageManager(); - await manager.removeModel('QuxModel'); - const out = await manager.listModels(); - expect(Object.keys(out)).toEqual(['repeat/QuxModel']); - })); + // First, save a model. + const handler1 = tf.io.getSaveHandlers('localstorage://QuxModel')[0]; + await handler1.save(artifacts1); + + // Then, save the model under another path. + const handler2 = tf.io.getSaveHandlers('localstorage://repeat/QuxModel')[0]; + await handler2.save(artifacts1); + + // After successful saving, delete the first save, and then + // `listModel` should give only one result. + const manager = new BrowserLocalStorageManager(); + await manager.removeModel('QuxModel'); + const out = await manager.listModels(); + expect(Object.keys(out)).toEqual(['repeat/QuxModel']); + })); }); diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts index ece6b97e8b5..ec4456f4e79 100644 --- a/tfjs-core/src/io/types.ts +++ b/tfjs-core/src/io/types.ts @@ -292,11 +292,6 @@ export declare interface ModelArtifacts { * Initializer for the model. */ modelInitializer?: {}; - - /** - * Inputs and outputs signature for model initializer. - */ - initializerSignature?: {}; } /** @@ -367,11 +362,6 @@ export declare interface ModelJSON { * Initializer for the model. */ modelInitializer?: {}; - - /** - * Inputs and outputs signature for model initializer. - */ - initializerSignature?: {}; } /** diff --git a/tfjs-inference/src/file_handler.ts b/tfjs-inference/src/file_handler.ts index fb63f8edfc5..423e32e992f 100644 --- a/tfjs-inference/src/file_handler.ts +++ b/tfjs-inference/src/file_handler.ts @@ -75,11 +75,6 @@ export class FileHandler implements tf.io.IOHandler { modelArtifacts.modelInitializer = modelJSON.modelInitializer; } - // TODO: Uncomment once table initializers are supported in TFJS. - // if (modelJSON.initializerSignature != null) { - // modelArtifacts.initializerSignature = modelJSON.initializerSignature; - // } - if (modelJSON.weightsManifest != null) { const [weightSpecs, weightData] = this.loadWeights(modelJSON.weightsManifest, path); diff --git a/tfjs-inference/src/file_handler_test.ts b/tfjs-inference/src/file_handler_test.ts index cf1f39892a8..83d34eea22e 100644 --- a/tfjs-inference/src/file_handler_test.ts +++ b/tfjs-inference/src/file_handler_test.ts @@ -100,9 +100,7 @@ describe('File Handler', () => { weightsManifest, signature: {}, userDefinedMetadata: {}, - modelInitializer: {}, - // TODO: Uncomment once table initializers are supported in TFJS. - // initializerSignature: {} + modelInitializer: {} }; // Write model.json file. @@ -126,9 +124,6 @@ describe('File Handler', () => { expect(modelArtifacts.modelTopology).toEqual(modelTopology1); expect(modelArtifacts.signature).toEqual({}); expect(modelArtifacts.userDefinedMetadata).toEqual({}); - expect(modelArtifacts.modelInitializer).toEqual({}); - // TODO: Uncomment once table initializers are supported in TFJS. - // expect(modelArtifacts.initializerSignature).toEqual({}); expect(modelArtifacts.weightSpecs).toEqual([ { name: 'dense/kernel',