Skip to content

Commit

Permalink
fix comments and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Oct 22, 2019
1 parent be8656d commit 1d6a62d
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 20 deletions.
2 changes: 0 additions & 2 deletions src/pass/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,6 @@ class Vectorizer : public IRMutator {
CHECK(!op->extent.type().is_vector());
Expr extent = Mutate(op->extent);
if (extent.type().is_vector()) {
// LOG(WARNING) << "Detect vectorized extent type, scalarizing...";
return Scalarize(s);
}
Stmt body = Mutate(op->body);
Expand All @@ -386,7 +385,6 @@ class Vectorizer : public IRMutator {
CHECK(!op->condition.type().is_vector());
Expr condition = this->Mutate(op->condition);
if (condition.type().is_vector()) {
// LOG(WARNING) << "Detect vector condition in Vectorized Loop, scalarizing...";
return Scalarize(s);
}
Stmt then_case = this->Mutate(op->then_case);
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/cuda/conv2d_hwcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def schedule(Apad, W, B):
# Scheduling
step = 8

bz = sch[Out].fuse(hi, wi) # FIXME: Does it assume square images?
bz = sch[Out].fuse(hi, wi)
by, tyz, ty, fi = cfg['tile_fi'].apply(sch, Out, fi)
bx, txz, tx, ni = cfg['tile_ni'].apply(sch, Out, ni)
sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)
Expand Down
49 changes: 33 additions & 16 deletions topi/tests/python/test_topi_conv2d_hwcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,41 +29,58 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p

A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = topi.nn.conv2d_hwcn(A, W, stride, padding, dilation)
C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_hwcn([B])
s2 = topi.cuda.schedule_conv2d_hwcn([C])
B = tvm.placeholder((1, num_filter, 1), name='bias')

a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
b_shape = get_const_tuple(B.shape)
dtype = A.dtype

@memoize("topi.tests.test_topi_conv2d_hwcn.verify_hwcn")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=b_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
b_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
c1_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
c2_np = c1_np + b_np
c3_np = np.maximum(c2_np, 0)
return a_np, w_np, b_np, c1_np, c2_np, c3_np

a_np, w_np, b_np, c1_np, c2_np, c3_np = get_ref_data()

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
t_conv = topi.nn.conv2d(A, W, stride, padding, dilation, layout='HWCN')
t_bias = topi.add(t_conv, B)
t_relu = topi.nn.relu(t_bias)
s1 = topi.generic.schedule_conv2d_hwcn([t_conv])
s2 = topi.generic.schedule_conv2d_hwcn([t_bias])
s3 = topi.generic.schedule_conv2d_hwcn([t_relu])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
func1 = tvm.build(s1, [A, W, B], device)
func2 = tvm.build(s2, [A, W, C], device)
func1(a, w, b)
func2(a, w, c)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
b = tvm.nd.array(b_np, ctx)

conv_out = tvm.nd.array(
np.zeros(get_const_tuple(t_conv.shape), dtype=t_conv.dtype), ctx)
bias_out = tvm.nd.array(
np.zeros(get_const_tuple(t_bias.shape), dtype=t_bias.dtype), ctx)
relu_out = tvm.nd.array(
np.zeros(get_const_tuple(t_relu.shape), dtype=t_relu.dtype), ctx)
func1 = tvm.build(s1, [A, W, t_conv], device)
func2 = tvm.build(s2, [A, W, B, t_bias], device)
func3 = tvm.build(s3, [A, W, B, t_relu], device)
func1(a, w, conv_out)
func2(a, w, b, bias_out)
func3(a, w, b, relu_out)
tvm.testing.assert_allclose(conv_out.asnumpy(), c1_np, rtol=1e-5)
tvm.testing.assert_allclose(bias_out.asnumpy(), c2_np, rtol=1e-5)
tvm.testing.assert_allclose(relu_out.asnumpy(), c3_np, rtol=1e-5)

for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)
Expand Down
1 change: 0 additions & 1 deletion topi/tests/python/test_topi_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def get_ref_data():
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
Expand Down

0 comments on commit 1d6a62d

Please sign in to comment.