Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] V2 refactor #149

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP refactor
continue-revolution committed Jun 25, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 6718b4c76b093c9515733e0c134d5f1b511fa4eb
28 changes: 21 additions & 7 deletions mam/m2m.py
Original file line number Diff line number Diff line change
@@ -7,23 +7,31 @@

import m2ms
from modules.devices import get_device_for
from scripts.sam_state import sam_extension_dir
from scripts.sam_log import logger

class SamM2M(Module):
def __init__(self, m2m='sam_decoder_deep', ckpt_path=None, device=None):

def __init__(self):
super(SamM2M, self).__init__()
self.m2m_device = get_device_for("sam")


def load_m2m(self, m2m='sam_decoder_deep', ckpt_path=None):
if m2m not in m2ms.__all__:
raise NotImplementedError(f"Unknown M2M {m2m}")
self.m2m: Module = m2ms.__dict__[m2m](nc=256)
if ckpt_path is None:
ckpt_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'models/mam')
ckpt_path = os.path.join(sam_extension_dir, 'models/mam')
try:
state_dict = torch.load(os.path.join(ckpt_path, 'mam.pth'), map_location=device)
logger.info(f"Loading mam from path: {ckpt_path}/mam.pth to device: {self.m2m_device}")
state_dict = torch.load(os.path.join(ckpt_path, 'mam.pth'), map_location=self.m2m_device)
except:
state_dict = torch.hub.load_state_dict_from_url(
"https://huggingface.co/conrevo/Matting-Anything-diff/resolve/main/mam.pth", ckpt_path, device)
mam_url = "https://huggingface.co/conrevo/SAM4WebUI-Extension-Models/resolve/main/mam.pth"
logger.info(f"Loading mam from url: {mam_url} to path: {ckpt_path}, device: {self.m2m_device}")
state_dict = torch.hub.load_state_dict_from_url(mam_url, ckpt_path, self.m2m_device)
self.m2m.load_state_dict(state_dict)
self.m2m.eval()
self.m2m_device = get_device_for("sam") if device is None else device


def forward(self, feas, image, masks):
@@ -32,5 +40,11 @@ def forward(self, feas, image, masks):
return pred


def clear(self):
del self.m2m
self.m2m = None


def unload_model(self):
self.m2m.to('cpu')
if self.m2m is not None:
self.m2m.cpu()
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
segment_anything
supervision
supervision
ultralytics
14 changes: 8 additions & 6 deletions sam_hq/build_sam_hq.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
from .modeling.image_encoder import ImageEncoderViTHQ
from segment_anything.modeling import PromptEncoder, Sam, TwoWayTransformer
from segment_anything import build_sam_vit_h, build_sam_vit_l, build_sam_vit_b
from ultralytics import YOLO


def build_sam_hq_vit_h(checkpoint=None):
@@ -45,12 +46,13 @@ def build_sam_hq_vit_b(checkpoint=None):


sam_model_registry = {
"sam_vit_h": build_sam_vit_h,
"sam_vit_l": build_sam_vit_l,
"sam_vit_b": build_sam_vit_b,
"sam_hq_vit_h": build_sam_hq_vit_h,
"sam_hq_vit_l": build_sam_hq_vit_l,
"sam_hq_vit_b": build_sam_hq_vit_b,
"sam_vit_h_4b8939.pth (Meta, 2.56GB)" : build_sam_vit_h,
"sam_vit_l_0b3195.pth (Meta, 1.25GB)" : build_sam_vit_l,
"sam_vit_b_01ec64.pth (Meta, 375MB)" : build_sam_vit_b,
"sam_hq_vit_h.pth (SysCV, 2.57GB)" : build_sam_hq_vit_h,
"sam_hq_vit_l.pth (SysCV, 1.25GB)" : build_sam_hq_vit_l,
"sam_hq_vit_b.pth (SysCV, 379MB)" : build_sam_hq_vit_b,
"FastSAM-x.pt (CASIA-IVA-Lab, 138MB)" : YOLO,
}


10 changes: 1 addition & 9 deletions sam_hq/predictor.py
Original file line number Diff line number Diff line change
@@ -2,15 +2,13 @@
import torch
from segment_anything import SamPredictor
from segment_anything.modeling import Sam
from modules.devices import get_device_for

class SamPredictorHQ(SamPredictor):

def __init__(
self,
sam_model: Sam,
sam_is_hq: bool = False,
sam_device: str = None,
) -> None:
"""
Uses SAM to calculate the image embedding for an image, and then
@@ -21,13 +19,7 @@ def __init__(
"""
super().__init__(sam_model=sam_model)
self.is_hq = sam_is_hq
self.sam_device = get_device_for('sam') if sam_device is None else sam_device
self.model = self.model.eval().to(self.sam_device)


