Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
9xEzreaL committed Apr 25, 2023
1 parent bd85198 commit bc83f52
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 64 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,4 @@ dmypy.json
# Dataset
dataset/
models/
sample/
sam_vit_h_4b8939.pth
flyer_pages/
34 changes: 5 additions & 29 deletions salt/dataset_explorer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pycocotools import mask
from skimage import measure
import json
import tifffile as tiff
import shutil
import itertools
import numpy as np
Expand Down Expand Up @@ -102,7 +101,9 @@ def __init__(self, dataset_folder, categories=None, coco_json_path=None):
self.dataset_folder = dataset_folder
self.image_names = os.listdir(os.path.join(self.dataset_folder, "images"))
self.image_names = [
os.path.split(name)[1] for name in self.image_names if name.endswith(".jpg") or name.endswith(".png") or name.endswith(".tif")
os.path.split(name)[1]
for name in self.image_names
if name.endswith(".jpg") or name.endswith(".png")
]
self.coco_json_path = coco_json_path
if not os.path.exists(coco_json_path):
Expand Down Expand Up @@ -148,9 +149,6 @@ def get_categories(self, get_colors=False):
return self.categories, self.category_colors
return self.categories

def get_review_status(self):
return ['not label', 'primary label', 'secondary label', 'final label']

def get_num_images(self):
return len(self.image_names)

Expand All @@ -162,15 +160,9 @@ def get_image_data(self, image_id):
"embeddings",
os.path.splitext(os.path.split(image_name)[1])[0] + ".npy",
)

# image = cv2.imread(image_path)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

image = np.expand_dims(tiff.imread(image_path), 2)
image = np.concatenate([image, image, image], 2)
image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_bgr = copy.deepcopy(image)

image_embedding = np.load(embedding_path)
return image, image_bgr, image_embedding, image_name

Expand All @@ -187,7 +179,6 @@ def get_annotations(self, image_id, return_colors=False):
colors = [self.category_colors[c] for c in cats]
if return_colors:
return self.annotations_by_image_id[image_id], colors

return self.annotations_by_image_id[image_id]

def delete_annotations(self, image_id, annotation_id):
Expand All @@ -211,22 +202,7 @@ def add_annotation(self, image_id, category_id, mask, poly=True):
self.__add_to_our_annotation_dict(annotation)
self.coco_json["annotations"].append(annotation)
self.global_annotation_id += 1
# print(self.coco_json["annotations"])

def del_annotation(self, image_id):
for single_seg in self.coco_json["annotations"].copy():
if single_seg["image_id"] == image_id:
self.coco_json["annotations"].remove(single_seg)
self.save_annotation()

def save_annotation(self):
with open(self.coco_json_path, "w") as f:
json.dump(self.coco_json, f)

# def save_status(self, file_name, status):
# self.coco_json["images"][file_name]["status"] = status
# self.save_annotation()

def delete_annotation(self, image_id):
self.annotations_by_image_id[image_id] = []
self.del_annotation(image_id)
11 changes: 0 additions & 11 deletions salt/editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,23 +194,19 @@ def save_ann(self):

elif self.curr_inputs.curr_point_mask is not None:
tmp_combination = self.curr_inputs.curr_point_mask
# tmp_combination[tmp_combination > 0] = True
# tmp_combination[tmp_combination < 0] = False

else:
tmp_combination = None

self.dataset_explorer.add_annotation(
self.image_id, self.category_id, tmp_combination
)
# print(self.curr_inputs.curr_point_mask)
# self.dataset_explorer.add_annotation(
# self.image_id, self.category_id, self.curr_inputs.curr_point_mask
# )

def save(self):
self.dataset_explorer.save_annotation()
# self.dataset_explorer.save_status(self.image_id, self.status)

def next_image(self):
if self.image_id == self.dataset_explorer.get_num_images() - 1:
Expand Down Expand Up @@ -274,11 +270,4 @@ def select_category(self, category_name):
category_id = self.categories.index(category_name)
self.category_id = category_id

# def select_status(self, category_name):
# status_id = self.status_categories.index(category_name)
# self.status = status_id

# def get_status_name(self):
# status_name = self.status_categories[self.status]
# return status_name

24 changes: 2 additions & 22 deletions salt/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,6 @@ def __init__(self, app, editor, panel_size=(1920, 1080)):

self.layout.addLayout(self.main_window)

# self.status_panel = self.get_review_side_panel()
# self.main_window.addWidget(self.status_panel)
# self.layout.addLayout(self.main_window)

self.label = QLabel()
self.label.resize(200, 100)
self.label.setText(f'{self.editor.name} ... 1/{self.editor.dataset_explorer.get_num_images()}')
Expand Down Expand Up @@ -301,35 +297,19 @@ def annotation_list_item_clicked(self, item):
selected_annotations.remove(int(item.text().split(" ")[0]))
self.editor.draw_selected_annotations(selected_annotations)
self.graphics_view.imshow(self.editor.display)

# def get_review_side_panel(self):
# panel = QWidget()
# panel_layout = QVBoxLayout(panel)
# label_array = []
# review_categories = ['not label', 'primary label', 'secondary label', 'final label']
# for i, _ in enumerate(review_categories):
# label_array.append(QRadioButton(review_categories[i]))
# label_array[i].clicked.connect(
# lambda state, x=review_categories[i]: self.editor.select_status(x)
# )
# panel_layout.addWidget(label_array[i])

# return panel


def _update_label(self, name, image_id):
self.label.setText(f'{name} ... {image_id+1}/{self.editor.dataset_explorer.get_num_images()}')
self.layout.addWidget(self.label)
self.setLayout(self.layout)

# def delete_annotations(self):
# self.editor.delete_annotations()

def jump2slice(self):
self.editor.jump2image(int(self.box.currentText()))
self._update_label(self.editor.name, self.editor.image_id)
self.graphics_view.imshow(self.editor.display)

def execute_mode(self):
# Here is for change point, paint, eraser mode
self.graphics_view.update_PPE_mode(self.point_paint_era.currentText())
# print(self.point_paint_era.currentText())
# pass
Expand Down

0 comments on commit bc83f52

Please sign in to comment.