Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrating 3D image classification from CT scans example to Keras 3 #1725

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 32 additions & 37 deletions examples/vision/3D_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Title: 3D image classification from CT scans
Author: [Hasib Zunair](https://twitter.com/hasibzunair)
Date created: 2020/09/23
Last modified: 2020/09/23
Last modified: 2024/01/11
Description: Train a 3D convolutional neural network to predict presence of pneumonia.
Accelerator: GPU
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
"""
"""
## Introduction
Expand All @@ -27,12 +28,15 @@
"""

import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import zipfile
import numpy as np
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras import ops
from keras import layers
import tensorflow as tf

"""
## Downloading the MosMedData: Chest CT Scans with COVID-19 Related Findings
Expand Down Expand Up @@ -177,19 +181,19 @@ def process_scan(path):

# Read and process the scans.
# Each scan is resized across height, width, and depth and rescaled.
abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])
abnormal_scans = ops.array([process_scan(path) for path in abnormal_scan_paths])
normal_scans = ops.array([process_scan(path) for path in normal_scan_paths])

# For the CT scans having presence of viral pneumonia
# assign 1, for the normal ones assign 0.
abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
normal_labels = np.array([0 for _ in range(len(normal_scans))])
abnormal_labels = ops.array([1 for _ in range(len(abnormal_scans))])
normal_labels = ops.array([0 for _ in range(len(normal_scans))])

# Split data in the ratio 70-30 for training and validation.
x_train = np.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)
y_train = np.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)
x_val = np.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)
y_val = np.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)
x_train = ops.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)
y_train = ops.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)
x_val = ops.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)
y_val = ops.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)
print(
"Number of samples in train and validation are %d and %d."
% (x_train.shape[0], x_val.shape[0])
Expand All @@ -210,37 +214,28 @@ def process_scan(path):

from scipy import ndimage

random_rotation = layers.RandomRotation(factor=(-0.06, 0.06))


@tf.function
def rotate(volume):
"""Rotate the volume by a few degrees"""

def scipy_rotate(volume):
# define some rotation angles
angles = [-20, -10, -5, 5, 10, 20]
# pick angles at random
angle = random.choice(angles)
# rotate volume
volume = ndimage.rotate(volume, angle, reshape=False)
volume[volume < 0] = 0
volume[volume > 1] = 1
return volume

augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32)
return augmented_volume
# rotate volume
volume = random_rotation(volume)
volume = ops.clip(volume, 0, 1)
return volume


def train_preprocessing(volume, label):
"""Process training data by rotating and adding a channel."""
# Rotate volume
volume = rotate(volume)
volume = tf.expand_dims(volume, axis=3)
volume = ops.expand_dims(volume, axis=3)
return volume, label


def validation_preprocessing(volume, label):
"""Process validation data by only adding a channel."""
volume = tf.expand_dims(volume, axis=3)
volume = ops.expand_dims(volume, axis=3)
return volume, label


Expand Down Expand Up @@ -278,10 +273,10 @@ def validation_preprocessing(volume, label):

data = train_dataset.take(1)
images, labels = list(data)[0]
images = images.numpy()
images = ops.convert_to_numpy(images)
image = images[0]
print("Dimension of the CT scan is:", image.shape)
plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")
plt.imshow(ops.squeeze(image[:, :, 30]), cmap="gray")


"""
Expand All @@ -291,9 +286,9 @@ def validation_preprocessing(volume, label):

def plot_slices(num_rows, num_columns, width, height, data):
"""Plot a montage of 20 CT slices"""
data = np.rot90(np.array(data))
data = np.transpose(data)
data = np.reshape(data, (num_rows, num_columns, width, height))
data = ndimage.rotate(ops.array(data), 90, reshape=False)
data = ops.transpose(data)
data = ops.reshape(data, (num_rows, num_columns, width, height))
rows_data, columns_data = data.shape[0], data.shape[1]
heights = [slc[0].shape[0] for slc in data]
widths = [slc.shape[1] for slc in data[0]]
Expand Down Expand Up @@ -379,7 +374,7 @@ def get_model(width=128, height=128, depth=64):

# Define callbacks.
checkpoint_cb = keras.callbacks.ModelCheckpoint(
"3d_image_classification.h5", save_best_only=True
"3d_image_classification.keras", save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)

Expand Down Expand Up @@ -426,8 +421,8 @@ def get_model(width=128, height=128, depth=64):
"""

# Load best weights.
model.load_weights("3d_image_classification.h5")
prediction = model.predict(np.expand_dims(x_val[0], axis=0))[0]
model.load_weights("3d_image_classification.keras")
prediction = model.predict(ops.expand_dims(x_val[0], axis=0))[0]
scores = [1 - prediction[0], prediction[0]]

class_names = ["normal", "abnormal"]
Expand Down
Loading
Loading