def unload_model(self):
if self.model is not None:
self.model.cpu()
self.model = self.model.eval()


@torch.no_grad()
131 changes: 2 additions & 129 deletions scripts/sam.py
Original file line number Diff line number Diff line change
@@ -7,30 +7,19 @@
import torch
import gradio as gr
from collections import OrderedDict
from scipy.ndimage import binary_dilation
from modules import scripts, shared, script_callbacks
from modules.ui import gr_show
from modules.ui_components import FormRow
from modules.safe import unsafe_torch_load, load
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessing
from modules.devices import device, torch_gc, cpu
from modules.paths import models_path

from sam_hq.predictor import SamPredictorHQ
from sam_hq.build_sam_hq import sam_model_registry
from scripts.sam_dino import dino_model_list, dino_predict_internal, show_boxes, clear_dino_cache, dino_install_issue_text
from scripts.sam_auto import clear_sem_sam_cache, register_auto_sam, semantic_segmentation, sem_sam_garbage_collect, image_layer_internal, categorical_mask_image
from scripts.sam_process import SAMProcessUnit, max_cn_num


refresh_symbol = '\U0001f504' # 🔄
sam_model_cache = OrderedDict()
scripts_sam_model_dir = os.path.join(scripts.basedir(), "models/sam")
sd_sam_model_dir = os.path.join(models_path, "sam")
sam_model_dir = sd_sam_model_dir if os.path.exists(sd_sam_model_dir) else scripts_sam_model_dir
sam_model_list = [f for f in os.listdir(sam_model_dir) if os.path.isfile(os.path.join(sam_model_dir, f)) and f.split('.')[-1] != 'txt']
sam_device = device


txt2img_width: gr.Slider = None
txt2img_height: gr.Slider = None
img2img_width: gr.Slider = None
@@ -45,43 +34,6 @@ def __init__(self, **kwargs):

def get_block_name(self):
return "button"


def show_masks(image_np, masks: np.ndarray, alpha=0.5):
image = copy.deepcopy(image_np)
np.random.seed(0)
for mask in masks:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
image[mask] = image[mask] * (1 - alpha) + 255 * color.reshape(1, 1, -1) * alpha
return image.astype(np.uint8)


def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image):
print("Dilation Amount: ", dilation_amt)
if isinstance(mask_gallery, list):
mask_image = Image.open(mask_gallery[chosen_mask + 3]['name'])
else:
mask_image = mask_gallery
binary_img = np.array(mask_image.convert('1'))
if dilation_amt:
mask_image, binary_img = dilate_mask(binary_img, dilation_amt)
blended_image = Image.fromarray(show_masks(np.array(input_image), binary_img.astype(np.bool_)[None, ...]))
matted_image = np.array(input_image)
matted_image[~binary_img] = np.array([0, 0, 0, 0])
return [blended_image, mask_image, Image.fromarray(matted_image)]


def load_sam_model(sam_checkpoint):
model_type = sam_checkpoint.split('.')[0]
if 'hq' not in model_type:
model_type = '_'.join(model_type.split('_')[:-1])
sam_checkpoint_path = os.path.join(sam_model_dir, sam_checkpoint)
torch.load = unsafe_torch_load
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_path)
sam.to(device=sam_device)
sam.eval()
torch.load = load
return sam


def clear_sam_cache():
@@ -103,86 +55,6 @@ def garbage_collect(sam):
torch_gc()


def refresh_sam_models(*inputs):
global sam_model_list
sam_model_list = [f for f in os.listdir(sam_model_dir) if os.path.isfile(
os.path.join(sam_model_dir, f)) and f.split('.')[-1] != 'txt']
dd = inputs[0]
if dd in sam_model_list:
selected = dd
elif len(sam_model_list) > 0:
selected = sam_model_list[0]
else:
selected = None
return gr.Dropdown.update(choices=sam_model_list, value=selected)


def init_sam_model(sam_model_name):
print(f"Initializing SAM to {sam_device}")
if sam_model_name in sam_model_cache:
sam = sam_model_cache[sam_model_name]
if shared.cmd_opts.lowvram or (str(sam_device) not in str(sam.device)):
sam.to(device=sam_device)
return sam
elif sam_model_name in sam_model_list:
clear_sam_cache()
sam_model_cache[sam_model_name] = load_sam_model(sam_model_name)
return sam_model_cache[sam_model_name]
else:
raise Exception(
f"{sam_model_name} not found, please download model to models/sam.")


