diff --git a/ctlearn/default_models/single_cnn.py b/ctlearn/default_models/single_cnn.py index 9b3c6f05..e721a752 100644 --- a/ctlearn/default_models/single_cnn.py +++ b/ctlearn/default_models/single_cnn.py @@ -49,4 +49,4 @@ def single_cnn_model(data, model_params): model = tf.keras.Model(network_input, output, name=network_name) - return model, network_input + return model, [network_input] diff --git a/ctlearn/run_model.py b/ctlearn/run_model.py index a48c1c90..1209e193 100644 --- a/ctlearn/run_model.py +++ b/ctlearn/run_model.py @@ -16,6 +16,7 @@ import tensorflow as tf from tensorflow.python import debug as tf_debug +import tf2onnx from dl1_data_handler.reader import DL1DataReaderSTAGE1, DL1DataReaderDL1DH from ctlearn.data_loader import KerasBatchGenerator @@ -356,10 +357,14 @@ def run_model(config, mode="train", debug=False, log_to_file=False): workers=workers, use_multiprocessing=use_multiprocessing, ) - - model.save(model_dir) - logger.info("Training and evaluating finished succesfully!") + model.save(model_dir) + logger.info("Keras model saved in {}saved_model.pb".format(model_dir)) + logger.info("Converting Keras model into ONNX format...") + input_type_spec = [input._type_spec for input in backbone_inputs] + output_path = model_dir + model.name + ".onnx" + tf2onnx.convert.from_keras(model, input_signature=input_type_spec, output_path=output_path) + logger.info("ONNX model saved in {}".format(output_path)) # Plotting training history training_log = pd.read_csv(model_dir + "/training_log.csv") diff --git a/environment.yml b/environment.yml index f99f284b..779aac65 100644 --- a/environment.yml +++ b/environment.yml @@ -6,7 +6,7 @@ channels: - cta-observatory - ctlearn-project dependencies: - - python=3.10 + - python=3.9 - dl1_data_handler=0.10.6 - astropy - matplotlib @@ -17,5 +17,6 @@ dependencies: - scikit-learn - pip: # TensorFlow via pip - - tensorflow + - tensorflow==2.8 + - tf2onnx - pydot