forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
214 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# pylint: disable=wildcard-import | ||
"""Faster R-CNN and Mask R-CNN operators""" | ||
from .roi_align import * | ||
from .roi_align_v2 import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# pylint: disable=invalid-name | ||
"""Roi align v2 in | ||
https://github.com/TuSimple/mxnet/blob/master/src/operator/contrib/roi_align_v2-inl.h""" | ||
|
||
import tvm | ||
from ...util import get_const_tuple | ||
|
||
|
||
@tvm.target.generic_func | ||
def roi_align_v2(data, rois, pooled_size, spatial_scale): | ||
"""ROI align operator in NCHW layout. | ||
Parameters | ||
---------- | ||
data : tvm.Tensor | ||
4-D with shape [batch, channel, height, width] | ||
rois : tvm.Tensor | ||
2-D with shape [num_roi, 5]. The last dimension should be in format of | ||
[batch_index, w_start, h_start, w_end, h_end] | ||
pooled_size : int or list/tuple of two ints | ||
output size, or [out_height, out_width] | ||
spatial_scale : float | ||
Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal | ||
of total stride in convolutional layers, which should be in range (0.0, 1.0] | ||
Returns | ||
------- | ||
output : tvm.Tensor | ||
4-D with shape [num_roi, channel, pooled_size, pooled_size] | ||
""" | ||
_, channel, height, width = get_const_tuple(data.shape) | ||
num_roi, _ = get_const_tuple(rois.shape) | ||
|
||
def _bilinear(i, c, y, x): | ||
y_low = y.astype('int32') | ||
x_low = x.astype('int32') | ||
y_high = tvm.min(tvm.ceil(y).astype('int32'), height - 1) | ||
x_high = tvm.min(tvm.ceil(x).astype('int32'), width - 1) | ||
y_lerp = y - y_low | ||
x_lerp = x - x_low | ||
bottom = x_lerp * data[i, c, y_high, x_high] + \ | ||
(1-x_lerp) * data[i, c, y_high, x_low] | ||
top = x_lerp * data[i, c, y_low, x_high] + \ | ||
(1-x_lerp) * data[i, c, y_low, x_low] | ||
return y_lerp * bottom + (1-y_lerp) * top | ||
|
||
def _sample(i, c, ph, pw): | ||
roi = rois[i] | ||
batch_index = roi[0].astype('int32') | ||
roi_start_w = roi[1] * spatial_scale | ||
roi_start_h = roi[2] * spatial_scale | ||
roi_end_w = roi[3] * spatial_scale | ||
roi_end_h = roi[4] * spatial_scale | ||
|
||
roi_h = roi_end_h - roi_start_h | ||
roi_w = roi_end_w - roi_start_w | ||
roi_h = roi_h | ||
roi_w = roi_w | ||
bin_h = roi_h / pooled_size | ||
bin_w = roi_w / pooled_size | ||
|
||
hstart = ph * bin_h | ||
wstart = pw * bin_w | ||
hend = (ph + 1) * bin_h | ||
wend = (pw + 1) * bin_w | ||
hstart = tvm.min(tvm.max(hstart + roi_start_h, 0), height-1) | ||
wstart = tvm.min(tvm.max(wstart + roi_start_w, 0), width-1) | ||
hend = tvm.min(tvm.max(hend + roi_start_h, 0), height-1) | ||
wend = tvm.min(tvm.max(wend + roi_start_w, 0), width-1) | ||
non_empty = tvm.all(hstart < hend, wstart < wend) | ||
|
||
def min_value(dtype): | ||
return tvm.select(non_empty, tvm.min_value(dtype), tvm.const(0.0, dtype)) | ||
|
||
stride_h = (hend - hstart) / 3.0 | ||
stride_w = (wend - wstart) / 3.0 | ||
hstart += stride_h | ||
wstart += stride_w | ||
stride_h = tvm.max(0.01, stride_h) | ||
stride_w = tvm.max(0.01, stride_w) | ||
_max = tvm.comm_reducer(lambda x, y: tvm.make._OpMax(x, y), min_value, name='max') | ||
rh = tvm.reduce_axis((0, tvm.select(non_empty, 2, 0)), 'rh') | ||
rw = tvm.reduce_axis((0, tvm.select(non_empty, 2, 0)), 'rw') | ||
return _max(_bilinear(batch_index, c, hstart + rh*stride_h, wstart+rw*stride_w), | ||
axis=[rh, rw]) | ||
|
||
return tvm.compute((num_roi, channel, pooled_size, pooled_size), _sample, | ||
tag='pool,roi_align_v2') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Test code for vision package""" | ||
import numpy as np | ||
import tvm | ||
import topi | ||
import math | ||
|
||
from topi.vision import ssd, nms | ||
|
||
|
||
def verify_roi_align(batch_size, in_channel, size, num_roi, pooled_size, spatial_scale): | ||
data_shape = (batch_size, in_channel, size, size) | ||
rois_shape = (num_roi, 5) | ||
data=tvm.placeholder(data_shape) | ||
rois=tvm.placeholder(rois_shape) | ||
np_data = np.random.uniform(size=data_shape).reshape(data_shape).astype('float32') * size | ||
np_rois = np.random.uniform(size=rois_shape).astype('float32') * size | ||
np_rois[:, 0] = 0 | ||
|
||
def check_device(device): | ||
ctx = tvm.context(device, 0) | ||
if not ctx.exist: | ||
print("Skip because %s is not enabled" % device) | ||
return | ||
print("Running on target: %s" % device) | ||
with tvm.target.create(device): | ||
out = topi.vision.rcnn.roi_align_v2(data, rois, pooled_size=pooled_size, spatial_scale=spatial_scale) | ||
s = topi.generic.schedule_roi_align_v2(out) | ||
|
||
tvm_data = tvm.nd.array(np_data, ctx) | ||
tvm_rois = tvm.nd.array(np_rois, ctx) | ||
tvm_out = tvm.nd.array(np.zeros((num_roi, in_channel, pooled_size, pooled_size)).astype(out.dtype), ctx=ctx) | ||
f = tvm.build(s, [data, rois, out], device) | ||
f(tvm_data, tvm_rois, tvm_out) | ||
|
||
import mxnet | ||
mx_ctx = mxnet.gpu(0) | ||
mx_data = mxnet.nd.array(np_data, mx_ctx) | ||
mx_rois = mxnet.nd.array(np_rois, mx_ctx) | ||
mx_out = mxnet.nd.contrib.ROIAlign_v2(mx_data, mx_rois, pooled_size=(pooled_size, pooled_size), spatial_scale=spatial_scale) | ||
mx_out = mx_out.asnumpy() | ||
|
||
tvm_out = tvm_out.asnumpy() | ||
|
||
np.testing.assert_allclose(tvm_out, mx_out, rtol=1e-3) | ||
|
||
for device in ['cuda', 'llvm']: | ||
check_device(device) | ||
|
||
|
||
def test_roi_align_v2(): | ||
verify_roi_align(1, 1, 14, 64, 7, 1.) | ||
verify_roi_align(1, 1, 14, 64, 7, 0.5) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_roi_align_v2() |