def dilate_mask(mask, dilation_amt):
x, y = np.meshgrid(np.arange(dilation_amt), np.arange(dilation_amt))
center = dilation_amt // 2
dilation_kernel = ((x - center)**2 + (y - center)**2 <= center**2).astype(np.uint8)
dilated_binary_img = binary_dilation(mask, dilation_kernel)
dilated_mask = Image.fromarray(dilated_binary_img.astype(np.uint8) * 255)
return dilated_mask, dilated_binary_img


def create_mask_output(image_np, masks, boxes_filt):
print("Creating output image")
mask_images, masks_gallery, matted_images = [], [], []
boxes_filt = boxes_filt.numpy().astype(int) if boxes_filt is not None else None
for mask in masks:
masks_gallery.append(Image.fromarray(np.any(mask, axis=0)))
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
mask_images.append(Image.fromarray(blended_image))
image_np_copy = copy.deepcopy(image_np)
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0])
matted_images.append(Image.fromarray(image_np_copy))
return mask_images + masks_gallery + matted_images


def create_mask_batch_output(
input_image_file, dino_batch_dest_dir,
image_np, masks, boxes_filt, batch_dilation_amt,
dino_batch_save_image, dino_batch_save_mask, dino_batch_save_background, dino_batch_save_image_with_mask):
print("Creating batch output image")
filename, ext = os.path.splitext(os.path.basename(input_image_file))
ext = ".png" # JPEG not compatible with RGBA
for idx, mask in enumerate(masks):
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
merged_mask = np.any(mask, axis=0)
if dino_batch_save_background:
merged_mask = ~merged_mask
if batch_dilation_amt:
_, merged_mask = dilate_mask(merged_mask, batch_dilation_amt)
image_np_copy = copy.deepcopy(image_np)
image_np_copy[~merged_mask] = np.array([0, 0, 0, 0])
if dino_batch_save_image:
output_image = Image.fromarray(image_np_copy)
output_image.save(os.path.join(dino_batch_dest_dir, f"{filename}_{idx}_output{ext}"))
if dino_batch_save_mask:
output_mask = Image.fromarray(merged_mask)
output_mask.save(os.path.join(dino_batch_dest_dir, f"{filename}_{idx}_mask{ext}"))
if dino_batch_save_image_with_mask:
output_blend = Image.fromarray(blended_image)
output_blend.save(os.path.join(dino_batch_dest_dir, f"{filename}_{idx}_blend{ext}"))


def sam_predict(sam_model_name, input_image, positive_points, negative_points,
dino_checkbox, dino_model_name, text_prompt, box_threshold,
dino_preview_checkbox, dino_preview_boxes_selection):
@@ -812,6 +684,7 @@ def on_after_component(component, **_kwargs):
def on_ui_settings():
section = ('segment_anything', "Segment Anything")
shared.opts.add_option("sam_use_local_groundingdino", shared.OptionInfo(False, "Use local groundingdino to bypass C++ problem", section=section))
shared.opts.add_option("sam_model_path", shared.OptionInfo("", "Specify SAM model path", section=section))


script_callbacks.on_ui_settings(on_ui_settings)
22 changes: 11 additions & 11 deletions scripts/sam_auto.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List, Tuple
import os
import glob
import copy
@@ -14,12 +15,12 @@ def __init__(self) -> None:
self.auto_sam: SamAutomaticMaskGeneratorHQ = None


def blend_image_and_seg(self, image: np.ndarray, seg: np.ndarray, alpha=0.5):
def blend_image_and_seg(self, image: np.ndarray, seg: np.ndarray, alpha=0.5) -> Image.Image:
image_blend = image * (1 - alpha) + np.array(seg) * alpha
return Image.fromarray(image_blend.astype(np.uint8))


def strengthen_semmantic_seg(self, class_ids: np.ndarray, img: np.ndarray):
def strengthen_semmantic_seg(self, class_ids: np.ndarray, img: np.ndarray) -> np.ndarray:
logger.info("AutoSAM strengthening semantic segmentation")
import pycocotools.mask as maskUtils
semantc_mask = copy.deepcopy(class_ids)
@@ -39,11 +40,10 @@ def strengthen_semmantic_seg(self, class_ids: np.ndarray, img: np.ndarray):
return semantc_mask


