Skip to content

Commit

Permalink
[feat] add onnx-window model to backend (ppliteseg-nysbc-ccameron)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickcleeve2 committed Jul 26, 2024
1 parent 36b2a42 commit 77365fd
Show file tree
Hide file tree
Showing 5 changed files with 448 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@ fibsem/config/*.yaml
tests/notebook.ipynb
fibsem/db/v2/notebook.ipynb
scripts/segformer-*/*
*.onnx
scripts/**/*.jpeg
scripts/**/*.png
11 changes: 9 additions & 2 deletions fibsem/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,15 @@ def load_model(
from fibsem.segmentation.nnunet_model import SegmentationModelNNUnet
model = SegmentationModelNNUnet(checkpoint=checkpoint)
elif backend == "onnx":
from fibsem.segmentation.onnx_model import SegmentationModelONNX
model = SegmentationModelONNX(checkpoint=checkpoint)
from fibsem.segmentation.onnx_model import SegmentationModelONNX, SegmentationModelWindowONNX
for onnx_model in [SegmentationModelWindowONNX, SegmentationModelONNX]:
try:
logging.debug(f"Trying to load {onnx_model}")
model = onnx_model(checkpoint=checkpoint)
break
except Exception as e:
logging.debug(f"Failed to load {type(onnx_model)} for {checkpoint}: {e}")

elif backend == "huggingface":
from fibsem.segmentation.hf_segmentation_model import SegmentationModelHuggingFace
model = SegmentationModelHuggingFace(checkpoint=checkpoint)
Expand Down
154 changes: 151 additions & 3 deletions fibsem/segmentation/onnx_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@


import os
import logging

import cv2
import numpy as np
import onnx
import onnxruntime
import PIL.Image
from onnxruntime import InferenceSession
from skimage.util.shape import view_as_windows

from fibsem.segmentation.utils import decode_segmap_v2, download_checkpoint


### ONNX


Expand Down Expand Up @@ -52,9 +57,10 @@ def inference(self, img: np.ndarray, rgb: bool = True):

def export_model_to_onnx(checkpoint: str, onnx_path: str):

from fibsem.segmentation.model import load_model
import torch

from fibsem.segmentation.model import load_model

# get fibsem model
model = load_model(checkpoint)
model.model.to("cpu")
Expand Down Expand Up @@ -96,4 +102,146 @@ def to_numpy(tensor):

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

# export_model_to_onnx("autolamella-mega-latest.pt", "autolamella-mega-20231230.onnx")
# export_model_to_onnx("autolamella-mega-latest.pt", "autolamella-mega-20231230.onnx")

## PPLITESEG WINDOWED MODEL
def load_windowed_onnx_model(model_path: str) -> tuple:
"""
Load the ONNX model.
Args:
model_path (str): File path to the ONNX model.
Returns:
InferenceSession: The ONNX model session.
str: The name of the input tensor.
Tuple[int, int]: The shape of the input tensor.
str: The name of the output tensor.
"""
session = InferenceSession(model_path)
input_name = session.get_inputs()[0].name
window_shape = session.get_inputs()[0].shape[2:]
output_name = session.get_outputs()[0].name

return session, input_name, window_shape, output_name

def standardize(img: object, sigma: float = 24.0) -> np.ndarray:
"""
Standardize the pixel intensities of the provided image.
Args:
img (np.ndarray): The image to standardize.
sigma (float): The standard deviation of the Gaussian kernel.
Returns:
np.ndarray: The standardized image.
"""
# subtract local mean
smooth = cv2.GaussianBlur(img, (0, 0), sigmaX=sigma)
img = np.subtract(img, smooth)
# scale pixel intensities
img = img / np.std(img)
del smooth

return img.astype(np.float32)


class SegmentationModelWindowONNX:
def __init__(self, checkpoint: str = None):
if checkpoint is not None:
self.load_model(checkpoint)
self.device = None

def load_model(self, checkpoint="autolamella-mega.onnx"):
# download checkpoint if needed
# checkpoint = download_checkpoint(checkpoint)
self.checkpoint = os.path.basename(checkpoint)

# load inference session
session = load_windowed_onnx_model(checkpoint)
self.session, self.input_name, self.window_shape, self.output_name = session

def pre_process(self, img: np.ndarray) -> np.ndarray:
"""Pre-process the image for inference, calculate window parameters"""
##### PREPROCESSING
if img.ndim == 2: # 2d grayscale image -> 3d rgb (grayscale)
img = np.array(PIL.Image.fromarray(img).convert("RGB"))

if img.dtype != np.float32:
img = img.astype(np.float32)

# image and window parameters
h, w, c = img.shape
stride = self.window_shape[0] // 5 # MAGIC_NUMBER

# standardize image
img = standardize(img)

# transpose dimensions from HWC to CHW
img = np.transpose(img, (2, 0, 1))

# pad image for sliding window
pad_h = max(self.window_shape[0] - h % stride, 0)
pad_w = max(self.window_shape[1] - w % stride, 0)
img = np.pad(img, ((0, 0), (0, pad_h), (0, pad_w)), mode="constant")
_, pad_h, pad_w = img.shape

# window input image
windows: np.ndarray = view_as_windows(
img, (3, self.window_shape[0], self.window_shape[1]), step=stride
).squeeze()

logging.debug(f"pre_process: {img.shape}, {windows.shape}, {h}, {w}, {pad_h}, {pad_w}, {stride}")

return img, windows, h, w, pad_h, pad_w, stride

def inference(self, img: np.ndarray, rgb: bool = True) -> np.ndarray:
"""Perform inference on the provided image.
Args:
img (np.ndarray): The image to segment.
rgb (bool): Whether to return an RGB image.
Returns:
np.ndarray: The segmented image."""

# pre-process image
img, windows, h, w, pad_h, pad_w, stride = self.pre_process(img)

# inference on each window
container = None
count = np.zeros([1, 1, pad_h, pad_w])
for i in range(windows.shape[0]):
for j in range(windows.shape[1]):
window = windows[i, j]
h_start = i * stride
w_start = j * stride
h_end = h_start + self.window_shape[0]
w_end = w_start + self.window_shape[1]

# add batch dimension to window
logits = self.session.run(
[self.output_name],
{self.input_name: window[np.newaxis, ...]}
)[0].squeeze()
del window

# add logits to container
if container is None:
container = np.zeros([1, logits.shape[0], pad_h, pad_w])
container[:, :, h_start:h_end, w_start:w_end] += logits
count[:, :, h_start:h_end, w_start:w_end] += 1
del h_start, w_start, h_end, w_end, logits

assert (
np.min(count) == 1
), "There are pixels not predicted. Check window and stride size."

# post-process
# average the predictions across windows
mask = np.argmax(container / count, axis=1).squeeze() # 2d class map
del container, count
# crop image to remove padding
mask = mask[:h, :w]

if rgb:
mask = decode_segmap_v2(mask)
return mask
100 changes: 100 additions & 0 deletions scripts/onnx_notebook.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ONNX Windowed Model Integration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import glob\n",
"import os\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import PIL.Image\n",
"\n",
"from fibsem.segmentation.model import load_model\n",
"from fibsem.structures import FibsemImage\n",
"\n",
"# image filenames\n",
"PATH = \"example_imgs/input\"\n",
"filenames = glob.glob(PATH + \"/*.jpeg\")\n",
"\n",
"# PATH = \"/home/patrick/github/data/autolamella-paper/model-development/train/waffle/test\"\n",
"# filenames = glob.glob(PATH + \"/*.tif\")\n",
"\n",
"# load model\n",
"MODEL_PATH = \"ppliteseg_fibsem_07022024_512x512_128k.onnx\"\n",
"model = load_model(checkpoint=MODEL_PATH)\n",
"\n",
"os.makedirs(\"example_imgs/output/test\", exist_ok=True)\n",
"\n",
"for i, filename in enumerate(filenames):\n",
" print(f\"Processing {i+1}/{len(filenames)}: {filename}\")\n",
"\n",
" # load image\n",
" if \"tif\" in filename:\n",
" image = FibsemImage.load(filename)\n",
" else:\n",
" image = FibsemImage(data=np.asarray(PIL.Image.open(filename)))\n",
" \n",
" # inference\n",
" rgb = model.inference(image.data)\n",
"\n",
" fig = plt.figure(figsize=(10, 10))\n",
" plt.title(f\"Predicted: {os.path.basename(filename)}\", fontsize=10)\n",
" plt.imshow(image.data, cmap=\"gray\")\n",
" plt.imshow(rgb, alpha=0.5)\n",
" plt.axis(\"off\")\n",
" plt.show()\n",
"\n",
" # save figure\n",
" fig.savefig(f\"example_imgs/output/test/{os.path.basename(filename)}\".replace(\".tif\", \".png\"), bbox_inches=\"tight\")\n",
" plt.close(fig)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "fibsem",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 77365fd

Please sign in to comment.