Skip to content

Commit

Permalink
[TOPI] Tunable Template for Conv2D HWCN on CUDA (apache#4168)
Browse files Browse the repository at this point in the history
* support conv2d HWCN in AutoTVM and Relay

* fix lint

* fix comments and unit tests
  • Loading branch information
comaniac authored and kevinthesun committed Oct 30, 2019
1 parent 7f66bd5 commit ccdd47a
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 65 deletions.
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def args_to_workload(x, topi_compute_func=None):
elif x is None:
workload = 0
else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
raise RuntimeError('Do not support type "%s" in argument. Consider to use '
'primitive types only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload

Expand Down
7 changes: 5 additions & 2 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,12 @@ def _topi_nn_conv2d(*args, **kwargs):
args = deserialize_args(args)
A, W = args[:2]
layout = args[-2]
assert layout == 'NCHW', "only support NCHW currently"
assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently"
C = topi.nn.conv2d(*args, **kwargs)
s = topi.generic.schedule_conv2d_nchw([C])
if layout == 'NCHW':
s = topi.generic.schedule_conv2d_nchw([C])
else:
s = topi.generic.schedule_conv2d_hwcn([C])
return s, [A, W, C]

@register("topi_nn_depthwise_conv2d_nchw")
Expand Down
12 changes: 7 additions & 5 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,14 @@ def compute_conv2d(attrs, inputs, out_type, target):
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)

assert layout in ["NCHW", "NHWC", "NCHW4c"]
assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"]
(dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")

def _get_out_depth():
weight_shape = get_const_tuple(inputs[1].shape)
if kernel_layout == "HWOI":
if kernel_layout.startswith("HW"):
return weight_shape[2] * weight_shape[3]
return weight_shape[0] * weight_shape[1]

Expand Down Expand Up @@ -192,11 +192,13 @@ def schedule_conv2d(attrs, outs, target):
with target:
if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs)
if groups == 1 and layout == "NCHW4c":
elif groups == 1 and layout == "NCHW4c":
return topi.generic.schedule_conv2d_nchw(outs)
if groups == 1 and layout == "NHWC":
elif groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs)
if groups != 1:
elif groups == 1 and layout == "HWCN":
return topi.generic.schedule_conv2d_hwcn(outs)
elif groups != 1:
# collect in_channels to distinguish depthwise and group conv2d
op = _find_conv2d_op(outs[0].op)
assert op is not None
Expand Down
2 changes: 0 additions & 2 deletions src/pass/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,6 @@ class Vectorizer : public IRMutator {
CHECK(!op->extent.type().is_vector());
Expr extent = Mutate(op->extent);
if (extent.type().is_vector()) {
LOG(WARNING) << "Detect vectorized extent type, scalarizing...";
return Scalarize(s);
}
Stmt body = Mutate(op->body);
Expand All @@ -386,7 +385,6 @@ class Vectorizer : public IRMutator {
CHECK(!op->condition.type().is_vector());
Expr condition = this->Mutate(op->condition);
if (condition.type().is_vector()) {
LOG(WARNING) << "Detect vector condition in Vectorized Loop, scalarizing...";
return Scalarize(s);
}
Stmt then_case = this->Mutate(op->then_case);
Expand Down
85 changes: 49 additions & 36 deletions topi/python/topi/cuda/conv2d_hwcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@
# pylint: disable=invalid-name, too-many-locals, too-many-statements
"""Schedule for conv2d_hwcn with auto fusion"""
import tvm
from .. import tag
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity

def schedule_conv2d_hwcn(outs):
from .. import generic, tag


@autotvm.register_topi_schedule(generic.schedule_conv2d_hwcn, ["cuda", "gpu"], ["direct"])
def schedule_conv2d_hwcn(cfg, outs):
"""Schedule for conv2d_hwcn and any element-wise operations.
Parameters
Expand Down Expand Up @@ -51,36 +56,44 @@ def schedule(Apad, W, B):
sch[B].set_scope("local")
BL = B

tile = 8
num_thread = 8
block_factor = tile * num_thread
step = 8
vthread = 2
hi, wi, fi, ni = sch[Out].op.axis

block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
# Create tuning space
n_thread_cand = [1, 2, 4, 8, 16, 32]
vthread_cand = [1, 2, 4, 8]

cfg.define_split(
'tile_fi',
fi,
num_outputs=4,
filter=lambda x:
(x.size[1] in vthread_cand and x.size[2] in n_thread_cand))
cfg.define_split(
'tile_ni',
ni,
num_outputs=4,
filter=lambda x:
(x.size[1] in vthread_cand and x.size[2] in n_thread_cand))

if cfg.is_fallback:
cfg['tile_fi'] = SplitEntity([-1, 2, 8, 4])
cfg['tile_ni'] = SplitEntity([-1, 2, 8, 4])

# Scheduling
step = 8

hi, wi, fi, ni = sch[Out].op.axis
bz = sch[Out].fuse(hi, wi)
by, fi = sch[Out].split(fi, factor=block_factor)
bx, ni = sch[Out].split(ni, factor=block_factor)
tyz, fi = sch[Out].split(fi, nparts=vthread)
txz, ni = sch[Out].split(ni, nparts=vthread)
ty, fi = sch[Out].split(fi, nparts=num_thread)
tx, ni = sch[Out].split(ni, nparts=num_thread)
by, tyz, ty, fi = cfg['tile_fi'].apply(sch, Out, fi)
bx, txz, tx, ni = cfg['tile_ni'].apply(sch, Out, ni)
sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)
sch[Out].bind(bz, block_z)
sch[Out].bind(by, block_y)
sch[Out].bind(bx, block_x)
sch[Out].bind(tyz, thread_yz)
sch[Out].bind(txz, thread_xz)
sch[Out].bind(ty, thread_y)
sch[Out].bind(tx, thread_x)

sch[Out].bind(bz, tvm.thread_axis('blockIdx.z'))
sch[Out].bind(by, tvm.thread_axis('blockIdx.y'))
sch[Out].bind(bx, tvm.thread_axis('blockIdx.x'))
sch[Out].bind(tyz, tvm.thread_axis('vthread'))
sch[Out].bind(txz, tvm.thread_axis('vthread'))
sch[Out].bind(ty, tvm.thread_axis('threadIdx.y'))
sch[Out].bind(tx, tvm.thread_axis('threadIdx.x'))

# Schedule BL local write
sch[BL].compute_at(sch[Out], tx)
Expand All @@ -98,21 +111,21 @@ def schedule(Apad, W, B):
sch[WL].compute_at(sch[BL], rci)
# Schedule for A's shared memory load
yi, xi, ci, ni = sch[AA].op.axis
ty, ci = sch[AA].split(ci, nparts=num_thread)
tx, ni = sch[AA].split(ni, nparts=num_thread)
ty, ci = sch[AA].split(ci, nparts=cfg['tile_fi'].size[2])
tx, ni = sch[AA].split(ni, nparts=cfg['tile_ni'].size[2])
_, ni = sch[AA].split(ni, factor=4)
sch[AA].reorder(ty, tx, yi, xi, ci, ni)
sch[AA].bind(ty, thread_y)
sch[AA].bind(tx, thread_x)
sch[AA].bind(ty, tvm.thread_axis('threadIdx.y'))
sch[AA].bind(tx, tvm.thread_axis('threadIdx.x'))
sch[AA].vectorize(ni)
# Schedule for W's shared memory load
yi, xi, ci, fi = sch[WW].op.axis
ty, ci = sch[WW].split(ci, nparts=num_thread)
tx, fi = sch[WW].split(fi, nparts=num_thread)
ty, ci = sch[WW].split(ci, nparts=cfg['tile_fi'].size[2])
tx, fi = sch[WW].split(fi, nparts=cfg['tile_ni'].size[2])
_, fi = sch[WW].split(fi, factor=4)
sch[WW].reorder(ty, tx, yi, xi, ci, fi)
sch[WW].bind(ty, thread_y)
sch[WW].bind(tx, thread_x)
sch[WW].bind(ty, tvm.thread_axis('threadIdx.y'))
sch[WW].bind(tx, tvm.thread_axis('threadIdx.x'))
sch[WW].vectorize(fi)

scheduled_ops = []
Expand Down
18 changes: 18 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ def _default_schedule(outs, auto_inline):
return s


@tvm.target.generic_func
def schedule_conv2d_hwcn(outs):
"""Schedule for conv2d_hwcn
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_hwcn
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
# default declaration
if layout == 'NCHW':
return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
if layout == 'HWCN':
elif layout == 'HWCN':
return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype)
if layout == 'NHWC':
elif layout == 'NHWC':
return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype)
raise ValueError("not support this layout {} yet".format(layout))

