Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add asset zipping functionality to TFJS converter #6915

Merged
merged 8 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tfjs-converter/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ py_wheel(
"importlib_resources>=5.9.0",
"jax>=0.3.16",
"protobuf<3.20,>=3.9.2",
"tensorflow>=2.1.0,<3",
"tensorflow>=2.10.0,<3",
"tensorflow-decision-forests>=1.0.1",
"six>=1.12.0,<2",
"tensorflow-hub>=0.7.0,<0.13",
"packaging~=20.9",
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ jax>=0.3.16
importlib_resources>=5.9.0
protobuf<3.20,>=3.9.2
tensorflow>=2.10.0,<3
tensorflow-decision-forests>=1.0.1
six>=1.12.0,<2
tensorflow-hub>=0.7.0,<0.13; python_version >= "3"
packaging~=20.9
9 changes: 9 additions & 0 deletions tfjs-converter/python/tensorflowjs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ py_library(
deps = [requirement("tensorflow")],
)

py_library(
name = "expect_tensorflow_decision_forests_installed",
# This is a dummy rule used as a tensorflow dependency in open-source.
# We expect tensorflow-decision-forests to already be installed on
# the system, e.g. via
# `pip install tensorflow-decision-forests`.
deps = [requirement("tensorflow-decision-forests")],
)

py_library(
name = "expect_tensorflow_hub_installed",
# This is a dummy rule used as a tensorflow_hub dependency in open-source.
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ py_library(
":graph_rewrite_util",
"//tfjs-converter/python/tensorflowjs:expect_numpy_installed",
"//tfjs-converter/python/tensorflowjs:expect_packaging_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_decision_forests_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_hub_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
"//tfjs-converter/python/tensorflowjs:resource_loader",
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# File name for the indexing JSON file in an artifact directory.
ARTIFACT_MODEL_JSON_FILE_NAME = 'model.json'
ASSETS_DIRECTORY_NAME = 'assets'

# JSON string keys for fields of the indexing JSON.
ARTIFACT_MODEL_TOPOLOGY_KEY = 'modelTopology'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@

import json
import os
import shutil
import tempfile
from zipfile import ZipFile

# Required to load saved models that use TFDF.
import tensorflow_decision_forests
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import device_properties_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.io import gfile
from tensorflow.python.checkpoint.trackable_view import TrackableView
from tensorflow.python.eager import context
from tensorflow.python.framework import convert_to_constants
Expand Down Expand Up @@ -399,7 +405,7 @@ def write_artifacts(topology,
assert isinstance(weights_manifest, list)
model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

with tf.io.gfile.GFile(output_graph, 'w') as f:
with gfile.GFile(output_graph, 'w') as f:
json.dump(model_json, f)

def _remove_unused_control_flow_inputs(input_graph_def):
Expand All @@ -421,6 +427,49 @@ def _check_signature_in_model(saved_model, signature_name):
"are available: %s" % (signature_name,
saved_model.signatures.keys()))

def _copy_assets(saved_model_dir, output_dir):
input_assets_path = os.path.join(saved_model_dir, common.ASSETS_DIRECTORY_NAME)

if gfile.exists(input_assets_path) and gfile.isdir(input_assets_path):

tmp_dir = tempfile.mkdtemp()
zip_path = gfile.join(tmp_dir, common.ASSETS_DIRECTORY_NAME + '.zip')

with ZipFile(zip_path, 'w') as archive:
for (input_dir_path, _, file_names) in gfile.walk(input_assets_path):

relative_dir_path = os.path.relpath(input_dir_path, input_assets_path)

for file_name in file_names:

input_file_path = gfile.join(input_dir_path, file_name)
relative_file_path = gfile.join(relative_dir_path, file_name)

with gfile.GFile(input_file_path, 'rb') as input_file:
with archive.open(relative_file_path, 'w') as relative_file:
shutil.copyfileobj(input_file, relative_file)

output_assets_path = gfile.join(output_dir, common.ASSETS_DIRECTORY_NAME + '.zip')
gfile.copy(zip_path, output_assets_path, overwrite=True)

if gfile.isdir(tmp_dir):
gfile.rmtree(tmp_dir)

# TFDF stores the necessary files for its binary in the assets folder.
ASSET_REQUIRING_OPS = set([
'SimpleMLCreateModelResource'
'SimpleMLLoadModelFromPathWithHandle',
'SimpleMLInferenceOpWithHandle',
])

def _is_assets_required(model_ops):
return not ASSET_REQUIRING_OPS.isdisjoint(model_ops)

def _get_frozen_graph_ops(frozen_graph):
if frozen_graph is None:
return []
return [node.op for node in frozen_graph.as_graph_def().node]


def _freeze_saved_model_v1(saved_model_dir, saved_model_tags,
output_node_names):
Expand Down Expand Up @@ -745,8 +794,8 @@ def _convert_tf_saved_model(output_dir,
if signature_def is None:
signature_def = 'serving_default'

if not tf.io.gfile.exists(output_dir):
tf.io.gfile.makedirs(output_dir)
if not gfile.exists(output_dir):
gfile.makedirs(output_dir)
output_graph = os.path.join(
output_dir, common.ARTIFACT_MODEL_JSON_FILE_NAME)

Expand Down Expand Up @@ -852,6 +901,12 @@ def _convert_tf_saved_model(output_dir,
# tensorflow version.
tf_version = tf.__version__

if saved_model_dir:
model_ops = set(_get_frozen_graph_ops(frozen_graph)) |\
set(_get_frozen_graph_ops(frozen_initializer_graph))
if _is_assets_required(model_ops):
_copy_assets(saved_model_dir, output_dir)

optimize_graph(frozen_graph, signature,
output_graph, tf_version,
quantization_dtype_map=quantization_dtype_map,
Expand Down Expand Up @@ -1137,7 +1192,7 @@ def convert_tf_hub_module(module_handle, output_dir,
# TODO(vbardiovskyg): We can remove this v1 code path once loading of all v1
# modules is fixed on the TF side, or once the modules we cannot load become
# replaced with newer versions.
if tf.io.gfile.exists(os.path.join(module_path, _HUB_V1_MODULE_PB)):
if gfile.exists(os.path.join(module_path, _HUB_V1_MODULE_PB)):
print("Loading the module using TF 1.X interface from %s." % module_path)
convert_tf_hub_module_v1(module_path, output_dir, signature,
quantization_dtype_map,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import shutil
import tempfile
import unittest
import numpy as np

import tensorflow.compat.v2 as tf
from tensorflow_decision_forests.keras import GradientBoostedTreesModel
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
Expand All @@ -35,6 +37,7 @@
from tensorflowjs import version
from tensorflowjs.converters import graph_rewrite_util
from tensorflowjs.converters import tf_saved_model_conversion_v2
from tensorflowjs.converters.common import ASSETS_DIRECTORY_NAME

SAVED_MODEL_DIR = 'saved_model'
HUB_MODULE_DIR = 'hub_module'
Expand Down Expand Up @@ -246,6 +249,22 @@ def find_next_odd(v):
save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
save(root, save_dir, to_save)

def _create_saved_model_with_tfdf(self):
"""Test a basic TFDF model."""
P = 5
NUM_EXAMPLES = 10
NUM_FEATURES = 4

x_train = np.random.uniform(size=(NUM_EXAMPLES, NUM_FEATURES))
y_train = np.random.uniform(size=NUM_EXAMPLES) > 0.5
w_train = y_train * (P - 1) + 1 # 1 or p depending on the class.

model = GradientBoostedTreesModel()
model.fit(x=x_train, y=y_train, sample_weight=w_train)

save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
model.save(save_dir)

def _create_unsupported_saved_model(self):
root = tracking.AutoTrackable()
root.w = variables.Variable(tf.random.uniform([2, 2]))
Expand Down Expand Up @@ -936,6 +955,31 @@ def test_convert_saved_model_with_control_flow_v2(self):
glob.glob(
os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*')))

def test_convert_saved_model_with_tfdf(self):
self._create_saved_model_with_tfdf()

tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
tf_saved_model_conversion_v2.convert_tf_saved_model(
tfjs_path, tfjs_path, skip_op_check=True
)

# Check model.json and weights manifest.
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
model_json = json.load(f)

# Check TFDF ops are present.
model_ops = [node['op'] for node in model_json['modelTopology']['node']]
self.assertTrue('SimpleMLInferenceOpWithHandle' in model_ops)

initializer_ops = [node['op'] for node in model_json['modelInitializer']['node']]
self.assertTrue('SimpleMLCreateModelResource' in initializer_ops)
self.assertTrue('SimpleMLLoadModelFromPathWithHandle' in initializer_ops)

# Check assets containing TFDF files were copied over.
self.assertTrue(
os.path.exists(
os.path.join(tfjs_path, ASSETS_DIRECTORY_NAME + '.zip')))

def test_convert_saved_model_sharded(self):
self._create_saved_model()
model_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
Expand Down