From bc83f52640de3a710b5233f66f13a7318022c8bc Mon Sep 17 00:00:00 2001 From: 9xEzreaL Date: Wed, 26 Apr 2023 00:22:24 +0800 Subject: [PATCH] Refactor --- .gitignore | 2 -- salt/dataset_explorer.py | 34 +++++----------------------------- salt/editor.py | 11 ----------- salt/interface.py | 24 ++---------------------- 4 files changed, 7 insertions(+), 64 deletions(-) diff --git a/.gitignore b/.gitignore index 4fb423d..51a7132 100644 --- a/.gitignore +++ b/.gitignore @@ -134,6 +134,4 @@ dmypy.json # Dataset dataset/ models/ -sample/ -sam_vit_h_4b8939.pth flyer_pages/ diff --git a/salt/dataset_explorer.py b/salt/dataset_explorer.py index 7cc75ae..2e9ee44 100644 --- a/salt/dataset_explorer.py +++ b/salt/dataset_explorer.py @@ -1,7 +1,6 @@ from pycocotools import mask from skimage import measure import json -import tifffile as tiff import shutil import itertools import numpy as np @@ -102,7 +101,9 @@ def __init__(self, dataset_folder, categories=None, coco_json_path=None): self.dataset_folder = dataset_folder self.image_names = os.listdir(os.path.join(self.dataset_folder, "images")) self.image_names = [ - os.path.split(name)[1] for name in self.image_names if name.endswith(".jpg") or name.endswith(".png") or name.endswith(".tif") + os.path.split(name)[1] + for name in self.image_names + if name.endswith(".jpg") or name.endswith(".png") ] self.coco_json_path = coco_json_path if not os.path.exists(coco_json_path): @@ -148,9 +149,6 @@ def get_categories(self, get_colors=False): return self.categories, self.category_colors return self.categories - def get_review_status(self): - return ['not label', 'primary label', 'secondary label', 'final label'] - def get_num_images(self): return len(self.image_names) @@ -162,15 +160,9 @@ def get_image_data(self, image_id): "embeddings", os.path.splitext(os.path.split(image_name)[1])[0] + ".npy", ) - - # image = cv2.imread(image_path) - # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - - image = np.expand_dims(tiff.imread(image_path), 2) - image = np.concatenate([image, image, image], 2) - image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8) + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image_bgr = copy.deepcopy(image) - image_embedding = np.load(embedding_path) return image, image_bgr, image_embedding, image_name @@ -187,7 +179,6 @@ def get_annotations(self, image_id, return_colors=False): colors = [self.category_colors[c] for c in cats] if return_colors: return self.annotations_by_image_id[image_id], colors - return self.annotations_by_image_id[image_id] def delete_annotations(self, image_id, annotation_id): @@ -211,22 +202,7 @@ def add_annotation(self, image_id, category_id, mask, poly=True): self.__add_to_our_annotation_dict(annotation) self.coco_json["annotations"].append(annotation) self.global_annotation_id += 1 - # print(self.coco_json["annotations"]) - - def del_annotation(self, image_id): - for single_seg in self.coco_json["annotations"].copy(): - if single_seg["image_id"] == image_id: - self.coco_json["annotations"].remove(single_seg) - self.save_annotation() def save_annotation(self): with open(self.coco_json_path, "w") as f: json.dump(self.coco_json, f) - - # def save_status(self, file_name, status): - # self.coco_json["images"][file_name]["status"] = status - # self.save_annotation() - - def delete_annotation(self, image_id): - self.annotations_by_image_id[image_id] = [] - self.del_annotation(image_id) diff --git a/salt/editor.py b/salt/editor.py index 1d0a6db..dfe0ca5 100644 --- a/salt/editor.py +++ b/salt/editor.py @@ -194,8 +194,6 @@ def save_ann(self): elif self.curr_inputs.curr_point_mask is not None: tmp_combination = self.curr_inputs.curr_point_mask - # tmp_combination[tmp_combination > 0] = True - # tmp_combination[tmp_combination < 0] = False else: tmp_combination = None @@ -203,14 +201,12 @@ def save_ann(self): self.dataset_explorer.add_annotation( self.image_id, self.category_id, tmp_combination ) - # print(self.curr_inputs.curr_point_mask) # self.dataset_explorer.add_annotation( # self.image_id, self.category_id, self.curr_inputs.curr_point_mask # ) def save(self): self.dataset_explorer.save_annotation() - # self.dataset_explorer.save_status(self.image_id, self.status) def next_image(self): if self.image_id == self.dataset_explorer.get_num_images() - 1: @@ -274,11 +270,4 @@ def select_category(self, category_name): category_id = self.categories.index(category_name) self.category_id = category_id - # def select_status(self, category_name): - # status_id = self.status_categories.index(category_name) - # self.status = status_id - - # def get_status_name(self): - # status_name = self.status_categories[self.status] - # return status_name diff --git a/salt/interface.py b/salt/interface.py index 7e2d4c8..1f8186c 100644 --- a/salt/interface.py +++ b/salt/interface.py @@ -155,10 +155,6 @@ def __init__(self, app, editor, panel_size=(1920, 1080)): self.layout.addLayout(self.main_window) - # self.status_panel = self.get_review_side_panel() - # self.main_window.addWidget(self.status_panel) - # self.layout.addLayout(self.main_window) - self.label = QLabel() self.label.resize(200, 100) self.label.setText(f'{self.editor.name} ... 1/{self.editor.dataset_explorer.get_num_images()}') @@ -301,35 +297,19 @@ def annotation_list_item_clicked(self, item): selected_annotations.remove(int(item.text().split(" ")[0])) self.editor.draw_selected_annotations(selected_annotations) self.graphics_view.imshow(self.editor.display) - - # def get_review_side_panel(self): - # panel = QWidget() - # panel_layout = QVBoxLayout(panel) - # label_array = [] - # review_categories = ['not label', 'primary label', 'secondary label', 'final label'] - # for i, _ in enumerate(review_categories): - # label_array.append(QRadioButton(review_categories[i])) - # label_array[i].clicked.connect( - # lambda state, x=review_categories[i]: self.editor.select_status(x) - # ) - # panel_layout.addWidget(label_array[i]) - - # return panel - + def _update_label(self, name, image_id): self.label.setText(f'{name} ... {image_id+1}/{self.editor.dataset_explorer.get_num_images()}') self.layout.addWidget(self.label) self.setLayout(self.layout) - # def delete_annotations(self): - # self.editor.delete_annotations() - def jump2slice(self): self.editor.jump2image(int(self.box.currentText())) self._update_label(self.editor.name, self.editor.image_id) self.graphics_view.imshow(self.editor.display) def execute_mode(self): + # Here is for change point, paint, eraser mode self.graphics_view.update_PPE_mode(self.point_paint_era.currentText()) # print(self.point_paint_era.currentText()) # pass