diff --git a/.gitignore b/.gitignore index 5f3c56f..a571bc8 100644 --- a/.gitignore +++ b/.gitignore @@ -84,3 +84,5 @@ venv/ **/_version.py tmp.py + +*.mypy* \ No newline at end of file diff --git a/src/napari_sam/_live_overlay.py b/src/napari_sam/_live_overlay.py new file mode 100644 index 0000000..c786f0c --- /dev/null +++ b/src/napari_sam/_live_overlay.py @@ -0,0 +1,149 @@ +from napari._vispy.overlays.base import ( + LayerOverlayMixin, + VispyBaseOverlay, +) +from napari._vispy.overlays.labels_polygon import _only_when_enabled +from napari._vispy.utils.visual import overlay_to_visual +from napari._vispy.visuals.image import Image as ImageNode +from vispy.visuals.transforms import TransformSystem +from vispy.visuals.transforms import ( + STTransform, +) + +from napari.components.overlays import SceneOverlay +from napari.layers import Layer +from napari.utils.events import Event + +import napari +import numpy as np +from time import time +from copy import copy +import warnings + +from typing import Tuple + +MIN_TIME_S: float = 0.08 + + +class ImageOverlay(SceneOverlay): + enabled: bool = False + + +class VispyImageOverlay(LayerOverlayMixin, VispyBaseOverlay): + def __init__( + self, *, layer: Layer, overlay: SceneOverlay, parent=None + ) -> None: + self.node = ImageNode((None), method="auto", texture_format=None) + super().__init__( + node=self.node, layer=layer, overlay=overlay, parent=None + ) + self.node.visible = True + self.widget = None + self.prev_t = time() + + #self.layer.mouse_move_callbacks.append(self._on_mouse_move) + self.reset() + + def _add_widget(self, widget) -> None: + self.widget = widget + + def _get_cropped_mask( + self, whole_mask: np.ndarray + ) -> Tuple[np.ndarray, Tuple[int, int]]: + """Draw bbox round whole mask and crop to it. Return offset as well to translate node later.""" + y_nonzero, x_nonzero = np.nonzero(whole_mask) + # in case of no response from SAM mask, draw nothing + if len(y_nonzero) == 0 or len(x_nonzero) == 0: + return np.zeros((10, 10)), (0, 0) + + x_min, x_max = np.amin(x_nonzero), np.amax(x_nonzero) + y_min, y_max = np.amin(y_nonzero), np.amax(y_nonzero) + return whole_mask[y_min:y_max, x_min:x_max], (x_min, y_min) # type: ignore + + def _update_img_from_mask( + self, mask: np.ndarray, offset: Tuple[int, int], color: Tuple[int, int, int, int] = (255, 0, 0, 100) + ) -> None: + mask = np.expand_dims(mask, -1) + cmapped = np.where(mask == 1, color, (0, 0, 0, 0)).astype( + np.uint8 + ) + self.node.set_data(cmapped) + x, y = offset + self.node.transform = STTransform(translate=[x, y]) + self.node.update() + + def remove_current(self) -> None: + self.draw_mask(np.zeros((4, 4)), (0, 0, 0, 0)) + + def draw_mask(self, mask: np.ndarray, color: Tuple[int, int, int, int] = (255, 0, 0, 100)) -> None: + cropped_mask, offset = self._get_cropped_mask(mask) + self._update_img_from_mask(cropped_mask, offset, color) + + def _on_mouse_move(self, event: Event) -> None: + """If enough time passed, request a SAM mask from our widget, crop it and draw to the overlay""" + current_t = time() + enough_time_passed = (current_t - self.prev_t) > MIN_TIME_S + if self.widget is None: + return + img_set = self.widget.img_set + if not enough_time_passed or not img_set: + return + y, x = int(event.value[0]), int(event.value[1]) # napari events are (y, x) + whole_mask: np.ndarray = self.widget._live_sam_prompt(x, y) + cropped_mask, offset = self._get_cropped_mask(whole_mask) + self._update_img_from_mask(cropped_mask, offset) + self.prev_t = current_t + + + +def add_custom_overlay( + layer: Layer, viewer: napari.Viewer +) -> Tuple[ImageOverlay, VispyImageOverlay]: + """Init live SAM overlay and add it to the given layer. This involves creating + custom overlay, an associated custom vispy overlay and updating the relevant + overlay -> visual mappings in the layer and the viewer. We then need to copy + all the transforms that bleong to the vispy overlay node's parent and apply + them to the vispy node to prevent an error on startup and ensure the node + is aligned as the canvas is transformed. + + :param layer: napari layer we want the live overlay on, usually the image layer + :type layer: Layer + :param viewer: parent viewer + :type viewer: napari.Viewer + :return: the overlay model (napari layer) and associated vispy overlay + :rtype: Tuple[ImageOverlay, VispyImageOverlay] + """ + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + vispy_layer = viewer.window._qt_viewer.canvas.layer_to_visual[layer] + custom_overlay_model = ImageOverlay() + custom_overlay_visual = VispyImageOverlay( + layer=layer, overlay=custom_overlay_model, parent=viewer + ) + vispy_layer.overlays[custom_overlay_model] = custom_overlay_visual + viewer.window._qt_viewer.canvas._overlay_to_visual[ + custom_overlay_model + ] = custom_overlay_visual + layer._overlays["live_SAM"] = custom_overlay_model + custom_overlay_visual.node.parent = vispy_layer.node + custom_overlay_model.enabled = True + custom_overlay_model.visible = True + custom_overlay_visual.reset() # this is necessary + vispy_layer._on_matrix_change() + + # per attribute transform copying - just using deepcopy or copy on the TransformSystem + # itself doesn't work (canvas will follow the node transforms) + p_sys: TransformSystem = custom_overlay_visual.node.parent.transforms + c_sys = TransformSystem(p_sys.canvas) + c_sys.canvas_transform = copy(p_sys.canvas_transform) + c_sys.scene_transform = copy(p_sys.scene_transform) + c_sys.visual_transform = copy(p_sys.visual_transform) + c_sys.dpi = p_sys.dpi + c_sys.framebuffer_transform = copy(p_sys.framebuffer_transform) + c_sys.document_transform = copy(p_sys.document_transform) + custom_overlay_visual.node.transforms = c_sys + return custom_overlay_model, custom_overlay_visual + +# when we add the custom overlay, this will trigger an event in the layer that tries to add a vispy overlay from +# the overlay_to_visual dict, which will fail for our custom overlay unless we modify it at runtime +overlay_to_visual[ImageOverlay] = VispyImageOverlay diff --git a/src/napari_sam/_widget.py b/src/napari_sam/_widget.py index 6918e3a..2635570 100644 --- a/src/napari_sam/_widget.py +++ b/src/napari_sam/_widget.py @@ -22,6 +22,9 @@ import os from os.path import join +from napari.components.overlays.interaction_box import SelectionBoxOverlay +from napari_sam._live_overlay import add_custom_overlay +from time import time class AnnotatorMode(Enum): NONE = 0 @@ -40,6 +43,14 @@ class BboxState(Enum): DRAG = 1 RELEASE = 2 + +class Backend(Enum): + GPU = 0 + MPS = 1 + CPU = 2 +BACKEND = Backend.GPU + + SAM_MODELS = { "default": {"filename": "sam_vit_h_4b8939.pth", "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "model": build_sam_vit_h}, "vit_h": {"filename": "sam_vit_h_4b8939.pth", "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "model": build_sam_vit_h}, @@ -57,13 +68,13 @@ def __init__(self, napari_viewer): self.annotator_mode = AnnotatorMode.NONE self.segmentation_mode = SegmentationMode.SEMANTIC - if not torch.cuda.is_available(): - if not torch.backends.mps.is_available(): - self.device = "cpu" - else: - self.device = "mps" - else: + + if BACKEND == Backend.GPU and torch.cuda.is_available(): self.device = "cuda" + elif BACKEND == Backend.MPS and torch.mps.is_available(): + self.device = "mps" + else: + self.device = "cpu" main_layout = QVBoxLayout() @@ -187,6 +198,12 @@ def __init__(self, napari_viewer): self.check_auto_inc_bbox.setChecked(True) main_layout.addWidget(self.check_auto_inc_bbox) + self.check_live_view = QCheckBox('Live View') + self.check_live_view.setEnabled(False) + self.check_live_view.setChecked(True) + self.check_live_view.clicked.connect(self._toggle_live_view) + main_layout.addWidget(self.check_live_view) + container_widget_info = QWidget() container_layout_info = QVBoxLayout(container_widget_info) @@ -264,8 +281,9 @@ def __init__(self, napari_viewer): self.old_points = np.zeros(0) self.point_size = 10 self.le_point_size.setText(str(self.point_size)) - self.bbox_layer = None - self.bbox_layer_name = "Ignore this layer2" + self.bbox_overlay = None + self.bbox_node = None + self.bbox_edge_width = 10 self.le_bbox_edge_width.setText(str(self.bbox_edge_width)) @@ -281,6 +299,9 @@ def __init__(self, napari_viewer): self.bboxes = defaultdict(list) + self.live_overlay_t = time() + self.live_timeout_s = 0.8 + # self.viewer.window.qt_viewer.layers.model().filterAcceptsRow = self._myfilter def init_auto_mode_settings(self): @@ -578,6 +599,10 @@ def _activate(self): self.image_name = self.cb_image_layers.currentText() self.image_layer = self.viewer.layers[self.cb_image_layers.currentText()] self.label_layer = self.viewer.layers[self.cb_label_layers.currentText()] + self.bbox_overlay, self.bbox_node = self._get_bbox_overlay_and_node() + self.live_overlay_model, self.live_overlay_visual = add_custom_overlay(self.image_layer, self.viewer) + self.live_overlay_visual._add_widget(self) + self.label_layer_changes = None # Fixes shape adjustment by napari if self.image_layer.ndim == 3: @@ -638,12 +663,13 @@ def _activate(self): self.check_prev_mask.setEnabled(True) self.check_auto_inc_bbox.setEnabled(True) self.check_auto_inc_bbox.setChecked(True) + self.check_live_view.setEnabled(True) + self.check_live_view.setChecked(True) self.btn_mode_switch.setText("Switch to BBox Mode") self.annotator_mode = AnnotatorMode.CLICK selected_layer = None if self.viewer.layers.selection.active != self.points_layer: selected_layer = self.viewer.layers.selection.active - self.bbox_layer = self.viewer.add_shapes(name=self.bbox_layer_name) if selected_layer is not None: self.viewer.layers.selection.active = selected_layer if self.image_layer.ndim == 3: @@ -653,7 +679,6 @@ def _activate(self): self.update_bbox_layer({}, bbox_tmp=None) self.viewer.dims.set_point(0, 0) self.viewer.dims.set_point(0, pos[0]) - self.bbox_layer.editable = False self.bbox_first_coords = None self.prev_segmentation_mode = SegmentationMode.SEMANTIC @@ -693,6 +718,7 @@ def _activate(self): self.update_points_layer(None) self.viewer.mouse_drag_callbacks.append(self.callback_click) + self.viewer.mouse_move_callbacks.append(self.callback_move) self.viewer.keymap['Delete'] = self.on_delete self.label_layer.keymap['Control-Z'] = self.on_undo self.label_layer.keymap['Control-Shift-Z'] = self.on_redo @@ -728,6 +754,8 @@ def _deactivate(self): self.btn_mode_switch.setText("Switch to BBox Mode") self.check_prev_mask.setEnabled(False) self.check_auto_inc_bbox.setEnabled(False) + self.check_live_view.setEnabled(False) + self.check_live_view.setChecked(False) self.prev_segmentation_mode = SegmentationMode.SEMANTIC self.annotator_mode = AnnotatorMode.CLICK @@ -746,14 +774,12 @@ def _deactivate(self): self.remove_all_widget_callbacks(self.label_layer) if self.points_layer is not None and self.points_layer in self.viewer.layers: self.viewer.layers.remove(self.points_layer) - if self.bbox_layer is not None and self.bbox_layer in self.viewer.layers: - self.viewer.layers.remove(self.bbox_layer) + self.live_overlay_model.visible = False self.image_name = None self.image_layer = None self.label_layer = None self.label_layer_changes = None self.points_layer = None - self.bbox_layer = None self.bbox_first_coords = None self.annotator_mode = AnnotatorMode.NONE self.points = defaultdict(list) @@ -769,6 +795,29 @@ def _deactivate(self): self.rb_semantic.setStyleSheet("") self.rb_instance.setStyleSheet("") self._reset_history() + + def _get_bbox_overlay_and_node(self): + """Get inbuilt napari selection-box overlay for self.labels_layer and the box node from + the associated vispy overlay (so we can edit its colour and width later). + + :return: napari overlay model and vispy node + :rtype: _type_ + """ + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + # private attribute access for overlays which causes warning + overlay = self.label_layer._overlays['selection_box'] + overlay.visible = True + # as far as I can tell this is the only way to get the vispy layer for a napari layer + vispy_layer = self.viewer.window._qt_viewer.canvas.layer_to_visual[self.label_layer] + # we loop over each vispy overlay in the vispy layers and return the one that is a SelectionBoxOverlay + for key, value in vispy_layer.overlays.items(): + if type(key) == SelectionBoxOverlay: + vispy_overlay = value + # the node is the bbox drawn on the vispy layer + node = vispy_overlay.node + node.line.set_data(color="steelblue", width=self.bbox_edge_width) + return overlay, node def _switch_mode(self): if self.annotator_mode == AnnotatorMode.CLICK: @@ -837,7 +886,33 @@ def callback_click(self, layer, event): data_coordinates = self.image_layer.world_to_data(event.position) coords = np.round(data_coordinates).astype(int) self.do_bbox_click(coords, BboxState.RELEASE) - + + def callback_move(self, layer, event): + current_t = time() + # avoid over-requesting to SAM + if current_t - self.live_overlay_t < self.live_timeout_s: + return + if self.image_layer.ndim != 2: + return + y, x = int(event.position[0]), int(event.position[1]) + if self.annotator_mode == AnnotatorMode.CLICK and self.check_live_view.isChecked(): + # this is (lazily) copied from predict_click + points = copy.deepcopy(self.points) + points[1].append((y, x)) + points_flattened = [] + labels_flattended = [] + for label, label_points in points.items(): + points_flattened.extend(label_points) + label = int(label == 1) + labels = [label] * len(label_points) + labels_flattended.extend(labels) + prediction = self.predict_sam(points=copy.deepcopy(points_flattened), labels=copy.deepcopy(labels_flattended), bbox=None, x_coord=copy.deepcopy(x)) # TODO: doesn't work in ndim=3 + label_n = self.label_layer.selected_label + color = self.label_layer.colormap.colors[label_n] + color[3] = 0.4 + color = tuple([int(255* i) for i in color]) + self.live_overlay_visual.draw_mask(prediction, color) + def on_delete(self, layer): selected_points = list(self.points_layer.selected_data) if len(selected_points) > 0: @@ -980,6 +1055,21 @@ def do_bbox_click(self, coords, bbox_state): raise RuntimeError("Only 2D and 3D images are supported.") bbox_tmp = np.rint(bbox_tmp).astype(np.int32) self.update_bbox_layer(self.bboxes, bbox_tmp=bbox_tmp) + + current_t = time() + if current_t - self.live_overlay_t < self.live_timeout_s or len(bbox_tmp) < 4: + return + if not self.check_live_view.isChecked(): + return + label_n = self.label_layer.selected_label + color = self.label_layer.colormap.colors[label_n] + color[3] = 0.4 + # overlay expects uint8 format - this assumes colors are rgba tuples and not strings + color = tuple([int(255* i) for i in color]) + x_coord = self.bbox_first_coords[0] + prediction = self.predict_sam(points=None, labels=None, bbox=copy.deepcopy(bbox_tmp), x_coord=x_coord) + self.live_overlay_visual.draw_mask(prediction, color) + else: self._save_history({"mode": AnnotatorMode.BBOX, "points": copy.deepcopy(self.points), "bboxes": copy.deepcopy(self.bboxes), "logits": self.sam_logits, "point_label": self.point_label}) if self.image_layer.ndim == 2: @@ -998,7 +1088,6 @@ def do_bbox_click(self, coords, bbox_state): bbox_final = np.rint(bbox_final).astype(np.int32) self.bboxes[new_label].append(bbox_final) - self.update_bbox_layer(self.bboxes) prediction = self.predict_sam(points=None, labels=None, bbox=copy.deepcopy(bbox_final), x_coord=x_coord) @@ -1159,7 +1248,7 @@ def update_points_layer(self, points): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) - self.points_layer = self.viewer.add_points(name=self.points_layer_name, data=np.asarray(points_flattened), face_color=colors_flattended, edge_color="white", size=self.point_size) + self.points_layer = self.viewer.add_points(name=self.points_layer_name, data=np.asarray(points_flattened), face_color=colors_flattended, size=self.point_size) self.points_layer.editable = False if selected_layer is not None: @@ -1168,18 +1257,11 @@ def update_points_layer(self, points): def update_bbox_layer(self, bboxes, bbox_tmp=None): self.bbox_edge_width = int(self.le_bbox_edge_width.text()) - bboxes_flattened = [] - edge_colors = [] - for _, bbox in bboxes.items(): - bboxes_flattened.extend(bbox) - edge_colors.extend(['skyblue'] * len(bbox)) + self.bbox_node.line.set_data(color="steelblue", width=self.bbox_edge_width) if bbox_tmp is not None: - bboxes_flattened.append(bbox_tmp) - edge_colors.append('steelblue') - self.bbox_layer.data = bboxes_flattened - self.bbox_layer.edge_width = [self.bbox_edge_width] * len(bboxes_flattened) - self.bbox_layer.edge_color = edge_colors - self.bbox_layer.face_color = [(0, 0, 0, 0)] * len(bboxes_flattened) + p0 = bbox_tmp[0] + p1 = bbox_tmp[2] + self.bbox_overlay.bounds = ((p0[0], p0[1]), (p1[0], p1[1])) def find_changed_point(self, old_points, new_points): if len(new_points) == 0: @@ -1373,5 +1455,9 @@ def get_cached_weight_types(self, model_types): return cached_weight_types + def _toggle_live_view(self): + if not self.check_live_view.isChecked(): + self.live_overlay_visual.remove_current() + # def _myfilter(self, row, parent): # return "" not in self.viewer.layers[row].name \ No newline at end of file