diff --git a/src/detect.py b/src/detect.py index 67c79ea..178c7f8 100644 --- a/src/detect.py +++ b/src/detect.py @@ -34,7 +34,12 @@ def main(data_dir, model, in_weights_path, visualization_path, batch_size, os.path.join(data_dir, 'label_colors.txt')) # set TensorFlow seed - tf.random.set_seed(seed) + if seed is not None: + import sys + if int(tf.__version__.split('.')[1]) < 4: + tf.random.set_seed(seed) + else: + tf.keras.utils.set_random_seed(seed) model = create_model(model, len(id2code), nr_bands, tensor_shape, backbone=backbone)