Skip to content

Commit

Permalink
Add siammask serverless function (it doesn't work, need to serialize …
Browse files Browse the repository at this point in the history
…state)
  • Loading branch information
Nikita Manovich committed Jul 27, 2020
1 parent 0c6e033 commit 1973a07
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 0 deletions.
57 changes: 57 additions & 0 deletions serverless/pytorch/foolwood/siammask/nuclio/function.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
metadata:
name: pth.foolwood.siammask
namespace: cvat
annotations:
name: SiamMask
type: tracker
spec:
framework: pytorch

spec:
description: Fast Online Object Tracking and Segmentation
runtime: "python:3.6"
handler: main:handler
eventTimeout: 30s
env:
- name: PYTHONPATH
value: /opt/nuclio/SiamMask:/opt/nuclio/SiamMask/experiments/siammask_sharp

build:
image: cvat/pth.foolwood.siammask
baseImage: continuumio/miniconda3

directives:
preCopy:
- kind: WORKDIR
value: /opt/nuclio
- kind: RUN
value: conda create -y -n siammask python=3.6
- kind: RUN
value: source activate siammask
- kind: RUN
value: git clone https://github.com/foolwood/SiamMask.git
- kind: RUN
value: pip install -r SiamMask/requirements.txt
- kind: RUN
value: conda install -y gcc_linux-64
- kind: RUN
value: cd SiamMask && bash make.sh && cd -
- kind: RUN
value: wget -P SiamMask/experiments/siammask_sharp http://www.robots.ox.ac.uk/~qwang/SiamMask_DAVIS.pth

- kind: WORKDIR
value: /opt/nuclio/pysot

triggers:
myHttpTrigger:
maxWorkers: 2
kind: "http"
workerAvailabilityTimeoutMilliseconds: 10000
attributes:
maxRequestBodySize: 33554432 # 32MB

platform:
attributes:
restartPolicy:
name: always
maximumRetryCount: 3
27 changes: 27 additions & 0 deletions serverless/pytorch/foolwood/siammask/nuclio/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import json
import base64
from PIL import Image
import io
from model_handler import ModelHandler

def init_context(context):
context.logger.info("Init context... 0%")

# Read the DL model
model = ModelHandler()
setattr(context.user_data, 'model', model)

context.logger.info("Init context...100%")

def handler(context, event):
context.logger.info("Run SiamMask model")
data = event.body
buf = io.BytesIO(base64.b64decode(data["image"].encode('utf-8')))
shape = data.get("shape")
state = data.get("state")
image = Image.open(buf)

results = context.user_data.model.infer(image, shape, state)

return context.Response(body=json.dumps(results), headers={},
content_type='application/json', status_code=200)
38 changes: 38 additions & 0 deletions serverless/pytorch/foolwood/siammask/nuclio/model_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (C) 2020 Intel Corporation
#
# SPDX-License-Identifier: MIT

from tools.test import *
import os

class ModelHandler:
def __init__(self):
# Setup device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

base_dir = "/opt/nuclio/SiamMask/experiments/siammask_sharp"
class configPath:
config = os.path.join(base_dir, "config_davis.json")

self.config = load_config(configPath)
from custom import Custom
siammask = Custom(anchors=self.config['anchors'])
self.siammask = load_pretrain(siammask, os.path.join(base_dir, "SiamMask_DAVIS.pth"))
self.siammask.eval().to(self.device)


def infer(self, image, shape, state):
if state is None: # init tracking
x, y, w, h = shape
target_pos = np.array([x + w / 2, y + h / 2])
target_sz = np.array([w, h])
state = siamese_init(image, target_pos, target_sz, self.siammask,
self.config['hp'], device=self.device)
else: # track
state = siamese_track(state, image, mask_enable=True, refine_enable=True,
device=self.device)
shape = state['ploygon'].flatten()

return {"shape": shape, "state": state}

0 comments on commit 1973a07

Please sign in to comment.