From 125cb10a34229f7e5024d90bef629f0545628b45 Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Thu, 29 Feb 2024 18:36:47 +0000 Subject: [PATCH 01/13] added variable backend based on global script variable (useful to force CPU if needed). TODO: make it default to GPU --- .gitignore | 2 ++ src/napari_sam/_widget.py | 20 ++++++++++++++------ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 5f3c56f..a571bc8 100644 --- a/.gitignore +++ b/.gitignore @@ -84,3 +84,5 @@ venv/ **/_version.py tmp.py + +*.mypy* \ No newline at end of file diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index 6918e3a..36d9d0a 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -40,6 +40,14 @@ class BboxState(Enum): DRAG = 1 RELEASE = 2 + +class Backend(Enum): + GPU = 0 + MPS = 1 + CPU = 2 +BACKEND = Backend.CPU # TODO: make GPU later + + SAM_MODELS = { "default": {"filename": "sam_vit_h_4b8939.pth", "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "model": build_sam_vit_h}, "vit_h": {"filename": "sam_vit_h_4b8939.pth", "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "model": build_sam_vit_h}, @@ -57,13 +65,13 @@ def __init__(self, napari_viewer): self.annotator_mode = AnnotatorMode.NONE self.segmentation_mode = SegmentationMode.SEMANTIC - if not torch.cuda.is_available(): - if not torch.backends.mps.is_available(): - self.device = "cpu" - else: - self.device = "mps" - else: + + if BACKEND == Backend.GPU and torch.cuda.is_available(): self.device = "cuda" + elif BACKEND == Backend.MPS and torch.mps.is_available(): + self.device = "mps" + else: + self.device = "cpu" main_layout = QVBoxLayout() From 351e5d1031a264e65445dff943738f90d5e474af Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Fri, 1 Mar 2024 14:02:54 +0000 Subject: [PATCH 02/13] seems to be a bug in my napari version - edge_color arg to add_points not working. Change this back aftewards --- src/napari_sam/_widget.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index 36d9d0a..c586d51 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -22,7 +22,6 @@ import os from os.path import join - class AnnotatorMode(Enum): NONE = 0 CLICK = 1 @@ -1167,7 +1166,7 @@ def update_points_layer(self, points): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) - self.points_layer = self.viewer.add_points(name=self.points_layer_name, data=np.asarray(points_flattened), face_color=colors_flattended, edge_color="white", size=self.point_size) + self.points_layer = self.viewer.add_points(name=self.points_layer_name, data=np.asarray(points_flattened), face_color=colors_flattended, size=self.point_size) self.points_layer.editable = False if selected_layer is not None: From afe7e334807c873f47d667406ea4c01ced78e663 Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Fri, 1 Mar 2024 15:36:59 +0000 Subject: [PATCH 03/13] added interaction box overlay for SAM bounding box. NB previous boxes don't persist --- src/napari_sam/_widget.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index c586d51..2b34d41 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -22,6 +22,8 @@ import os from os.path import join +from napari.components.overlays.interaction_box import SelectionBoxOverlay + class AnnotatorMode(Enum): NONE = 0 CLICK = 1 @@ -273,6 +275,9 @@ def __init__(self, napari_viewer): self.le_point_size.setText(str(self.point_size)) self.bbox_layer = None self.bbox_layer_name = "Ignore this layer2" + self.bbox_overlay = None + self.bbox_node = None + self.bbox_edge_width = 10 self.le_bbox_edge_width.setText(str(self.bbox_edge_width)) @@ -585,6 +590,10 @@ def _activate(self): self.image_name = self.cb_image_layers.currentText() self.image_layer = self.viewer.layers[self.cb_image_layers.currentText()] self.label_layer = self.viewer.layers[self.cb_label_layers.currentText()] + self.bbox_overlay = self._get_bbox_overlay_and_node() + print(self.bbox_node) + #self.bbox_node._edge_color = (1, 0, 0, 1) + self.label_layer_changes = None # Fixes shape adjustment by napari if self.image_layer.ndim == 3: @@ -776,6 +785,18 @@ def _deactivate(self): self.rb_semantic.setStyleSheet("") self.rb_instance.setStyleSheet("") self._reset_history() + + def _get_bbox_overlay_and_node(self): + overlay = self.label_layer._overlays['selection_box'] + overlay.visible = True + vispy_layer = self.viewer.window._qt_viewer.canvas.layer_to_visual[self.label_layer] + for key, value in vispy_layer.overlays.items(): + if type(key) == SelectionBoxOverlay: + vispy_overlay = value + node = vispy_overlay.node + node.line.set_data(color="steelblue", width=self.bbox_edge_width) + #overlay._line_color = "steelblue" + return overlay def _switch_mode(self): if self.annotator_mode == AnnotatorMode.CLICK: @@ -1183,10 +1204,16 @@ def update_bbox_layer(self, bboxes, bbox_tmp=None): if bbox_tmp is not None: bboxes_flattened.append(bbox_tmp) edge_colors.append('steelblue') - self.bbox_layer.data = bboxes_flattened - self.bbox_layer.edge_width = [self.bbox_edge_width] * len(bboxes_flattened) - self.bbox_layer.edge_color = edge_colors - self.bbox_layer.face_color = [(0, 0, 0, 0)] * len(bboxes_flattened) + p0 = bbox_tmp[0] + p1 = bbox_tmp[2] + #self.bbox_overlay.line_thickness = self.bbox_edge_width + self.bbox_overlay.bounds = ((p0[0], p0[1]), (p1[0], p1[1])) + #self.bbox_layer.data = bboxes_flattened + #self.bbox_layer.edge_width = [self.bbox_edge_width] * len(bboxes_flattened) + #self.bbox_layer.edge_color = edge_colors + # nself.bbox_layer.face_color = [(0, 0, 0, 0)] * len(bboxes_flattened) + + def find_changed_point(self, old_points, new_points): if len(new_points) == 0: From e67623d69733670f5e616af00b940391ce090f7b Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Fri, 1 Mar 2024 15:45:37 +0000 Subject: [PATCH 04/13] removed old bbox layer, keeping the history stuff intact --- src/napari_sam/_widget.py | 28 +++------------------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index 2b34d41..891e215 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -273,8 +273,6 @@ def __init__(self, napari_viewer): self.old_points = np.zeros(0) self.point_size = 10 self.le_point_size.setText(str(self.point_size)) - self.bbox_layer = None - self.bbox_layer_name = "Ignore this layer2" self.bbox_overlay = None self.bbox_node = None @@ -590,9 +588,7 @@ def _activate(self): self.image_name = self.cb_image_layers.currentText() self.image_layer = self.viewer.layers[self.cb_image_layers.currentText()] self.label_layer = self.viewer.layers[self.cb_label_layers.currentText()] - self.bbox_overlay = self._get_bbox_overlay_and_node() - print(self.bbox_node) - #self.bbox_node._edge_color = (1, 0, 0, 1) + self.bbox_overlay, self.bbox_node = self._get_bbox_overlay_and_node() self.label_layer_changes = None # Fixes shape adjustment by napari @@ -659,7 +655,6 @@ def _activate(self): selected_layer = None if self.viewer.layers.selection.active != self.points_layer: selected_layer = self.viewer.layers.selection.active - self.bbox_layer = self.viewer.add_shapes(name=self.bbox_layer_name) if selected_layer is not None: self.viewer.layers.selection.active = selected_layer if self.image_layer.ndim == 3: @@ -669,7 +664,6 @@ def _activate(self): self.update_bbox_layer({}, bbox_tmp=None) self.viewer.dims.set_point(0, 0) self.viewer.dims.set_point(0, pos[0]) - self.bbox_layer.editable = False self.bbox_first_coords = None self.prev_segmentation_mode = SegmentationMode.SEMANTIC @@ -769,7 +763,6 @@ def _deactivate(self): self.label_layer = None self.label_layer_changes = None self.points_layer = None - self.bbox_layer = None self.bbox_first_coords = None self.annotator_mode = AnnotatorMode.NONE self.points = defaultdict(list) @@ -795,8 +788,7 @@ def _get_bbox_overlay_and_node(self): vispy_overlay = value node = vispy_overlay.node node.line.set_data(color="steelblue", width=self.bbox_edge_width) - #overlay._line_color = "steelblue" - return overlay + return overlay, node def _switch_mode(self): if self.annotator_mode == AnnotatorMode.CLICK: @@ -1026,7 +1018,6 @@ def do_bbox_click(self, coords, bbox_state): bbox_final = np.rint(bbox_final).astype(np.int32) self.bboxes[new_label].append(bbox_final) - self.update_bbox_layer(self.bboxes) prediction = self.predict_sam(points=None, labels=None, bbox=copy.deepcopy(bbox_final), x_coord=x_coord) @@ -1196,24 +1187,11 @@ def update_points_layer(self, points): def update_bbox_layer(self, bboxes, bbox_tmp=None): self.bbox_edge_width = int(self.le_bbox_edge_width.text()) - bboxes_flattened = [] - edge_colors = [] - for _, bbox in bboxes.items(): - bboxes_flattened.extend(bbox) - edge_colors.extend(['skyblue'] * len(bbox)) + self.bbox_node.line.set_data(color="steelblue", width=self.bbox_edge_width) if bbox_tmp is not None: - bboxes_flattened.append(bbox_tmp) - edge_colors.append('steelblue') p0 = bbox_tmp[0] p1 = bbox_tmp[2] - #self.bbox_overlay.line_thickness = self.bbox_edge_width self.bbox_overlay.bounds = ((p0[0], p0[1]), (p1[0], p1[1])) - #self.bbox_layer.data = bboxes_flattened - #self.bbox_layer.edge_width = [self.bbox_edge_width] * len(bboxes_flattened) - #self.bbox_layer.edge_color = edge_colors - # nself.bbox_layer.face_color = [(0, 0, 0, 0)] * len(bboxes_flattened) - - def find_changed_point(self, old_points, new_points): if len(new_points) == 0: From d4b721f800f0f02baa63f78cb6019e0007fefe08 Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Fri, 1 Mar 2024 16:02:47 +0000 Subject: [PATCH 05/13] comments about bbox overlay --- src/napari_sam/_widget.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index 891e215..7e484e7 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -780,12 +780,22 @@ def _deactivate(self): self._reset_history() def _get_bbox_overlay_and_node(self): + """Get inbuilt napari selection-box overlay for self.labels_layer and the box node from + the associated vispy overlay (so we can edit its colour and width later). + + :return: napari overlay model and vispy node + :rtype: _type_ + """ + # private attribute access for overlays which causes warning overlay = self.label_layer._overlays['selection_box'] overlay.visible = True + # as far as I can tell this is the only way to get the vispy layer for a napari layer vispy_layer = self.viewer.window._qt_viewer.canvas.layer_to_visual[self.label_layer] + # we loop over each vispy overlay in the vispy layers and return the one that is a SelectionBoxOverlay for key, value in vispy_layer.overlays.items(): if type(key) == SelectionBoxOverlay: vispy_overlay = value + # the node is the bbox drawn on the vispy layer node = vispy_overlay.node node.line.set_data(color="steelblue", width=self.bbox_edge_width) return overlay, node From 7939ee30796d9489a450acbacf445b1e40da91c9 Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Fri, 1 Mar 2024 17:00:22 +0000 Subject: [PATCH 06/13] initial implementation of live overview for click annotator mode: prompt SAM with rpevious points + current mouse position as a positive point --- src/napari_sam/_live_overlay.py | 139 ++++++++++++++++++++++++++++++++ src/napari_sam/_widget.py | 34 ++++++++ 2 files changed, 173 insertions(+) create mode 100644 src/napari_sam/_live_overlay.py diff --git a/src/napari_sam/_live_overlay.py b/src/napari_sam/_live_overlay.py new file mode 100644 index 0000000..ec8a122 --- /dev/null +++ b/src/napari_sam/_live_overlay.py @@ -0,0 +1,139 @@ +from napari._vispy.overlays.base import ( + LayerOverlayMixin, + VispyBaseOverlay, +) +from napari._vispy.overlays.labels_polygon import _only_when_enabled +from napari._vispy.utils.visual import overlay_to_visual +from napari._vispy.visuals.image import Image as ImageNode +from vispy.visuals.transforms import TransformSystem +from vispy.visuals.transforms import ( + STTransform, +) + +from napari.components.overlays import SceneOverlay +from napari.layers import Layer +from napari.utils.events import Event + +import napari +import numpy as np +from time import time +from copy import copy + +from typing import Tuple + +MIN_TIME_S: float = 0.08 + + +class ImageOverlay(SceneOverlay): + enabled: bool = False + + +class VispyImageOverlay(LayerOverlayMixin, VispyBaseOverlay): + def __init__( + self, *, layer: Layer, overlay: SceneOverlay, parent=None + ) -> None: + self.node = ImageNode((None), method="auto", texture_format=None) + super().__init__( + node=self.node, layer=layer, overlay=overlay, parent=None + ) + self.node.visible = True + self.widget = None + self.prev_t = time() + + #self.layer.mouse_move_callbacks.append(self._on_mouse_move) + self.reset() + + def _add_widget(self, widget) -> None: + self.widget = widget + + def _get_cropped_mask( + self, whole_mask: np.ndarray + ) -> Tuple[np.ndarray, Tuple[int, int]]: + """Draw bbox round whole mask and crop to it. Return offset as well to translate node later.""" + y_nonzero, x_nonzero = np.nonzero(whole_mask) + x_min, x_max = np.amin(x_nonzero), np.amax(x_nonzero) + y_min, y_max = np.amin(y_nonzero), np.amax(y_nonzero) + return whole_mask[y_min:y_max, x_min:x_max], (x_min, y_min) # type: ignore + + def _update_img_from_mask( + self, mask: np.ndarray, offset: Tuple[int, int], color: Tuple[int, int, int, int] = (255, 0, 0, 100) + ) -> None: + mask = np.expand_dims(mask, -1) + cmapped = np.where(mask == 1, color, (0, 0, 0, 0)).astype( + np.uint8 + ) + self.node.set_data(cmapped) + x, y = offset + self.node.transform = STTransform(translate=[x, y]) + self.node.update() + + def draw_mask(self, mask: np.ndarray, color: Tuple[int, int, int, int] = (255, 0, 0, 100)) -> None: + cropped_mask, offset = self._get_cropped_mask(mask) + self._update_img_from_mask(cropped_mask, offset, color) + + def _on_mouse_move(self, event: Event) -> None: + """If enough time passed, request a SAM mask from our widget, crop it and draw to the overlay""" + current_t = time() + enough_time_passed = (current_t - self.prev_t) > MIN_TIME_S + if self.widget is None: + return + img_set = self.widget.img_set + if not enough_time_passed or not img_set: + return + y, x = int(event.value[0]), int(event.value[1]) # napari events are (y, x) + whole_mask: np.ndarray = self.widget._live_sam_prompt(x, y) + cropped_mask, offset = self._get_cropped_mask(whole_mask) + self._update_img_from_mask(cropped_mask, offset) + self.prev_t = current_t + + + +def add_custom_overlay( + layer: Layer, viewer: napari.Viewer +) -> Tuple[ImageOverlay, VispyImageOverlay]: + """Init live SAM overlay and add it to the given layer. This involves creating + custom overlay, an associated custom vispy overlay and updating the relevant + overlay -> visual mappings in the layer and the viewer. We then need to copy + all the transforms that bleong to the vispy overlay node's parent and apply + them to the vispy node to prevent an error on startup and ensure the node + is aligned as the canvas is transformed. + + :param layer: napari layer we want the live overlay on, usually the image layer + :type layer: Layer + :param viewer: parent viewer + :type viewer: napari.Viewer + :return: the overlay model (napari layer) and associated vispy overlay + :rtype: Tuple[ImageOverlay, VispyImageOverlay] + """ + vispy_layer = viewer.window._qt_viewer.canvas.layer_to_visual[layer] + custom_overlay_model = ImageOverlay() + custom_overlay_visual = VispyImageOverlay( + layer=layer, overlay=custom_overlay_model, parent=viewer + ) + vispy_layer.overlays[custom_overlay_model] = custom_overlay_visual + viewer.window._qt_viewer.canvas._overlay_to_visual[ + custom_overlay_model + ] = custom_overlay_visual + layer._overlays["live_SAM"] = custom_overlay_model + custom_overlay_visual.node.parent = vispy_layer.node + custom_overlay_model.enabled = True + custom_overlay_model.visible = True + custom_overlay_visual.reset() # this is necessary + vispy_layer._on_matrix_change() + + # per attribute transform copying - just using deepcopy or copy on the TransformSystem + # itself doesn't work (canvas will follow the node transforms) + p_sys: TransformSystem = custom_overlay_visual.node.parent.transforms + c_sys = TransformSystem(p_sys.canvas) + c_sys.canvas_transform = copy(p_sys.canvas_transform) + c_sys.scene_transform = copy(p_sys.scene_transform) + c_sys.visual_transform = copy(p_sys.visual_transform) + c_sys.dpi = p_sys.dpi + c_sys.framebuffer_transform = copy(p_sys.framebuffer_transform) + c_sys.document_transform = copy(p_sys.document_transform) + custom_overlay_visual.node.transforms = c_sys + return custom_overlay_model, custom_overlay_visual + +# when we add the custom overlay, this will trigger an event in the layer that tries to add a vispy overlay from +# the overlay_to_visual dict, which will fail for our custom overlay unless we modify it at runtime +overlay_to_visual[ImageOverlay] = VispyImageOverlay diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index 7e484e7..53b3555 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -23,6 +23,8 @@ from os.path import join from napari.components.overlays.interaction_box import SelectionBoxOverlay +from napari_sam._live_overlay import add_custom_overlay +from time import time class AnnotatorMode(Enum): NONE = 0 @@ -291,6 +293,9 @@ def __init__(self, napari_viewer): self.bboxes = defaultdict(list) + self.live_overlay_t = time() + self.live_timeout_s = 0.8 + # self.viewer.window.qt_viewer.layers.model().filterAcceptsRow = self._myfilter def init_auto_mode_settings(self): @@ -589,6 +594,8 @@ def _activate(self): self.image_layer = self.viewer.layers[self.cb_image_layers.currentText()] self.label_layer = self.viewer.layers[self.cb_label_layers.currentText()] self.bbox_overlay, self.bbox_node = self._get_bbox_overlay_and_node() + self.live_overlay_model, self.live_overlay_visual = add_custom_overlay(self.image_layer, self.viewer) + self.live_overlay_visual._add_widget(self) self.label_layer_changes = None # Fixes shape adjustment by napari @@ -703,6 +710,7 @@ def _activate(self): self.update_points_layer(None) self.viewer.mouse_drag_callbacks.append(self.callback_click) + self.viewer.mouse_move_callbacks.append(self.callback_move) self.viewer.keymap['Delete'] = self.on_delete self.label_layer.keymap['Control-Z'] = self.on_undo self.label_layer.keymap['Control-Shift-Z'] = self.on_redo @@ -867,6 +875,29 @@ def callback_click(self, layer, event): data_coordinates = self.image_layer.world_to_data(event.position) coords = np.round(data_coordinates).astype(int) self.do_bbox_click(coords, BboxState.RELEASE) + + def callback_move(self, layer, event): + current_t = time() + # avoid over-requesting to SAM + if current_t - self.live_overlay_t < self.live_timeout_s: + return + y, x = int(event.position[0]), int(event.position[1]) + if self.annotator_mode == AnnotatorMode.CLICK: + points = copy.deepcopy(self.points) + points[1].append((y, x)) + points_flattened = [] + labels_flattended = [] + for label, label_points in points.items(): + points_flattened.extend(label_points) + label = int(label == 1) + labels = [label] * len(label_points) + labels_flattended.extend(labels) + prediction = self.predict_sam(points=copy.deepcopy(points_flattened), labels=copy.deepcopy(labels_flattended), bbox=None, x_coord=copy.deepcopy(x)) + #prediction = self.predict_sam(points=copy.deepcopy(points_flattened), labels=copy.deepcopy(labels_flattended), bbox=None, x_coord=copy.deepcopy(current_point[0])) + self.live_overlay_visual.draw_mask(prediction) + + + def on_delete(self, layer): selected_points = list(self.points_layer.selected_data) @@ -1010,6 +1041,9 @@ def do_bbox_click(self, coords, bbox_state): raise RuntimeError("Only 2D and 3D images are supported.") bbox_tmp = np.rint(bbox_tmp).astype(np.int32) self.update_bbox_layer(self.bboxes, bbox_tmp=bbox_tmp) + #prediction = self.predict_sam(points=None, labels=None, bbox=copy.deepcopy(bbox_final), x_coord=x_coord) + # TODO: add call to custom overlay + else: self._save_history({"mode": AnnotatorMode.BBOX, "points": copy.deepcopy(self.points), "bboxes": copy.deepcopy(self.bboxes), "logits": self.sam_logits, "point_label": self.point_label}) if self.image_layer.ndim == 2: From ff6d42449d8ed2fe53af14910337adeee9982ff0 Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Fri, 1 Mar 2024 17:27:47 +0000 Subject: [PATCH 07/13] added warning filter and made it draw correct label color --- src/napari_sam/_live_overlay.py | 61 ++++++++++++++++++--------------- src/napari_sam/_widget.py | 35 ++++++++++++------- 2 files changed, 57 insertions(+), 39 deletions(-) diff --git a/src/napari_sam/_live_overlay.py b/src/napari_sam/_live_overlay.py index ec8a122..60941c6 100644 --- a/src/napari_sam/_live_overlay.py +++ b/src/napari_sam/_live_overlay.py @@ -18,6 +18,7 @@ import numpy as np from time import time from copy import copy +import warnings from typing import Tuple @@ -51,6 +52,10 @@ def _get_cropped_mask( ) -> Tuple[np.ndarray, Tuple[int, int]]: """Draw bbox round whole mask and crop to it. Return offset as well to translate node later.""" y_nonzero, x_nonzero = np.nonzero(whole_mask) + # in case of no response from SAM mask, draw nothing + if len(y_nonzero) == 0 or len(x_nonzero) == 0: + return np.zeros((10, 10)), (0, 0) + x_min, x_max = np.amin(x_nonzero), np.amax(x_nonzero) y_min, y_max = np.amin(y_nonzero), np.amax(y_nonzero) return whole_mask[y_min:y_max, x_min:x_max], (x_min, y_min) # type: ignore @@ -105,33 +110,35 @@ def add_custom_overlay( :return: the overlay model (napari layer) and associated vispy overlay :rtype: Tuple[ImageOverlay, VispyImageOverlay] """ - vispy_layer = viewer.window._qt_viewer.canvas.layer_to_visual[layer] - custom_overlay_model = ImageOverlay() - custom_overlay_visual = VispyImageOverlay( - layer=layer, overlay=custom_overlay_model, parent=viewer - ) - vispy_layer.overlays[custom_overlay_model] = custom_overlay_visual - viewer.window._qt_viewer.canvas._overlay_to_visual[ - custom_overlay_model - ] = custom_overlay_visual - layer._overlays["live_SAM"] = custom_overlay_model - custom_overlay_visual.node.parent = vispy_layer.node - custom_overlay_model.enabled = True - custom_overlay_model.visible = True - custom_overlay_visual.reset() # this is necessary - vispy_layer._on_matrix_change() - - # per attribute transform copying - just using deepcopy or copy on the TransformSystem - # itself doesn't work (canvas will follow the node transforms) - p_sys: TransformSystem = custom_overlay_visual.node.parent.transforms - c_sys = TransformSystem(p_sys.canvas) - c_sys.canvas_transform = copy(p_sys.canvas_transform) - c_sys.scene_transform = copy(p_sys.scene_transform) - c_sys.visual_transform = copy(p_sys.visual_transform) - c_sys.dpi = p_sys.dpi - c_sys.framebuffer_transform = copy(p_sys.framebuffer_transform) - c_sys.document_transform = copy(p_sys.document_transform) - custom_overlay_visual.node.transforms = c_sys + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + vispy_layer = viewer.window._qt_viewer.canvas.layer_to_visual[layer] + custom_overlay_model = ImageOverlay() + custom_overlay_visual = VispyImageOverlay( + layer=layer, overlay=custom_overlay_model, parent=viewer + ) + vispy_layer.overlays[custom_overlay_model] = custom_overlay_visual + viewer.window._qt_viewer.canvas._overlay_to_visual[ + custom_overlay_model + ] = custom_overlay_visual + layer._overlays["live_SAM"] = custom_overlay_model + custom_overlay_visual.node.parent = vispy_layer.node + custom_overlay_model.enabled = True + custom_overlay_model.visible = True + custom_overlay_visual.reset() # this is necessary + vispy_layer._on_matrix_change() + + # per attribute transform copying - just using deepcopy or copy on the TransformSystem + # itself doesn't work (canvas will follow the node transforms) + p_sys: TransformSystem = custom_overlay_visual.node.parent.transforms + c_sys = TransformSystem(p_sys.canvas) + c_sys.canvas_transform = copy(p_sys.canvas_transform) + c_sys.scene_transform = copy(p_sys.scene_transform) + c_sys.visual_transform = copy(p_sys.visual_transform) + c_sys.dpi = p_sys.dpi + c_sys.framebuffer_transform = copy(p_sys.framebuffer_transform) + c_sys.document_transform = copy(p_sys.document_transform) + custom_overlay_visual.node.transforms = c_sys return custom_overlay_model, custom_overlay_visual # when we add the custom overlay, this will trigger an event in the layer that tries to add a vispy overlay from diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index 53b3555..f088a7d 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -794,11 +794,13 @@ def _get_bbox_overlay_and_node(self): :return: napari overlay model and vispy node :rtype: _type_ """ - # private attribute access for overlays which causes warning - overlay = self.label_layer._overlays['selection_box'] - overlay.visible = True - # as far as I can tell this is the only way to get the vispy layer for a napari layer - vispy_layer = self.viewer.window._qt_viewer.canvas.layer_to_visual[self.label_layer] + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + # private attribute access for overlays which causes warning + overlay = self.label_layer._overlays['selection_box'] + overlay.visible = True + # as far as I can tell this is the only way to get the vispy layer for a napari layer + vispy_layer = self.viewer.window._qt_viewer.canvas.layer_to_visual[self.label_layer] # we loop over each vispy overlay in the vispy layers and return the one that is a SelectionBoxOverlay for key, value in vispy_layer.overlays.items(): if type(key) == SelectionBoxOverlay: @@ -893,12 +895,12 @@ def callback_move(self, layer, event): labels = [label] * len(label_points) labels_flattended.extend(labels) prediction = self.predict_sam(points=copy.deepcopy(points_flattened), labels=copy.deepcopy(labels_flattended), bbox=None, x_coord=copy.deepcopy(x)) - #prediction = self.predict_sam(points=copy.deepcopy(points_flattened), labels=copy.deepcopy(labels_flattended), bbox=None, x_coord=copy.deepcopy(current_point[0])) - self.live_overlay_visual.draw_mask(prediction) - + label_n = self.label_layer.selected_label + color = self.label_layer.colormap.colors[label_n] + color[3] = 0.4 + color = tuple([int(255* i) for i in color]) + self.live_overlay_visual.draw_mask(prediction, color) - - def on_delete(self, layer): selected_points = list(self.points_layer.selected_data) if len(selected_points) > 0: @@ -1041,8 +1043,17 @@ def do_bbox_click(self, coords, bbox_state): raise RuntimeError("Only 2D and 3D images are supported.") bbox_tmp = np.rint(bbox_tmp).astype(np.int32) self.update_bbox_layer(self.bboxes, bbox_tmp=bbox_tmp) - #prediction = self.predict_sam(points=None, labels=None, bbox=copy.deepcopy(bbox_final), x_coord=x_coord) - # TODO: add call to custom overlay + + current_t = time() + if current_t - self.live_overlay_t < self.live_timeout_s or len(bbox_tmp) < 4: + return + label_n = self.label_layer.selected_label + color = self.label_layer.colormap.colors[label_n] + color[3] = 0.4 + color = tuple([int(255* i) for i in color]) + x_coord = self.bbox_first_coords[0] + prediction = self.predict_sam(points=None, labels=None, bbox=copy.deepcopy(bbox_tmp), x_coord=x_coord) + self.live_overlay_visual.draw_mask(prediction, color) else: self._save_history({"mode": AnnotatorMode.BBOX, "points": copy.deepcopy(self.points), "bboxes": copy.deepcopy(self.bboxes), "logits": self.sam_logits, "point_label": self.point_label}) From fdd1f5746f2fc168ea3c01fd4c678dd1eb57d87e Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Fri, 1 Mar 2024 17:38:11 +0000 Subject: [PATCH 08/13] added checkbox that disables live view (in case it's distracting) --- src/napari_sam/_widget.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index f088a7d..0ac517d 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -198,6 +198,12 @@ def __init__(self, napari_viewer): self.check_auto_inc_bbox.setChecked(True) main_layout.addWidget(self.check_auto_inc_bbox) + self.check_live_view = QCheckBox('Live View') + self.check_live_view.setEnabled(False) + self.check_live_view.setChecked(True) + self.live_on = True + main_layout.addWidget(self.check_live_view) + container_widget_info = QWidget() container_layout_info = QVBoxLayout(container_widget_info) @@ -657,6 +663,8 @@ def _activate(self): self.check_prev_mask.setEnabled(True) self.check_auto_inc_bbox.setEnabled(True) self.check_auto_inc_bbox.setChecked(True) + self.check_live_view.setEnabled(True) + self.check_live_view.setChecked(True) self.btn_mode_switch.setText("Switch to BBox Mode") self.annotator_mode = AnnotatorMode.CLICK selected_layer = None @@ -746,6 +754,8 @@ def _deactivate(self): self.btn_mode_switch.setText("Switch to BBox Mode") self.check_prev_mask.setEnabled(False) self.check_auto_inc_bbox.setEnabled(False) + self.check_live_view.setEnabled(False) + self.check_live_view.setChecked(False) self.prev_segmentation_mode = SegmentationMode.SEMANTIC self.annotator_mode = AnnotatorMode.CLICK @@ -766,6 +776,7 @@ def _deactivate(self): self.viewer.layers.remove(self.points_layer) if self.bbox_layer is not None and self.bbox_layer in self.viewer.layers: self.viewer.layers.remove(self.bbox_layer) + self.live_overlay_model.visible = False self.image_name = None self.image_layer = None self.label_layer = None @@ -884,7 +895,7 @@ def callback_move(self, layer, event): if current_t - self.live_overlay_t < self.live_timeout_s: return y, x = int(event.position[0]), int(event.position[1]) - if self.annotator_mode == AnnotatorMode.CLICK: + if self.annotator_mode == AnnotatorMode.CLICK and self.check_live_view.isChecked(): points = copy.deepcopy(self.points) points[1].append((y, x)) points_flattened = [] @@ -1047,6 +1058,8 @@ def do_bbox_click(self, coords, bbox_state): current_t = time() if current_t - self.live_overlay_t < self.live_timeout_s or len(bbox_tmp) < 4: return + if not self.check_live_view.isChecked(): + return label_n = self.label_layer.selected_label color = self.label_layer.colormap.colors[label_n] color[3] = 0.4 From 385533a90bc7fd7b2aa721252d813328e82e417c Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Fri, 1 Mar 2024 17:38:32 +0000 Subject: [PATCH 09/13] add reset function in live overlay for OOB or live view toggle --- src/napari_sam/_live_overlay.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/napari_sam/_live_overlay.py b/src/napari_sam/_live_overlay.py index 60941c6..c786f0c 100644 --- a/src/napari_sam/_live_overlay.py +++ b/src/napari_sam/_live_overlay.py @@ -72,6 +72,9 @@ def _update_img_from_mask( self.node.transform = STTransform(translate=[x, y]) self.node.update() + def remove_current(self) -> None: + self.draw_mask(np.zeros((4, 4)), (0, 0, 0, 0)) + def draw_mask(self, mask: np.ndarray, color: Tuple[int, int, int, int] = (255, 0, 0, 100)) -> None: cropped_mask, offset = self._get_cropped_mask(mask) self._update_img_from_mask(cropped_mask, offset, color) From 082cedf123b49f92e77f286cd285127a3f2eca73 Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Fri, 1 Mar 2024 17:45:08 +0000 Subject: [PATCH 10/13] now removes current live view mask if toggled off --- src/napari_sam/_widget.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index 0ac517d..a2dac93 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -201,7 +201,7 @@ def __init__(self, napari_viewer): self.check_live_view = QCheckBox('Live View') self.check_live_view.setEnabled(False) self.check_live_view.setChecked(True) - self.live_on = True + self.check_live_view.clicked.connect(self._toggle_live_view) main_layout.addWidget(self.check_live_view) container_widget_info = QWidget() @@ -1453,5 +1453,9 @@ def get_cached_weight_types(self, model_types): return cached_weight_types + def _toggle_live_view(self): + if not self.check_live_view.isChecked(): + self.live_overlay_visual.remove_current() + # def _myfilter(self, row, parent): # return "" not in self.viewer.layers[row].name \ No newline at end of file From eafca6d6f2378041878483a91ec730d641343e9d Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Fri, 1 Mar 2024 18:03:03 +0000 Subject: [PATCH 11/13] stopped live view for ndim=3 --- src/napari_sam/_widget.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index a2dac93..5f7c37d 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -894,6 +894,8 @@ def callback_move(self, layer, event): # avoid over-requesting to SAM if current_t - self.live_overlay_t < self.live_timeout_s: return + if self.image_layer.ndim != 2: + return y, x = int(event.position[0]), int(event.position[1]) if self.annotator_mode == AnnotatorMode.CLICK and self.check_live_view.isChecked(): points = copy.deepcopy(self.points) @@ -905,7 +907,7 @@ def callback_move(self, layer, event): label = int(label == 1) labels = [label] * len(label_points) labels_flattended.extend(labels) - prediction = self.predict_sam(points=copy.deepcopy(points_flattened), labels=copy.deepcopy(labels_flattended), bbox=None, x_coord=copy.deepcopy(x)) + prediction = self.predict_sam(points=copy.deepcopy(points_flattened), labels=copy.deepcopy(labels_flattended), bbox=None, x_coord=copy.deepcopy(x)) # TODO: doesn't work in ndim=3 label_n = self.label_layer.selected_label color = self.label_layer.colormap.colors[label_n] color[3] = 0.4 From 99df586b8ea9b93458057d877a427fa39e4b4b83 Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Fri, 1 Mar 2024 18:18:23 +0000 Subject: [PATCH 12/13] changed backend default to GPU --- src/napari_sam/_widget.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index 5f7c37d..299986a 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -48,7 +48,7 @@ class Backend(Enum): GPU = 0 MPS = 1 CPU = 2 -BACKEND = Backend.CPU # TODO: make GPU later +BACKEND = Backend.GPU SAM_MODELS = { @@ -898,6 +898,7 @@ def callback_move(self, layer, event): return y, x = int(event.position[0]), int(event.position[1]) if self.annotator_mode == AnnotatorMode.CLICK and self.check_live_view.isChecked(): + # this is (lazily) copied from predict_click points = copy.deepcopy(self.points) points[1].append((y, x)) points_flattened = [] @@ -1065,6 +1066,7 @@ def do_bbox_click(self, coords, bbox_state): label_n = self.label_layer.selected_label color = self.label_layer.colormap.colors[label_n] color[3] = 0.4 + # overlay expects uint8 format - this assumes colors are rgba tuples and not strings color = tuple([int(255* i) for i in color]) x_coord = self.bbox_first_coords[0] prediction = self.predict_sam(points=None, labels=None, bbox=copy.deepcopy(bbox_tmp), x_coord=x_coord) From f1c9e08a0300ab471e96c8e5a69149338d388c2c Mon Sep 17 00:00:00 2001 From: rmdocherty Date: Mon, 4 Mar 2024 15:41:31 +0000 Subject: [PATCH 13/13] Removed unnecessary bbox layer removal in _deactivate which otherwise caused exception --- src/napari_sam/_widget.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index 299986a..2635570 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -774,8 +774,6 @@ def _deactivate(self): self.remove_all_widget_callbacks(self.label_layer) if self.points_layer is not None and self.points_layer in self.viewer.layers: self.viewer.layers.remove(self.points_layer) - if self.bbox_layer is not None and self.bbox_layer in self.viewer.layers: - self.viewer.layers.remove(self.bbox_layer) self.live_overlay_model.visible = False self.image_name = None self.image_layer = None