-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial version of FBRS interactive segmentation (#2094)
* Initial version of FBRS interactive segmentation * Add min_pos_points for dextr * Fix fbrs serverless function. * Fix codacy issues. * Minor changes * Fix codacy issues. * Fix typo * Update CHANGELOG * Add license header * Fix comments in yaml
- Loading branch information
Nikita Manovich
authored
Sep 9, 2020
1 parent
4e21929
commit 5ebd91a
Showing
6 changed files
with
206 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
metadata: | ||
name: pth.saic-vul.fbrs | ||
namespace: cvat | ||
annotations: | ||
name: f-BRS | ||
type: interactor | ||
spec: | ||
framework: pytorch | ||
min_pos_points: 1 | ||
|
||
spec: | ||
description: f-BRS interactive segmentation | ||
runtime: "python:3.6" | ||
handler: main:handler | ||
eventTimeout: 30s | ||
env: | ||
- name: PYTHONPATH | ||
value: /opt/nuclio/fbrs | ||
|
||
build: | ||
image: cvat/pth.saic-vul.fbrs | ||
baseImage: python:3.6.11 | ||
|
||
directives: | ||
preCopy: | ||
- kind: WORKDIR | ||
value: /opt/nuclio | ||
- kind: RUN | ||
value: git clone https://github.com/saic-vul/fbrs_interactive_segmentation.git fbrs | ||
- kind: WORKDIR | ||
value: /opt/nuclio/fbrs | ||
- kind: ENV | ||
value: fileid=1Z9dQtpWVTobEdmUBntpUU0pJl-pEXUwR | ||
- kind: ENV | ||
value: filename=resnet101_dh256_sbd.pth | ||
- kind: RUN | ||
value: curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" | ||
- kind: RUN | ||
value: curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} | ||
- kind: RUN | ||
value: apt update && apt install -y libgl1-mesa-glx | ||
- kind: RUN | ||
value: pip3 install -r requirements.txt | ||
- kind: WORKDIR | ||
value: /opt/nuclio | ||
|
||
triggers: | ||
myHttpTrigger: | ||
maxWorkers: 2 | ||
kind: "http" | ||
workerAvailabilityTimeoutMilliseconds: 10000 | ||
attributes: | ||
maxRequestBodySize: 33554432 # 32MB | ||
|
||
platform: | ||
attributes: | ||
restartPolicy: | ||
name: always | ||
maximumRetryCount: 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (C) 2020 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import json | ||
import base64 | ||
from PIL import Image | ||
import io | ||
from model_handler import ModelHandler | ||
|
||
def init_context(context): | ||
context.logger.info("Init context... 0%") | ||
|
||
model = ModelHandler() | ||
setattr(context.user_data, 'model', model) | ||
|
||
context.logger.info("Init context...100%") | ||
|
||
def handler(context, event): | ||
context.logger.info("call handler") | ||
data = event.body | ||
pos_points = data["points"] | ||
neg_points = [] | ||
threshold = data.get("threshold", 0.5) | ||
buf = io.BytesIO(base64.b64decode(data["image"].encode('utf-8'))) | ||
image = Image.open(buf) | ||
|
||
polygon = context.user_data.model.handle(image, pos_points, | ||
neg_points, threshold) | ||
return context.Response(body=json.dumps(polygon), | ||
headers={}, | ||
content_type='application/json', | ||
status_code=200) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright (C) 2020 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import torch | ||
import numpy as np | ||
from torchvision import transforms | ||
import cv2 | ||
import os | ||
|
||
from isegm.inference.predictors import get_predictor | ||
from isegm.inference.utils import load_deeplab_is_model, load_hrnet_is_model | ||
from isegm.inference.clicker import Clicker, Click | ||
|
||
def convert_mask_to_polygon(mask): | ||
mask = np.array(mask, dtype=np.uint8) | ||
cv2.normalize(mask, mask, 0, 255, cv2.NORM_MINMAX) | ||
contours = None | ||
if int(cv2.__version__.split('.')[0]) > 3: | ||
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)[0] | ||
else: | ||
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)[1] | ||
|
||
contours = max(contours, key=lambda arr: arr.size) | ||
if contours.shape.count(1): | ||
contours = np.squeeze(contours) | ||
if contours.size < 3 * 2: | ||
raise Exception('Less then three point have been detected. Can not build a polygon.') | ||
|
||
polygon = [] | ||
for point in contours: | ||
polygon.append([int(point[0]), int(point[1])]) | ||
|
||
return polygon | ||
|
||
class ModelHandler: | ||
def __init__(self): | ||
torch.backends.cudnn.deterministic = True | ||
base_dir = os.environ.get("MODEL_PATH", "/opt/nuclio/fbrs") | ||
model_path = os.path.join(base_dir, "resnet101_dh256_sbd.pth") | ||
state_dict = torch.load(model_path, map_location='cpu') | ||
|
||
self.net = None | ||
backbone = 'auto' | ||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
for k in state_dict.keys(): | ||
if 'feature_extractor.stage2.0.branches' in k: | ||
self.net = load_hrnet_is_model(state_dict, self.device, backbone) | ||
break | ||
|
||
if self.net is None: | ||
self.net = load_deeplab_is_model(state_dict, self.device, backbone) | ||
self.net.to(self.device) | ||
|
||
def handle(self, image, pos_points, neg_points, threshold): | ||
input_transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize([.485, .456, .406], [.229, .224, .225]) | ||
]) | ||
|
||
image_nd = input_transform(image).to(self.device) | ||
|
||
clicker = Clicker() | ||
for x, y in pos_points: | ||
click = Click(is_positive=True, coords=(y, x)) | ||
clicker.add_click(click) | ||
|
||
for x, y in neg_points: | ||
click = Click(is_positive=False, coords=(y, x)) | ||
clicker.add_click(click) | ||
|
||
predictor_params = { | ||
'brs_mode': 'f-BRS-B', | ||
'brs_opt_func_params': {'min_iou_diff': 0.001}, | ||
'lbfgs_params': {'maxfun': 20}, | ||
'predictor_params': {'max_size': 800, 'net_clicks_limit': 8}, | ||
'prob_thresh': threshold, | ||
'zoom_in_params': {'expansion_ratio': 1.4, 'skip_clicks': 1, 'target_size': 480}} | ||
predictor = get_predictor(self.net, device=self.device, | ||
**predictor_params) | ||
predictor.set_input_image(image_nd) | ||
|
||
object_prob = predictor.get_prediction(clicker) | ||
if self.device == 'cuda': | ||
torch.cuda.empty_cache() | ||
object_mask = object_prob > threshold | ||
polygon = convert_mask_to_polygon(object_mask) | ||
|
||
return polygon | ||
|
||
|