Skip to content

Commit

Permalink
convert keras model into onnx format (closes #143)
Browse files Browse the repository at this point in the history
TF-ONNX packages is only supporting pyhton v3.9 and tensorflow v2.8, so we downgraded to those versions. (Tensorflow 2.9 was recently released)
  • Loading branch information
TjarkMiener committed Jun 30, 2022
1 parent ef1b9b0 commit 81510dc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion ctlearn/default_models/single_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def single_cnn_model(data, model_params):

model = tf.keras.Model(network_input, output, name=network_name)

return model, network_input
return model, [network_input]
11 changes: 8 additions & 3 deletions ctlearn/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import tensorflow as tf
from tensorflow.python import debug as tf_debug
import tf2onnx

from dl1_data_handler.reader import DL1DataReaderSTAGE1, DL1DataReaderDL1DH
from ctlearn.data_loader import KerasBatchGenerator
Expand Down Expand Up @@ -356,10 +357,14 @@ def run_model(config, mode="train", debug=False, log_to_file=False):
workers=workers,
use_multiprocessing=use_multiprocessing,
)

model.save(model_dir)

logger.info("Training and evaluating finished succesfully!")
model.save(model_dir)
logger.info("Keras model saved in {}saved_model.pb".format(model_dir))
logger.info("Converting Keras model into ONNX format...")
input_type_spec = [input._type_spec for input in backbone_inputs]
output_path = model_dir + model.name + ".onnx"
tf2onnx.convert.from_keras(model, input_signature=input_type_spec, output_path=output_path)
logger.info("ONNX model saved in {}".format(output_path))

# Plotting training history
training_log = pd.read_csv(model_dir + "/training_log.csv")
Expand Down
5 changes: 3 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ channels:
- cta-observatory
- ctlearn-project
dependencies:
- python=3.10
- python=3.9
- dl1_data_handler=0.10.6
- astropy
- matplotlib
Expand All @@ -17,5 +17,6 @@ dependencies:
- scikit-learn
- pip:
# TensorFlow via pip
- tensorflow
- tensorflow==2.8
- tf2onnx
- pydot

0 comments on commit 81510dc

Please sign in to comment.