diff --git a/Colorful/src/utils/general_utils.py b/Colorful/src/utils/general_utils.py index 1cf7490..03b7f9b 100644 --- a/Colorful/src/utils/general_utils.py +++ b/Colorful/src/utils/general_utils.py @@ -69,7 +69,7 @@ def plot_batch(color_model, q_ab, X_batch_black, X_batch_color, batch_size, h, w list_img.append(arr) plt.figure(figsize=(20,20)) - list_img = [np.concatenate(list_img[4 * i: 4 * (i + 1)], axis=2) for i in range(len(list_img) / 4)] + list_img = [np.concatenate(list_img[4 * i: 4 * (i + 1)], axis=2) for i in range(len(list_img) // 4)] arr = np.concatenate(list_img, axis=1) plt.imshow(arr.transpose(1,2,0)) ax = plt.gca() @@ -111,7 +111,7 @@ def plot_batch_eval(color_model, q_ab, X_batch_black, X_batch_color, batch_size, list_img.append(arr) plt.figure(figsize=(20,20)) - list_img = [np.concatenate(list_img[4 * i: 4 * (i + 1)], axis=2) for i in range(len(list_img) / 4)] + list_img = [np.concatenate(list_img[4 * i: 4 * (i + 1)], axis=2) for i in range(len(list_img) // 4)] arr = np.concatenate(list_img, axis=1) plt.imshow(arr.transpose(1,2,0)) ax = plt.gca()