Expand Down
49 changes: 33 additions & 16 deletions topi/tests/python/test_topi_conv2d_hwcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,41 +29,58 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p

A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = topi.nn.conv2d_hwcn(A, W, stride, padding, dilation)
C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_hwcn([B])
s2 = topi.cuda.schedule_conv2d_hwcn([C])
B = tvm.placeholder((1, num_filter, 1), name='bias')

a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
b_shape = get_const_tuple(B.shape)
dtype = A.dtype

@memoize("topi.tests.test_topi_conv2d_hwcn.verify_hwcn")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=b_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
b_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
c1_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
c2_np = c1_np + b_np
c3_np = np.maximum(c2_np, 0)
return a_np, w_np, b_np, c1_np, c2_np, c3_np

a_np, w_np, b_np, c1_np, c2_np, c3_np = get_ref_data()

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
t_conv = topi.nn.conv2d(A, W, stride, padding, dilation, layout='HWCN')
t_bias = topi.add(t_conv, B)
t_relu = topi.nn.relu(t_bias)
s1 = topi.generic.schedule_conv2d_hwcn([t_conv])
s2 = topi.generic.schedule_conv2d_hwcn([t_bias])
s3 = topi.generic.schedule_conv2d_hwcn([t_relu])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
func1 = tvm.build(s1, [A, W, B], device)
func2 = tvm.build(s2, [A, W, C], device)
func1(a, w, b)
func2(a, w, c)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
b = tvm.nd.array(b_np, ctx)

conv_out = tvm.nd.array(
np.zeros(get_const_tuple(t_conv.shape), dtype=t_conv.dtype), ctx)
bias_out = tvm.nd.array(
np.zeros(get_const_tuple(t_bias.shape), dtype=t_bias.dtype), ctx)
relu_out = tvm.nd.array(
np.zeros(get_const_tuple(t_relu.shape), dtype=t_relu.dtype), ctx)
func1 = tvm.build(s1, [A, W, t_conv], device)
func2 = tvm.build(s2, [A, W, B, t_bias], device)
func3 = tvm.build(s3, [A, W, B, t_relu], device)
func1(a, w, conv_out)
func2(a, w, b, bias_out)
func3(a, w, b, relu_out)
tvm.testing.assert_allclose(conv_out.asnumpy(), c1_np, rtol=1e-5)
tvm.testing.assert_allclose(bias_out.asnumpy(), c2_np, rtol=1e-5)
tvm.testing.assert_allclose(relu_out.asnumpy(), c3_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)
Expand Down
1 change: 0 additions & 1 deletion topi/tests/python/test_topi_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def get_ref_data():
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
Expand Down

0 comments on commit ccdd47a

Please sign in to comment.