diff --git a/backend/core/tasks.py b/backend/core/tasks.py index b15c2475..3069da39 100644 --- a/backend/core/tasks.py +++ b/backend/core/tasks.py @@ -222,6 +222,17 @@ def train_model( # Save the model in HDF5 format model.save(os.path.join(output_path, "checkpoint.h5")) + logger.info(model.inputs) + logger.info(model.outputs) + + # Convert the model to tflite for android/ios. + converter = tf.lite.TFLiteConverter.from_keras_model(model) + tflite_model = converter.convert() + + # Save the model. + with open(os.path.join(output_path, "checkpoint.tflite"), 'wb') as f: + f.write(tflite_model) + # now remove the ramp-data all our outputs are copied to our training workspace shutil.rmtree(base_path) training_instance.accuracy = float(final_accuracy)