diff --git a/src/typings/models.py b/src/typings/models.py index 250e118..0bf6bc2 100644 --- a/src/typings/models.py +++ b/src/typings/models.py @@ -31,8 +31,13 @@ def __init__( self.__input_shape = input_shape - self.__model: keras.Model = keras.Sequential( - [keras.Input(self.__input_shape), *self._model().layers], name=self._name + self.__model: keras.Model = ( + keras.Sequential( + [keras.Input(self.__input_shape), *self._model().layers], + name=self._name, + ) + if len(self._model().layers) != 0 + else self._model() ) self.intensity: float = intensity @@ -115,7 +120,7 @@ def compile(self): ) def predict(self, inputs): - outs = self.__model.predict(inputs) + outs = self.__model(inputs) return inputs + (outs - inputs) * self.intensity if self.intensity < 1 else outs def train(self, epochs: int = 100):