Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Tunable Template for Conv2D HWCN on CUDA #4168

Merged
merged 3 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def args_to_workload(x, topi_compute_func=None):
elif x is None:
workload = 0
else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
raise RuntimeError('Do not support type "%s" in argument. Consider to use '
'primitive types only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload

Expand Down
7 changes: 5 additions & 2 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,12 @@ def _topi_nn_conv2d(*args, **kwargs):
args = deserialize_args(args)
A, W = args[:2]
layout = args[-2]
assert layout == 'NCHW', "only support NCHW currently"
assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently"
C = topi.nn.conv2d(*args, **kwargs)
s = topi.generic.schedule_conv2d_nchw([C])
if layout == 'NCHW':
s = topi.generic.schedule_conv2d_nchw([C])
else:
s = topi.generic.schedule_conv2d_hwcn([C])
return s, [A, W, C]

@register("topi_nn_depthwise_conv2d_nchw")
Expand Down
12 changes: 7 additions & 5 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,14 @@ def compute_conv2d(attrs, inputs, out_type, target):
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)

assert layout in ["NCHW", "NHWC", "NCHW4c"]
assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"]
(dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")

def _get_out_depth():
weight_shape = get_const_tuple(inputs[1].shape)
if kernel_layout == "HWOI":
if kernel_layout.startswith("HW"):
return weight_shape[2] * weight_shape[3]
return weight_shape[0] * weight_shape[1]

Expand Down Expand Up @@ -192,11 +192,13 @@ def schedule_conv2d(attrs, outs, target):
with target:
if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs)
if groups == 1 and layout == "NCHW4c":
elif groups == 1 and layout == "NCHW4c":
return topi.generic.schedule_conv2d_nchw(outs)
if groups == 1 and layout == "NHWC":
elif groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs)
if groups != 1:
elif groups == 1 and layout == "HWCN":
return topi.generic.schedule_conv2d_hwcn(outs)
elif groups != 1:
# collect in_channels to distinguish depthwise and group conv2d
op = _find_conv2d_op(outs[0].op)
assert op is not None
Expand Down
2 changes: 0 additions & 2 deletions src/pass/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,6 @@ class Vectorizer : public IRMutator {
CHECK(!op->extent.type().is_vector());
Expr extent = Mutate(op->extent);
if (extent.type().is_vector()) {
LOG(WARNING) << "Detect vectorized extent type, scalarizing...";
return Scalarize(s);
}
Stmt body = Mutate(op->body);
Expand All @@ -386,7 +385,6 @@ class Vectorizer : public IRMutator {
CHECK(!op->condition.type().is_vector());
Expr condition = this->Mutate(op->condition);
if (condition.type().is_vector()) {
LOG(WARNING) << "Detect vector condition in Vectorized Loop, scalarizing...";
return Scalarize(s);
}
Stmt then_case = this->Mutate(op->then_case);
Expand Down
85 changes: 49 additions & 36 deletions topi/python/topi/cuda/conv2d_hwcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@
# pylint: disable=invalid-name, too-many-locals, too-many-statements
"""Schedule for conv2d_hwcn with auto fusion"""
import tvm
from .. import tag
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity

def schedule_conv2d_hwcn(outs):
from .. import generic, tag


@autotvm.register_topi_schedule(generic.schedule_conv2d_hwcn, ["cuda", "gpu"], ["direct"])
def schedule_conv2d_hwcn(cfg, outs):
"""Schedule for conv2d_hwcn and any element-wise operations.

Parameters
Expand Down Expand Up @@ -51,36 +56,44 @@ def schedule(Apad, W, B):
sch[B].set_scope("local")
BL = B

tile = 8
num_thread = 8
block_factor = tile * num_thread
step = 8
vthread = 2
hi, wi, fi, ni = sch[Out].op.axis

block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
# Create tuning space
n_thread_cand = [1, 2, 4, 8, 16, 32]
vthread_cand = [1, 2, 4, 8]

cfg.define_split(
'tile_fi',
fi,
num_outputs=4,
filter=lambda x:
(x.size[1] in vthread_cand and x.size[2] in n_thread_cand))
cfg.define_split(
'tile_ni',
ni,
num_outputs=4,
filter=lambda x:
(x.size[1] in vthread_cand and x.size[2] in n_thread_cand))

if cfg.is_fallback:
cfg['tile_fi'] = SplitEntity([-1, 2, 8, 4])
cfg['tile_ni'] = SplitEntity([-1, 2, 8, 4])

# Scheduling
step = 8

hi, wi, fi, ni = sch[Out].op.axis
bz = sch[Out].fuse(hi, wi)
by, fi = sch[Out].split(fi, factor=block_factor)
bx, ni = sch[Out].split(ni, factor=block_factor)
tyz, fi = sch[Out].split(fi, nparts=vthread)
txz, ni = sch[Out].split(ni, nparts=vthread)
ty, fi = sch[Out].split(fi, nparts=num_thread)
tx, ni = sch[Out].split(ni, nparts=num_thread)
by, tyz, ty, fi = cfg['tile_fi'].apply(sch, Out, fi)
bx, txz, tx, ni = cfg['tile_ni'].apply(sch, Out, ni)
sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)
sch[Out].bind(bz, block_z)
sch[Out].bind(by, block_y)
sch[Out].bind(bx, block_x)
sch[Out].bind(tyz, thread_yz)
sch[Out].bind(txz, thread_xz)
sch[Out].bind(ty, thread_y)
sch[Out].bind(tx, thread_x)

sch[Out].bind(bz, tvm.thread_axis('blockIdx.z'))
sch[Out].bind(by, tvm.thread_axis('blockIdx.y'))
sch[Out].bind(bx, tvm.thread_axis('blockIdx.x'))
sch[Out].bind(tyz, tvm.thread_axis('vthread'))
sch[Out].bind(txz, tvm.thread_axis('vthread'))
sch[Out].bind(ty, tvm.thread_axis('threadIdx.y'))
sch[Out].bind(tx, tvm.thread_axis('threadIdx.x'))

# Schedule BL local write
sch[BL].compute_at(sch[Out], tx)
Expand All @@ -98,21 +111,21 @@ def schedule(Apad, W, B):
sch[WL].compute_at(sch[BL], rci)
# Schedule for A's shared memory load
yi, xi, ci, ni = sch[AA].op.axis
ty, ci = sch[AA].split(ci, nparts=num_thread)
tx, ni = sch[AA].split(ni, nparts=num_thread)
ty, ci = sch[AA].split(ci, nparts=cfg['tile_fi'].size[2])
tx, ni = sch[AA].split(ni, nparts=cfg['tile_ni'].size[2])
_, ni = sch[AA].split(ni, factor=4)
sch[AA].reorder(ty, tx, yi, xi, ci, ni)
sch[AA].bind(ty, thread_y)
sch[AA].bind(tx, thread_x)
sch[AA].bind(ty, tvm.thread_axis('threadIdx.y'))
sch[AA].bind(tx, tvm.thread_axis('threadIdx.x'))
sch[AA].vectorize(ni)
# Schedule for W's shared memory load
yi, xi, ci, fi = sch[WW].op.axis
ty, ci = sch[WW].split(ci, nparts=num_thread)
tx, fi = sch[WW].split(fi, nparts=num_thread)
ty, ci = sch[WW].split(ci, nparts=cfg['tile_fi'].size[2])
tx, fi = sch[WW].split(fi, nparts=cfg['tile_ni'].size[2])
_, fi = sch[WW].split(fi, factor=4)
sch[WW].reorder(ty, tx, yi, xi, ci, fi)
sch[WW].bind(ty, thread_y)
sch[WW].bind(tx, thread_x)
sch[WW].bind(ty, tvm.thread_axis('threadIdx.y'))
sch[WW].bind(tx, tvm.thread_axis('threadIdx.x'))
sch[WW].vectorize(fi)

scheduled_ops = []
Expand Down
18 changes: 18 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ def _default_schedule(outs, auto_inline):
return s


@tvm.target.generic_func
def schedule_conv2d_hwcn(outs):
"""Schedule for conv2d_hwcn

Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_hwcn
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
# default declaration
if layout == 'NCHW':
return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
if layout == 'HWCN':
elif layout == 'HWCN':
return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype)
if layout == 'NHWC':
elif layout == 'NHWC':
return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype)
raise ValueError("not support this layout {} yet".format(layout))

