Skip to content

Commit

Permalink
Small fix of the Depthwise Convolution example in python3 (#224)
Browse files Browse the repository at this point in the history
* fix for python3

fix for python3

* Update depthwise_conv2d_map_test.py

remove sys.append
  • Loading branch information
sxjscience authored and tqchen committed Jul 7, 2017
1 parent bfe6d95 commit b759d0f
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions topi/recipe/conv/depthwise_conv2d_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ def depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np):
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j/channel_multiplier,:,:], np.rot90(filter_np[j/channel_multiplier,j%channel_multiplier,:,:], 2),
depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j // channel_multiplier,:,:],
np.rot90(filter_np[j // channel_multiplier,j%channel_multiplier,:,:], 2),
mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]
if padding == 'VALID':
for i in range(batch):
for j in range(out_channel):
depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j/channel_multiplier,:,:], np.rot90(filter_np[j/channel_multiplier,j%channel_multiplier,:,:], 2),
depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j // channel_multiplier,:,:],
np.rot90(filter_np[j // channel_multiplier,j%channel_multiplier,:,:], 2),
mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w]
for c in range(out_channel):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
Expand Down Expand Up @@ -132,7 +134,7 @@ def check_device(device):
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
print "success"
print("success")

with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0,
Expand Down

0 comments on commit b759d0f

Please sign in to comment.