def random_segmentation(self, img: Image.Image):
def random_segmentation(self, img: Image.Image) -> Tuple[List[Image.Image], str]:
logger.info("AutoSAM generating random segmentation for EditAnything")
img_np = np.array(img.convert("RGB"))
annotations = self.auto_sam.generate(img_np)
# annotations = sorted(annotations, key=lambda x: x['area'], reverse=True)
logger.info(f"AutoSAM generated {len(annotations)} masks")
H, W, _ = img_np.shape
color_map = np.zeros((H, W, 3), dtype=np.uint8)
@@ -65,7 +65,7 @@ def random_segmentation(self, img: Image.Image):
"Random segmentation done. Left above (0) is blended image, right above (1) is random segmentation, left below (2) is Edit-Anything control input."


def layer_single_image(self, layout_input_image: Image.Image, layout_output_path: str):
def layer_single_image(self, layout_input_image: Image.Image, layout_output_path: str) -> None:
img_np = np.array(layout_input_image.convert("RGB"))
annotations = self.auto_sam.generate(img_np)
logger.info(f"AutoSAM generated {len(annotations)} annotations")
@@ -80,7 +80,7 @@ def layer_single_image(self, layout_input_image: Image.Image, layout_output_path
img_np.save(os.path.join(layout_output_path, "leftover.png"))


def image_layer(self, layout_input_image_or_path, layout_output_path: str):
def image_layer(self, layout_input_image_or_path, layout_output_path: str) -> str:
if isinstance(layout_input_image_or_path, str):
logger.info("Image layer division batch processing")
all_files = glob.glob(os.path.join(layout_input_image_or_path, "*"))
@@ -101,7 +101,7 @@ def image_layer(self, layout_input_image_or_path, layout_output_path: str):


def semantic_segmentation(self, input_image: Image.Image, annotator_name: str, processor_res: int,
use_pixel_perfect: bool, resize_mode: int, target_W: int, target_H: int):
use_pixel_perfect: bool, resize_mode: int, target_W: int, target_H: int) -> Tuple[List[Image.Image], str]:
if input_image is None:
return [], "No input image."
if "seg" in annotator_name:
@@ -131,8 +131,8 @@ def semantic_segmentation(self, input_image: Image.Image, annotator_name: str, p
return self.random_segmentation(input_image)


def categorical_mask_image(self, crop_processor: str, crop_processor_res: int, crop_category_input: str, crop_input_image: Image.Image,
crop_pixel_perfect: bool, crop_resize_mode: int, target_W: int, target_H: int):
def categorical_mask_image(self, crop_processor: str, crop_processor_res: int, crop_category_input: List[int], crop_input_image: Image.Image,
crop_pixel_perfect: bool, crop_resize_mode: int, target_W: int, target_H: int) -> Tuple[np.ndarray, Image.Image]:
if crop_input_image is None:
return "No input image."
try:
@@ -144,7 +144,7 @@ def categorical_mask_image(self, crop_processor: str, crop_processor_res: int, c
}
except:
return [], "ControlNet extension not found."
filter_classes = crop_category_input.split('+')
filter_classes = crop_category_input
if len(filter_classes) == 0:
return "No class selected."
try:
@@ -163,7 +163,7 @@ def categorical_mask_image(self, crop_processor: str, crop_processor_res: int, c
original_semantic = oneformers[dataset](crop_input_image_np, crop_processor_res)
sam_semantic = self.strengthen_semmantic_seg(np.array(original_semantic), crop_input_image_np)
mask = np.zeros(sam_semantic.shape, dtype=np.bool_)
from scripts.semantic.category import SEMANTIC_CATEGORIES
from scripts.sam_config import SEMANTIC_CATEGORIES
for i in filter_classes:
mask[np.equal(sam_semantic, SEMANTIC_CATEGORIES[crop_processor][i])] = True
return mask, crop_input_image_copy
File renamed without changes.
330 changes: 158 additions & 172 deletions scripts/sam_dino.py
Original file line number Diff line number Diff line change
@@ -1,187 +1,173 @@
from typing import Tuple
import os
import gc
import cv2
import copy
from PIL import Image
import torch
from collections import OrderedDict

from modules import scripts, shared
from modules.devices import device, torch_gc, cpu
from modules import shared
from modules.devices import device, torch_gc
from scripts.sam_log import logger
import local_groundingdino

class GroundingDINO4SAM:

def __init__(self) -> None:
self.dino_model = None
self.dino_model_type = ""
from scripts.sam_state import sam_extension_dir
self.dino_model_dir = os.path.join(sam_extension_dir, "models/grounding-dino")
self.dino_model_list = ["GroundingDINO_SwinT_OGC (694MB)", "GroundingDINO_SwinB (938MB)"]
self.dino_model_info = {
"GroundingDINO_SwinT_OGC (694MB)": {
"checkpoint": "groundingdino_swint_ogc.pth",
"config": os.path.join(self.dino_model_dir, "GroundingDINO_SwinT_OGC.py"),
"url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth",
},
"GroundingDINO_SwinB (938MB)": {
"checkpoint": "groundingdino_swinb_cogcoor.pth",
"config": os.path.join(self.dino_model_dir, "GroundingDINO_SwinB.cfg.py"),
"url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth"
},
}
self.dino_install_issue_text = "Please permanently switch to local GroundingDINO on Settings/Segment Anything or submit an issue to https://github.com/IDEA-Research/Grounded-Segment-Anything/issues."


def _install_goundingdino(self) -> bool:
if shared.opts.data.get("sam_use_local_groundingdino", False):
logger.info("Using local groundingdino.")
return False

dino_model_cache = OrderedDict()
sam_extension_dir = scripts.basedir()
dino_model_dir = os.path.join(sam_extension_dir, "models/grounding-dino")
dino_model_list = ["GroundingDINO_SwinT_OGC (694MB)", "GroundingDINO_SwinB (938MB)"]
dino_model_info = {
"GroundingDINO_SwinT_OGC (694MB)": {
"checkpoint": "groundingdino_swint_ogc.pth",
"config": os.path.join(dino_model_dir, "GroundingDINO_SwinT_OGC.py"),
"url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth",
},
"GroundingDINO_SwinB (938MB)": {
"checkpoint": "groundingdino_swinb_cogcoor.pth",
"config": os.path.join(dino_model_dir, "GroundingDINO_SwinB.cfg.py"),
"url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth"
},
}
dino_install_issue_text = "permanently switch to local groundingdino on Settings/Segment Anything or submit an issue to https://github.com/IDEA-Research/Grounded-Segment-Anything/issues."


def install_goundingdino():
if shared.opts.data.get("sam_use_local_groundingdino", False):
print("Using local groundingdino.")
return False

def verify_dll(install_local=True):
def verify_dll(install_local=True):
try:
from groundingdino import _C
logger.info("GroundingDINO dynamic library have been successfully built.")
return True
except Exception:
import traceback
traceback.print_exc()
def run_pip_uninstall(command, desc=None):
from launch import python, run
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
return run(f'"{python}" -m pip uninstall -y {command}', desc=f"Uninstalling {desc}", errdesc=f"Couldn't uninstall {desc}", live=default_command_live)
if install_local:
logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and fall back to local GroundingDINO this time. {self.dino_install_issue_text}")
run_pip_uninstall(
f"groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
else:
logger.warn(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and re-try installing from GitHub source code. {self.dino_install_issue_text}")
run_pip_uninstall(
f"uninstall groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
return False

import launch
if launch.is_installed("groundingdino"):
logger.info("Found GroundingDINO in pip. Verifying if dynamic library build success.")
if verify_dll(install_local=False):
return True
try:
from groundingdino import _C
print("GroundingDINO dynamic library have been successfully built.")
return True
launch.run_pip(
f"install git+https://github.com/IDEA-Research/GroundingDINO",
f"sd-webui-segment-anything requirement: groundingdino")
logger.info("GroundingDINO install success. Verifying if dynamic library build success.")
return verify_dll()
except Exception:
import traceback
traceback.print_exc()
def run_pip_uninstall(command, desc=None):
from launch import python, run
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
return run(f'"{python}" -m pip uninstall -y {command}', desc=f"Uninstalling {desc}", errdesc=f"Couldn't uninstall {desc}", live=default_command_live)
if install_local:
print(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and fall back to local groundingdino this time. Please {dino_install_issue_text}")
run_pip_uninstall(
f"groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
else:
print(f"Failed to build dymanic library. Will uninstall GroundingDINO from pip and re-try installing from GitHub source code. Please {dino_install_issue_text}")
run_pip_uninstall(
f"uninstall groundingdino",
f"sd-webui-segment-anything requirement: groundingdino")
logger.warn(f"GroundingDINO install failed. Will fall back to local groundingdino this time. {self.dino_install_issue_text}")
return False

import launch
if launch.is_installed("groundingdino"):
print("Found GroundingDINO in pip. Verifying if dynamic library build success.")
if verify_dll(install_local=False):
return True
try:
launch.run_pip(
f"install git+https://github.com/IDEA-Research/GroundingDINO",
f"sd-webui-segment-anything requirement: groundingdino")
print("GroundingDINO install success. Verifying if dynamic library build success.")
return verify_dll()
except Exception:
import traceback
traceback.print_exc()
print(f"GroundingDINO install failed. Will fall back to local groundingdino this time. Please {dino_install_issue_text}")
return False


def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=False):
if boxes is None:
return image_np

image = copy.deepcopy(image_np)
for idx, box in enumerate(boxes):
x, y, w, h = box
cv2.rectangle(image, (x, y), (w, h), color, thickness)
if show_index:
font = cv2.FONT_HERSHEY_SIMPLEX
text = str(idx)
textsize = cv2.getTextSize(text, font, 1, 2)[0]
cv2.putText(image, text, (x, y+textsize[1]), font, 1, color, thickness)

return image


def clear_dino_cache():
dino_model_cache.clear()
gc.collect()
torch_gc()


def load_dino_model(dino_checkpoint, dino_install_success):
print(f"Initializing GroundingDINO {dino_checkpoint}")
if dino_checkpoint in dino_model_cache:
dino = dino_model_cache[dino_checkpoint]
if shared.cmd_opts.lowvram:
dino.to(device=device)
else:
clear_dino_cache()

def _load_dino_model(self, dino_checkpoint: str, dino_install_success: bool) -> torch.nn.Module:
logger.info(f"Initializing GroundingDINO {dino_checkpoint}")
if self.dino_model is None or dino_checkpoint != self.dino_model_type:
self.clear()
if dino_install_success:
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
else:
from local_groundingdino.models import build_model
from local_groundingdino.util.slconfig import SLConfig
from local_groundingdino.util.utils import clean_state_dict
args = SLConfig.fromfile(self.dino_model_info[dino_checkpoint]["config"])
dino = build_model(args)
checkpoint = torch.hub.load_state_dict_from_url(
self.dino_model_info[dino_checkpoint]["url"], self.dino_model_dir)
dino.load_state_dict(clean_state_dict(
checkpoint['model']), strict=False)
dino.eval()
self.dino_model = dino
self.dino_model_type = dino_checkpoint
self.dino_model.to(device=device)


def _load_dino_image(self, image_pil: Image.Image, dino_install_success: bool) -> torch.Tensor:
if dino_install_success:
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
import groundingdino.datasets.transforms as T
else:
from local_groundingdino.models import build_model
from local_groundingdino.util.slconfig import SLConfig
from local_groundingdino.util.utils import clean_state_dict
args = SLConfig.fromfile(dino_model_info[dino_checkpoint]["config"])
dino = build_model(args)
checkpoint = torch.hub.load_state_dict_from_url(
dino_model_info[dino_checkpoint]["url"], dino_model_dir)
dino.load_state_dict(clean_state_dict(
checkpoint['model']), strict=False)
dino.to(device=device)
dino_model_cache[dino_checkpoint] = dino
dino.eval()
return dino


def load_dino_image(image_pil, dino_install_success):
if dino_install_success:
import groundingdino.datasets.transforms as T
else:
from local_groundingdino.datasets import transforms as T
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image


def get_grounding_output(model, image, caption, box_threshold):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
if shared.cmd_opts.lowvram:
model.to(cpu)
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"][0] # (nq, 4)

# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4

return boxes_filt.cpu()


def dino_predict_internal(input_image, dino_model_name, text_prompt, box_threshold):
install_success = install_goundingdino()
print("Running GroundingDINO Inference")
dino_image = load_dino_image(input_image.convert("RGB"), install_success)
dino_model = load_dino_model(dino_model_name, install_success)
install_success = install_success or shared.opts.data.get("sam_use_local_groundingdino", False)

boxes_filt = get_grounding_output(
dino_model, dino_image, text_prompt, box_threshold
)

H, W = input_image.size[1], input_image.size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
gc.collect()
torch_gc()
return boxes_filt, install_success
from local_groundingdino.datasets import transforms as T
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image


def _get_grounding_output(self, image: torch.Tensor, caption: str, box_threshold: float):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
image = image.to(device)
with torch.no_grad():
outputs = self.dino_model(image[None], captions=[caption])
if shared.cmd_opts.lowvram:
self.dino_model.cpu()
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"][0] # (nq, 4)

# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4

return boxes_filt.cpu()


def dino_predict_internal(self, input_image: Image.Image, dino_model_name: str, text_prompt: str, box_threshold: float) -> Tuple[torch.Tensor, bool]:
install_success = self._install_goundingdino()
logger.info("Running GroundingDINO Inference")
dino_image = self._load_dino_image(input_image.convert("RGB"), install_success)
self._load_dino_model(dino_model_name, install_success)
using_groundingdino = install_success or shared.opts.data.get("sam_use_local_groundingdino", False)

boxes_filt = self._get_grounding_output(
dino_image, text_prompt, box_threshold
)

H, W = input_image.size[1], input_image.size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
gc.collect()
torch_gc()
return boxes_filt, using_groundingdino


def clear(self) -> None:
del self.dino_model
self.dino_model = None


def unload_model(self) -> None:
if self.dino_model is not None:
self.dino_model.cpu()
3 changes: 3 additions & 0 deletions scripts/sam_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from modules import scripts

sam_extension_dir = scripts.basedir()
57 changes: 57 additions & 0 deletions scripts/sam_tmp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import torch
from modules import shared
from modules.safe import unsafe_torch_load, load
from modules.devices import get_device_for, cpu
from modules.paths import models_path
from scripts.sam_state import sam_extension_dir
from scripts.sam_log import logger
from sam_hq.build_sam_hq import sam_model_registry
from sam_hq.predictor import SamPredictorHQ

class Segmentation:

def __init__(self) -> None:
self.sam_model_info = {
"sam_vit_h_4b8939.pth (Meta, 2.56GB)" : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"sam_vit_l_0b3195.pth (Meta, 1.25GB)" : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"sam_vit_b_01ec64.pth (Meta, 375MB)" : "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"sam_hq_vit_h.pth (SysCV, 2.57GB)" : "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth",
"sam_hq_vit_l.pth (SysCV, 1.25GB)" : "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth",
"sam_hq_vit_b.pth (SysCV, 379MB)" : "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth",
"FastSAM-x.pt (CASIA-IVA-Lab, 138MB)" : "https://huggingface.co/conrevo/SAM4WebUI-Extension-Models/resolve/main/FastSAM-x.pt"
}
self.sam_model = None
self.sam_model_type = ""
self.sam_model_wrapper = None
self.sam_device = get_device_for("sam")


def load_sam_model(self, sam_checkpoint_name: str) -> None:
if sam_checkpoint_name not in self.sam_model_info.keys():
logger.error(f"Invalid SAM model checkpoint name: {sam_checkpoint_name}")
elif self.sam_model is None or self.sam_model_type != sam_checkpoint_name:
logger.info(f"Initializing {sam_checkpoint_name} to {self.sam_device}")
user_sam_model_dir = shared.opts.data.get("sam_model_path", "")
sd_sam_model_dir = os.path.join(models_path, "sam")
scripts_sam_model_dir = os.path.join(sam_extension_dir, "models/sam")
sam_model_dir = user_sam_model_dir if user_sam_model_dir != "" else (sd_sam_model_dir if os.path.exists(sd_sam_model_dir) else scripts_sam_model_dir)
sam_checkpoint_path = os.path.join(sam_model_dir, sam_checkpoint_name)
if not os.path.exists(sam_checkpoint_path):
sam_url = self.sam_model_info[sam_checkpoint_name]
logger.info(f"Downloading SAM model from {sam_url} to {sam_checkpoint_path}")
torch.hub.download_url_to_file(sam_url, sam_model_dir)
logger.info(f"Loading SAM model from {sam_checkpoint_path}")
torch.load = unsafe_torch_load
self.sam_model = sam_model_registry[sam_checkpoint_name](checkpoint=sam_checkpoint_path).to(self.sam_device)
torch.load = load
self.sam_model_type = sam_checkpoint_name
self.sam_model_wrapper = SamPredictorHQ(self.sam_model, 'hq' in sam_checkpoint_name) if "Fast" not in sam_checkpoint_name else self.sam_model


def change_device(self, use_cpu: bool) -> None:
self.sam_device = cpu if use_cpu else get_device_for("sam")


def __call__(self, point_coords=None, point_labels=None, boxes=None, multimask_output=True, global_point=False, use_numpy=False, use_mam=False):
pass
20 changes: 20 additions & 0 deletions scripts/sam_ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import gradio as gr

class SamUI:

def __init__(self) -> None:
pass

def render_sam_model(self) -> None:
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
sam_model_name = gr.Dropdown(label="SAM Model", choices=sam_model_list, value=sam_model_list[0] if len(sam_model_list) > 0 else None)
sam_refresh_models = ToolButton(value=refresh_symbol)
sam_refresh_models.click(refresh_sam_models, sam_model_name, sam_model_name)
with gr.Column(scale=1):
sam_use_cpu = gr.Checkbox(value=False, label="Use CPU for SAM")
def change_sam_device(use_cpu=False):
global sam_device
sam_device = "cpu" if use_cpu else device
sam_use_cpu.change(fn=change_sam_device, inputs=[sam_use_cpu], show_progress=False)
93 changes: 93 additions & 0 deletions scripts/sam_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Tuple, List
import os
import numpy as np
import cv2
import copy
from scipy.ndimage import binary_dilation
from PIL import Image


def show_boxes(image_np: np.ndarray, boxes: np.ndarray, color=(255, 0, 0, 255), thickness=2, show_index=False) -> np.ndarray:
if boxes is None:
return image_np
image = copy.deepcopy(image_np)
for idx, box in enumerate(boxes):
x, y, w, h = box
cv2.rectangle(image, (x, y), (w, h), color, thickness)
if show_index:
font = cv2.FONT_HERSHEY_SIMPLEX
text = str(idx)
textsize = cv2.getTextSize(text, font, 1, 2)[0]
cv2.putText(image, text, (x, y+textsize[1]), font, 1, color, thickness)
return image


def show_masks(image_np: np.ndarray, masks: np.ndarray, alpha=0.5) -> np.ndarray:
image = copy.deepcopy(image_np)
np.random.seed(0)
for mask in masks:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
image[mask] = image[mask] * (1 - alpha) + 255 * color.reshape(1, 1, -1) * alpha
return image.astype(np.uint8)


def dilate_mask(mask: np.ndarray, dilation_amt: int) -> Tuple[Image.Image, np.ndarray]:
x, y = np.meshgrid(np.arange(dilation_amt), np.arange(dilation_amt))
center = dilation_amt // 2
dilation_kernel = ((x - center)**2 + (y - center)**2 <= center**2).astype(np.uint8)
dilated_binary_img = binary_dilation(mask, dilation_kernel)
dilated_mask = Image.fromarray(dilated_binary_img.astype(np.uint8) * 255)
return dilated_mask, dilated_binary_img


def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image):
if isinstance(mask_gallery, list):
mask_image = Image.open(mask_gallery[chosen_mask + 3]['name'])
else:
mask_image = mask_gallery
binary_img = np.array(mask_image.convert('1'))
if dilation_amt:
mask_image, binary_img = dilate_mask(binary_img, dilation_amt)
blended_image = Image.fromarray(show_masks(np.array(input_image), binary_img.astype(np.bool_)[None, ...]))
matted_image = np.array(input_image)
matted_image[~binary_img] = np.array([0, 0, 0, 0])
return [blended_image, mask_image, Image.fromarray(matted_image)]


def create_mask_output(image_np: np.ndarray, masks: np.ndarray, boxes_filt: np.ndarray) -> List[Image.Image]:
mask_images, masks_gallery, matted_images = [], [], []
boxes_filt = boxes_filt.astype(int) if boxes_filt is not None else None
for mask in masks:
masks_gallery.append(Image.fromarray(np.any(mask, axis=0)))
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
mask_images.append(Image.fromarray(blended_image))
image_np_copy = copy.deepcopy(image_np)
image_np_copy[~np.any(mask, axis=0)] = np.array([0, 0, 0, 0])
matted_images.append(Image.fromarray(image_np_copy))
return mask_images + masks_gallery + matted_images


def create_mask_batch_output(
input_image_filename: str, dest_dir: str,
image_np: np.ndarray, masks: np.ndarray, boxes_filt: np.ndarray, dilation_amt: float,
save_image: bool, save_mask: bool, save_background: bool, save_image_with_mask: bool):
filename, ext = os.path.splitext(os.path.basename(input_image_filename))
ext = ".png" # JPEG not compatible with RGBA
for idx, mask in enumerate(masks):
blended_image = show_masks(show_boxes(image_np, boxes_filt), mask)
merged_mask = np.any(mask, axis=0)
if save_background:
merged_mask = ~merged_mask
if dilation_amt:
_, merged_mask = dilate_mask(merged_mask, dilation_amt)
image_np_copy = copy.deepcopy(image_np)
image_np_copy[~merged_mask] = np.array([0, 0, 0, 0])
if save_image:
output_image = Image.fromarray(image_np_copy)
output_image.save(os.path.join(dest_dir, f"{filename}_{idx}_output{ext}"))
if save_mask:
output_mask = Image.fromarray(merged_mask)
output_mask.save(os.path.join(dest_dir, f"{filename}_{idx}_mask{ext}"))
if save_image_with_mask:
output_blend = Image.fromarray(blended_image)
output_blend.save(os.path.join(dest_dir, f"{filename}_{idx}_blend{ext}"))