Expand Down
49 changes: 33 additions & 16 deletions topi/tests/python/test_topi_conv2d_hwcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,41 +29,58 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p

A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = topi.nn.conv2d_hwcn(A, W, stride, padding, dilation)
C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_hwcn([B])
s2 = topi.cuda.schedule_conv2d_hwcn([C])
B = tvm.placeholder((1, num_filter, 1), name='bias')

a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
b_shape = get_const_tuple(B.shape)
dtype = A.dtype

@memoize("topi.tests.test_topi_conv2d_hwcn.verify_hwcn")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=b_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
b_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
c1_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
c2_np = c1_np + b_np
c3_np = np.maximum(c2_np, 0)
return a_np, w_np, b_np, c1_np, c2_np, c3_np

a_np, w_np, b_np, c1_np, c2_np, c3_np = get_ref_data()

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
t_conv = topi.nn.conv2d(A, W, stride, padding, dilation, layout='HWCN')
t_bias = topi.add(t_conv, B)
t_relu = topi.nn.relu(t_bias)
s1 = topi.generic.schedule_conv2d_hwcn([t_conv])
s2 = topi.generic.schedule_conv2d_hwcn([t_bias])
s3 = topi.generic.schedule_conv2d_hwcn([t_relu])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
func1 = tvm.build(s1, [A, W, B], device)
func2 = tvm.build(s2, [A, W, C], device)
func1(a, w, b)
func2(a, w, c)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
b = tvm.nd.array(b_np, ctx)

conv_out = tvm.nd.array(
np.zeros(get_const_tuple(t_conv.shape), dtype=t_conv.dtype), ctx)
bias_out = tvm.nd.array(
np.zeros(get_const_tuple(t_bias.shape), dtype=t_bias.dtype), ctx)
relu_out = tvm.nd.array(
np.zeros(get_const_tuple(t_relu.shape), dtype=t_relu.dtype), ctx)
func1 = tvm.build(s1, [A, W, t_conv], device)
func2 = tvm.build(s2, [A, W, B, t_bias], device)
func3 = tvm.build(s3, [A, W, B, t_relu], device)
func1(a, w, conv_out)
func2(a, w, b, bias_out)
func3(a, w, b, relu_out)
tvm.testing.assert_allclose(conv_out.asnumpy(), c1_np, rtol=1e-5)
tvm.testing.assert_allclose(bias_out.asnumpy(), c2_np, rtol=1e-5)
tvm.testing.assert_allclose(relu_out.asnumpy(), c3_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)
Expand Down
1 change: 0 additions & 1 deletion topi/tests/python/test_topi_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def get_ref_data():
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
Expand Down