Skip to content

Commit

Permalink
[FIX](handler) Fix model graph generation
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Aug 23, 2022
1 parent 046c809 commit 0ab746e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 17 deletions.
24 changes: 11 additions & 13 deletions hourglass_tensorflow/models/hourglass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from doctest import OutputChecker

import tensorflow as tf
from keras.models import Model

Expand Down Expand Up @@ -62,16 +60,16 @@ def __init__(

def call(self, inputs: tf.Tensor, training=True):
x = self.downsampling(inputs)
outputs = []
outputs_list = []
for layer in self.hourglasses:
x, y = layer(x)
if self._intermediate_supervision:
outputs.append(y)
outputs_list.append(y)
if self._intermediate_supervision:
self.output = tf.stack(outputs, axis=1, name="NetworkStackedOutput")
self._outputs = tf.stack(outputs_list, axis=1, name="NetworkStackedOutput")
else:
self.output = y
return self.output
self._outputs = y
return self._outputs


def model_as_layers(
Expand Down Expand Up @@ -114,18 +112,18 @@ def model_as_layers(
]

x = downsampling(inputs)
outputs = []
output_list = []
for layer in hourglasses:
x, y = layer(x)
if intermediate_supervision:
outputs.append(y)
output_list.append(y)
if intermediate_supervision:
output = tf.stack(outputs, axis=1, name="NetworkStackedOutput")
outputs = tf.stack(output_list, axis=1, name="NetworkStackedOutput")
else:
output = y
outputs = y

model = Model(inputs=inputs, outputs=output)
model = Model(inputs=inputs, outputs=outputs)

return HTFModelAsLayers(
downsampling=downsampling, hourglasses=hourglasses, outputs=output, model=model
downsampling=downsampling, hourglasses=hourglasses, outputs=outputs, model=model
)
5 changes: 1 addition & 4 deletions hourglass_tensorflow/types/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@ class HTFModelHandlerReturnObject(TypedDict):
model: keras.models.Model


class HTFModelAsLayers(BaseModel):
class HTFModelAsLayers(TypedDict):
downsampling: keras.layers.Layer
hourglasses: List[keras.layers.Layer]
outputs: keras.layers.Layer
model: keras.models.Model

class Config:
arbitrary_types_allowed = True


class HTFModelParams(HTFConfigField):
name: str = "HourglassNetwork"
Expand Down

0 comments on commit 0ab746e

Please sign in to comment.