Skip to content

Commit

Permalink
(#4) Refactor: Support functional models
Browse files Browse the repository at this point in the history
  • Loading branch information
betarixm committed May 25, 2022
1 parent 33255a3 commit eb70dcd
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/typings/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
),
checkpoint_filepath: str = None,
tensorboard_log_path: str = None,
is_functional: bool = False,
):
self._name: str = name

Expand All @@ -31,14 +32,17 @@ def __init__(

self.__input_shape = input_shape

self.__model: keras.Model = (
keras.Sequential(
if is_functional:
inputs = keras.Input(self.__input_shape)
outputs = self._model()(inputs)
self.__model: keras.Model = keras.Model(inputs, outputs, name=self._name)
elif len(self._model().layers) != 0:
self.__model: keras.Model = keras.Sequential(
[keras.Input(self.__input_shape), *self._model().layers],
name=self._name,
)
if len(self._model().layers) != 0
else self._model()
)
else:
self.__model: keras.Model = self._model()

self.intensity: float = intensity

Expand Down

0 comments on commit eb70dcd

Please sign in to comment.