Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for Sequential model with multiple inputs. #823

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tf_keras/engine/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2321,7 +2321,7 @@ def build_from_config(self, config):
"""
input_shape = config["input_shape"]
if input_shape is not None:
self.build(input_shape)
self.build(tf_utils.convert_shapes(input_shape, to_tuples=False))

############################################################################
# Methods & attributes below are all private and only used by the framework.
Expand Down
41 changes: 29 additions & 12 deletions tf_keras/engine/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,16 @@ def _build_graph_network_for_inferred_shape(
):
# Determine whether the input shape is novel, i.e. whether the model
# should be rebuilt.
input_shape = tuple(input_shape)
input_shape = tf_utils.convert_shapes(input_shape)
if self._inferred_input_shape is None:
new_shape = input_shape
else:
new_shape = relax_input_shape(
self._inferred_input_shape, input_shape
new_shape = tf.nest.map_structure(
_relax_input_shape,
tf_utils.convert_shapes(
self._inferred_input_shape, to_tuples=False
),
tf_utils.convert_shapes(input_shape, to_tuples=False),
)
if (
new_shape is not None
Expand All @@ -299,10 +303,13 @@ def _build_graph_network_for_inferred_shape(
# A novel shape has been received: we need to rebuild the model.
# In case we are inside a graph function, we step out of it.
with tf.init_scope():
inputs = input_layer.Input(
batch_shape=new_shape,
dtype=input_dtype,
name=self.layers[0].name + "_input",
inputs = tf.nest.map_structure(
lambda s: input_layer.Input(
batch_shape=tf_utils.convert_shapes(s),
dtype=input_dtype,
name=self.layers[0].name + "_input",
),
tf_utils.convert_shapes(new_shape, to_tuples=False),
)
layer_input = inputs
created_nodes = set()
Expand Down Expand Up @@ -370,7 +377,7 @@ def build(self, input_shape=None):
raise ValueError("You must provide an `input_shape` argument.")
self._build_graph_network_for_inferred_shape(input_shape)
if not self.built:
input_shape = tuple(input_shape)
input_shape = tf_utils.convert_shapes(input_shape)
self._build_input_shape = input_shape
super().build(input_shape)
self.built = True
Expand Down Expand Up @@ -435,7 +442,8 @@ def compute_mask(self, inputs, mask):
def get_config(self):
layer_configs = []
serialize_obj_fn = serialization_lib.serialize_keras_object
if getattr(self, "use_legacy_config", None):
use_legacy_config = getattr(self, "use_legacy_config", False)
if use_legacy_config:
serialize_obj_fn = legacy_serialization.serialize_keras_object
for layer in super().layers:
# `super().layers` include the InputLayer if available (it is
Expand All @@ -446,7 +454,11 @@ def get_config(self):
config = training.Model.get_config(self)
config["name"] = self.name
config["layers"] = copy.deepcopy(layer_configs)
if not self._is_graph_network and self._build_input_shape is not None:
if (
use_legacy_config
and not self._is_graph_network
and self._build_input_shape
):
config["build_input_shape"] = self._build_input_shape
return config

Expand All @@ -458,6 +470,7 @@ def from_config(cls, config, custom_objects=None):
layer_configs = config["layers"]
else:
name = None
build_input_shape = None
layer_configs = config
model = cls(name=name)
for layer_config in layer_configs:
Expand Down Expand Up @@ -519,11 +532,15 @@ def _get_shape_tuple(t):
return None


def relax_input_shape(shape_1, shape_2):
def _relax_input_shape(shape_1, shape_2):
if shape_1 is None or shape_2 is None:
return None
if len(shape_1) != len(shape_2):
if shape_1.rank is None or shape_2.rank is None:
return None
if shape_1.rank != shape_2.rank:
return None
shape_1 = shape_1.as_list()
shape_2 = shape_2.as_list()
return tuple(None if d1 != d2 else d1 for d1, d2 in zip(shape_1, shape_2))


Expand Down
27 changes: 27 additions & 0 deletions tf_keras/engine/sequential_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from absl.testing import parameterized

import tf_keras as keras
from tf_keras.saving import object_registration
from tf_keras.testing_infra import test_combinations
from tf_keras.testing_infra import test_utils

Expand Down Expand Up @@ -574,6 +575,22 @@ def test_multi_inputs_outputs(self):
model(image_inputs)
model.fit(x=image_inputs, y=image_inputs, steps_per_epoch=1)

@test_combinations.run_all_keras_modes(always_skip_v1=True)
def test_multi_inputs_build(self):
model = keras.Sequential([ImageMultiplyLayer()])
model.build({"images": (None, 512, 512, 3), "weights": (None, 3)})

image_inputs = tf.ones((2, 512, 512, 3))
weight_inputs = tf.ones((2, 3))
output = model({"images": image_inputs, "weights": weight_inputs})

config = model.to_json()
new_model = keras.models.model_from_json(config)
new_output = new_model(
{"images": image_inputs, "weights": weight_inputs}
)
self.assertAllClose(output, new_output)


class TestSequentialEagerIntegration(test_combinations.TestCase):
@test_combinations.run_all_keras_modes
Expand Down Expand Up @@ -642,10 +659,20 @@ def test_build_empty_network(self):
self.assertTrue(model.built)


@object_registration.register_keras_serializable()
class ImageAugmentLayer(keras.layers.Layer):
def call(self, inputs):
return inputs


@object_registration.register_keras_serializable()
class ImageMultiplyLayer(keras.layers.Layer):
def call(self, inputs):
images = inputs["images"]
weights = inputs["weights"]
images = tf.reshape(images, (-1, 1, 1, 3))
return images * weights


if __name__ == "__main__":
tf.test.main()
Loading