From 0ab746e09ec8b0690fa1aa4fc2749ff7976b78bb Mon Sep 17 00:00:00 2001 From: wbenbihi Date: Tue, 23 Aug 2022 16:53:29 +0800 Subject: [PATCH] [FIX](handler) Fix model graph generation --- hourglass_tensorflow/models/hourglass.py | 24 ++++++++++------------ hourglass_tensorflow/types/config/model.py | 5 +---- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/hourglass_tensorflow/models/hourglass.py b/hourglass_tensorflow/models/hourglass.py index dd6bceb..b821a68 100644 --- a/hourglass_tensorflow/models/hourglass.py +++ b/hourglass_tensorflow/models/hourglass.py @@ -1,5 +1,3 @@ -from doctest import OutputChecker - import tensorflow as tf from keras.models import Model @@ -62,16 +60,16 @@ def __init__( def call(self, inputs: tf.Tensor, training=True): x = self.downsampling(inputs) - outputs = [] + outputs_list = [] for layer in self.hourglasses: x, y = layer(x) if self._intermediate_supervision: - outputs.append(y) + outputs_list.append(y) if self._intermediate_supervision: - self.output = tf.stack(outputs, axis=1, name="NetworkStackedOutput") + self._outputs = tf.stack(outputs_list, axis=1, name="NetworkStackedOutput") else: - self.output = y - return self.output + self._outputs = y + return self._outputs def model_as_layers( @@ -114,18 +112,18 @@ def model_as_layers( ] x = downsampling(inputs) - outputs = [] + output_list = [] for layer in hourglasses: x, y = layer(x) if intermediate_supervision: - outputs.append(y) + output_list.append(y) if intermediate_supervision: - output = tf.stack(outputs, axis=1, name="NetworkStackedOutput") + outputs = tf.stack(output_list, axis=1, name="NetworkStackedOutput") else: - output = y + outputs = y - model = Model(inputs=inputs, outputs=output) + model = Model(inputs=inputs, outputs=outputs) return HTFModelAsLayers( - downsampling=downsampling, hourglasses=hourglasses, outputs=output, model=model + downsampling=downsampling, hourglasses=hourglasses, outputs=outputs, model=model ) diff --git a/hourglass_tensorflow/types/config/model.py b/hourglass_tensorflow/types/config/model.py index 8a71072..f133acc 100644 --- a/hourglass_tensorflow/types/config/model.py +++ b/hourglass_tensorflow/types/config/model.py @@ -27,15 +27,12 @@ class HTFModelHandlerReturnObject(TypedDict): model: keras.models.Model -class HTFModelAsLayers(BaseModel): +class HTFModelAsLayers(TypedDict): downsampling: keras.layers.Layer hourglasses: List[keras.layers.Layer] outputs: keras.layers.Layer model: keras.models.Model - class Config: - arbitrary_types_allowed = True - class HTFModelParams(HTFConfigField): name: str = "HourglassNetwork"