Skip to content

Commit

Permalink
Generate tflite model format as well along with .h5
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijrajsharma committed Oct 19, 2023
1 parent 9e44525 commit f712827
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions backend/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,17 @@ def train_model(
# Save the model in HDF5 format
model.save(os.path.join(output_path, "checkpoint.h5"))

logger.info(model.inputs)
logger.info(model.outputs)

# Convert the model to tflite for android/ios.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the model.
with open(os.path.join(output_path, "checkpoint.tflite"), 'wb') as f:
f.write(tflite_model)

# now remove the ramp-data all our outputs are copied to our training workspace
shutil.rmtree(base_path)
training_instance.accuracy = float(final_accuracy)
Expand Down

0 comments on commit f712827

Please sign in to comment.