From 77365fd77de39b7adae5473ee3aa3e0ac9a25797 Mon Sep 17 00:00:00 2001 From: Patrick Cleeve Date: Fri, 26 Jul 2024 14:23:06 +1000 Subject: [PATCH] [feat] add onnx-window model to backend (ppliteseg-nysbc-ccameron) --- .gitignore | 3 + fibsem/segmentation/model.py | 11 +- fibsem/segmentation/onnx_model.py | 154 ++++++++++++++++++++++++- scripts/onnx_notebook.ipynb | 100 ++++++++++++++++ scripts/onnx_pred.py | 185 ++++++++++++++++++++++++++++++ 5 files changed, 448 insertions(+), 5 deletions(-) create mode 100644 scripts/onnx_notebook.ipynb create mode 100644 scripts/onnx_pred.py diff --git a/.gitignore b/.gitignore index 1eda6794..48d5b1c6 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,6 @@ fibsem/config/*.yaml tests/notebook.ipynb fibsem/db/v2/notebook.ipynb scripts/segformer-*/* +*.onnx +scripts/**/*.jpeg +scripts/**/*.png \ No newline at end of file diff --git a/fibsem/segmentation/model.py b/fibsem/segmentation/model.py index 3ae70b07..e49727e4 100644 --- a/fibsem/segmentation/model.py +++ b/fibsem/segmentation/model.py @@ -202,8 +202,15 @@ def load_model( from fibsem.segmentation.nnunet_model import SegmentationModelNNUnet model = SegmentationModelNNUnet(checkpoint=checkpoint) elif backend == "onnx": - from fibsem.segmentation.onnx_model import SegmentationModelONNX - model = SegmentationModelONNX(checkpoint=checkpoint) + from fibsem.segmentation.onnx_model import SegmentationModelONNX, SegmentationModelWindowONNX + for onnx_model in [SegmentationModelWindowONNX, SegmentationModelONNX]: + try: + logging.debug(f"Trying to load {onnx_model}") + model = onnx_model(checkpoint=checkpoint) + break + except Exception as e: + logging.debug(f"Failed to load {type(onnx_model)} for {checkpoint}: {e}") + elif backend == "huggingface": from fibsem.segmentation.hf_segmentation_model import SegmentationModelHuggingFace model = SegmentationModelHuggingFace(checkpoint=checkpoint) diff --git a/fibsem/segmentation/onnx_model.py b/fibsem/segmentation/onnx_model.py index cff30226..11d2864f 100644 --- a/fibsem/segmentation/onnx_model.py +++ b/fibsem/segmentation/onnx_model.py @@ -1,13 +1,18 @@ import os +import logging + +import cv2 import numpy as np import onnx import onnxruntime +import PIL.Image +from onnxruntime import InferenceSession +from skimage.util.shape import view_as_windows from fibsem.segmentation.utils import decode_segmap_v2, download_checkpoint - ### ONNX @@ -52,9 +57,10 @@ def inference(self, img: np.ndarray, rgb: bool = True): def export_model_to_onnx(checkpoint: str, onnx_path: str): - from fibsem.segmentation.model import load_model import torch + from fibsem.segmentation.model import load_model + # get fibsem model model = load_model(checkpoint) model.model.to("cpu") @@ -96,4 +102,146 @@ def to_numpy(tensor): print("Exported model has been tested with ONNXRuntime, and the result looks good!") -# export_model_to_onnx("autolamella-mega-latest.pt", "autolamella-mega-20231230.onnx") \ No newline at end of file +# export_model_to_onnx("autolamella-mega-latest.pt", "autolamella-mega-20231230.onnx") + +## PPLITESEG WINDOWED MODEL +def load_windowed_onnx_model(model_path: str) -> tuple: + """ + Load the ONNX model. + + Args: + model_path (str): File path to the ONNX model. + + Returns: + InferenceSession: The ONNX model session. + str: The name of the input tensor. + Tuple[int, int]: The shape of the input tensor. + str: The name of the output tensor. + """ + session = InferenceSession(model_path) + input_name = session.get_inputs()[0].name + window_shape = session.get_inputs()[0].shape[2:] + output_name = session.get_outputs()[0].name + + return session, input_name, window_shape, output_name + +def standardize(img: object, sigma: float = 24.0) -> np.ndarray: + """ + Standardize the pixel intensities of the provided image. + + Args: + img (np.ndarray): The image to standardize. + sigma (float): The standard deviation of the Gaussian kernel. + + Returns: + np.ndarray: The standardized image. + """ + # subtract local mean + smooth = cv2.GaussianBlur(img, (0, 0), sigmaX=sigma) + img = np.subtract(img, smooth) + # scale pixel intensities + img = img / np.std(img) + del smooth + + return img.astype(np.float32) + + +class SegmentationModelWindowONNX: + def __init__(self, checkpoint: str = None): + if checkpoint is not None: + self.load_model(checkpoint) + self.device = None + + def load_model(self, checkpoint="autolamella-mega.onnx"): + # download checkpoint if needed + # checkpoint = download_checkpoint(checkpoint) + self.checkpoint = os.path.basename(checkpoint) + + # load inference session + session = load_windowed_onnx_model(checkpoint) + self.session, self.input_name, self.window_shape, self.output_name = session + + def pre_process(self, img: np.ndarray) -> np.ndarray: + """Pre-process the image for inference, calculate window parameters""" + ##### PREPROCESSING + if img.ndim == 2: # 2d grayscale image -> 3d rgb (grayscale) + img = np.array(PIL.Image.fromarray(img).convert("RGB")) + + if img.dtype != np.float32: + img = img.astype(np.float32) + + # image and window parameters + h, w, c = img.shape + stride = self.window_shape[0] // 5 # MAGIC_NUMBER + + # standardize image + img = standardize(img) + + # transpose dimensions from HWC to CHW + img = np.transpose(img, (2, 0, 1)) + + # pad image for sliding window + pad_h = max(self.window_shape[0] - h % stride, 0) + pad_w = max(self.window_shape[1] - w % stride, 0) + img = np.pad(img, ((0, 0), (0, pad_h), (0, pad_w)), mode="constant") + _, pad_h, pad_w = img.shape + + # window input image + windows: np.ndarray = view_as_windows( + img, (3, self.window_shape[0], self.window_shape[1]), step=stride + ).squeeze() + + logging.debug(f"pre_process: {img.shape}, {windows.shape}, {h}, {w}, {pad_h}, {pad_w}, {stride}") + + return img, windows, h, w, pad_h, pad_w, stride + + def inference(self, img: np.ndarray, rgb: bool = True) -> np.ndarray: + """Perform inference on the provided image. + Args: + img (np.ndarray): The image to segment. + rgb (bool): Whether to return an RGB image. + Returns: + np.ndarray: The segmented image.""" + + # pre-process image + img, windows, h, w, pad_h, pad_w, stride = self.pre_process(img) + + # inference on each window + container = None + count = np.zeros([1, 1, pad_h, pad_w]) + for i in range(windows.shape[0]): + for j in range(windows.shape[1]): + window = windows[i, j] + h_start = i * stride + w_start = j * stride + h_end = h_start + self.window_shape[0] + w_end = w_start + self.window_shape[1] + + # add batch dimension to window + logits = self.session.run( + [self.output_name], + {self.input_name: window[np.newaxis, ...]} + )[0].squeeze() + del window + + # add logits to container + if container is None: + container = np.zeros([1, logits.shape[0], pad_h, pad_w]) + container[:, :, h_start:h_end, w_start:w_end] += logits + count[:, :, h_start:h_end, w_start:w_end] += 1 + del h_start, w_start, h_end, w_end, logits + + assert ( + np.min(count) == 1 + ), "There are pixels not predicted. Check window and stride size." + + # post-process + # average the predictions across windows + mask = np.argmax(container / count, axis=1).squeeze() # 2d class map + del container, count + # crop image to remove padding + mask = mask[:h, :w] + + if rgb: + mask = decode_segmap_v2(mask) + return mask diff --git a/scripts/onnx_notebook.ipynb b/scripts/onnx_notebook.ipynb new file mode 100644 index 00000000..47d99ef5 --- /dev/null +++ b/scripts/onnx_notebook.ipynb @@ -0,0 +1,100 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ONNX Windowed Model Integration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import glob\n", + "import os\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import PIL.Image\n", + "\n", + "from fibsem.segmentation.model import load_model\n", + "from fibsem.structures import FibsemImage\n", + "\n", + "# image filenames\n", + "PATH = \"example_imgs/input\"\n", + "filenames = glob.glob(PATH + \"/*.jpeg\")\n", + "\n", + "# PATH = \"/home/patrick/github/data/autolamella-paper/model-development/train/waffle/test\"\n", + "# filenames = glob.glob(PATH + \"/*.tif\")\n", + "\n", + "# load model\n", + "MODEL_PATH = \"ppliteseg_fibsem_07022024_512x512_128k.onnx\"\n", + "model = load_model(checkpoint=MODEL_PATH)\n", + "\n", + "os.makedirs(\"example_imgs/output/test\", exist_ok=True)\n", + "\n", + "for i, filename in enumerate(filenames):\n", + " print(f\"Processing {i+1}/{len(filenames)}: {filename}\")\n", + "\n", + " # load image\n", + " if \"tif\" in filename:\n", + " image = FibsemImage.load(filename)\n", + " else:\n", + " image = FibsemImage(data=np.asarray(PIL.Image.open(filename)))\n", + " \n", + " # inference\n", + " rgb = model.inference(image.data)\n", + "\n", + " fig = plt.figure(figsize=(10, 10))\n", + " plt.title(f\"Predicted: {os.path.basename(filename)}\", fontsize=10)\n", + " plt.imshow(image.data, cmap=\"gray\")\n", + " plt.imshow(rgb, alpha=0.5)\n", + " plt.axis(\"off\")\n", + " plt.show()\n", + "\n", + " # save figure\n", + " fig.savefig(f\"example_imgs/output/test/{os.path.basename(filename)}\".replace(\".tif\", \".png\"), bbox_inches=\"tight\")\n", + " plt.close(fig)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fibsem", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/scripts/onnx_pred.py b/scripts/onnx_pred.py new file mode 100644 index 00000000..24fb03b4 --- /dev/null +++ b/scripts/onnx_pred.py @@ -0,0 +1,185 @@ +#!/usr/bin/python3 +# +# onnx_pred.py - inference using ONNX model +# author: Christopher JF Cameron +# + +import argparse +import cv2 # type: ignore +import numpy as np # type: ignore +import os + +from skimage.util.shape import view_as_windows # type: ignore +from onnxruntime import InferenceSession # type: ignore + + +def load_onnx_model(model_path: str): + """ + Load the ONNX model. + + Args: + model_path (str): File path to the ONNX model. + + Returns: + InferenceSession: The ONNX model session. + str: The name of the input tensor. + Tuple[int, int]: The shape of the input tensor. + str: The name of the output tensor. + """ + session = InferenceSession(model_path) + input_name = session.get_inputs()[0].name + window_shape = session.get_inputs()[0].shape[2:] + output_name = session.get_outputs()[0].name + + return session, input_name, window_shape, output_name + + +def parse_args(): + """ + Parse command line arguments. + + Returns: + argparse.Namespace: The parsed command line arguments + """ + parser = argparse.ArgumentParser(description="Lamella prediction with ONNX model") + parser.add_argument("model_path", type=str, help="Path to the ONNX model file") + parser.add_argument("img_path", type=str, help="Path to the input image file") + parser.add_argument( + "-o", + "--output_dir", + type=str, + help="Path to save the output image (default: img directory)", + ) + + return parser.parse_args() + + +def standardize(img: object, sigma: float = 24.0): + """ + Standardize the pixel intensities of the provided image. + + Args: + img (np.ndarray): The image to standardize. + sigma (float): The standard deviation of the Gaussian kernel. + + Returns: + np.ndarray: The standardized image. + """ + # subtract local mean + smooth = cv2.GaussianBlur(img, (0, 0), sigmaX=sigma) + img = np.subtract(img, smooth) + # scale pixel intensities + img = img / np.std(img) + del smooth + + return img.astype(np.float32) + + +def main(args): + """ + Main function for ONNX prediction. + # usage + # python onnx_pred.py /path/to/model.onnx /path/to/image.jpeg -o /path/to/output + + Args: + args (argparse.Namespace): The command line arguments. + + Returns: + None + """ + + print("Loading ONNX model ... ", end="", flush=True) + # load ONNX model + session, input_name, window_shape, output_name = load_onnx_model(args.model_path) + stride = window_shape[0] // 5 + print("done") + + print("Loading image ... ", end="", flush=True) + # load image + basename = os.path.basename(args.img_path).replace(".jpeg", "") + img = cv2.imread(args.img_path).astype(np.float32) + # save original image shape + h, w, c = img.shape + # PaddleSeg models expect 3 channel input image + assert c == 3, f"image must have 3 channels. Found: {c}" + del c + print("done") + + # standardize image + img = standardize(img) + + # transpose dimensions from HWC to CHW + img = np.transpose(img, (2, 0, 1)) + + # pad image for sliding window + pad_h = max(window_shape[0] - h % stride, 0) + pad_w = max(window_shape[1] - w % stride, 0) + img = np.pad(img, ((0, 0), (0, pad_h), (0, pad_w)), mode="constant") + _, pad_h, pad_w = img.shape + del _ + + # window input image + windows = view_as_windows( + img, (3, window_shape[0], window_shape[1]), step=stride + ).squeeze() + + print("Predicting ... ", end="", flush=True) + # predict each window + container = None + count = np.zeros([1, 1, pad_h, pad_w]) + for i in range(windows.shape[0]): + for j in range(windows.shape[1]): + + window = windows[i, j] + h_start = i * stride + w_start = j * stride + h_end = h_start + window_shape[0] + w_end = w_start + window_shape[1] + + # add batch dimension to window + logits = session.run([output_name], {input_name: window[np.newaxis, ...]})[ + 0 + ].squeeze() + del window + + # add logits to container + if container is None: + container = np.zeros([1, logits.shape[0], pad_h, pad_w]) + container[:, :, h_start:h_end, w_start:w_end] += logits + count[:, :, h_start:h_end, w_start:w_end] += 1 + del h_start, w_start, h_end, w_end, logits + del img, pad_h, pad_w, i, j + assert ( + np.min(count) == 1 + ), "There are pixels not predicted. Check window and stride size." + + # average the predictions + container = np.argmax(container / count, axis=1).squeeze() * 127 + del count + + # crop image to remove padding + container = container[:h, :w] + del h, w + print("done") + + import matplotlib.pyplot as plt + plt.imshow(container, cmap="gray") + plt.show() + + # write to storage + out_path = os.path.join(args.output_dir, f"{basename}.png") + cv2.imwrite(out_path, container) + del basename + + +if __name__ == "__main__": + args = parse_args() + + # validate arguments + assert os.path.exists(args.model_path), "model path does not exist" + assert os.path.exists(args.img_path), "image path does not exist" + if args.output_dir is not None: + print(f"warning - setting output directory to image directory: {args.img_path}") + args.output_dir = os.path.dirname(args.img_path) + + main(args)