diff --git a/sleap/nn/peak_finding.py b/sleap/nn/peak_finding.py index 84dca00ae..e1fb43a6e 100644 --- a/sleap/nn/peak_finding.py +++ b/sleap/nn/peak_finding.py @@ -221,7 +221,7 @@ def find_global_peaks_rough( channels = tf.cast(tf.shape(cms)[-1], tf.int64) total_peaks = tf.cast(tf.shape(argmax_cols)[0], tf.int64) sample_subs = tf.range(total_peaks, dtype=tf.int64) // channels - channel_subs = tf.range(total_peaks, dtype=tf.int64) % channels + channel_subs = tf.math.mod(tf.range(total_peaks, dtype=tf.int64), channels) # Gather subscripts. peak_subs = tf.stack([sample_subs, argmax_rows, argmax_cols, channel_subs], axis=1)