Skip to content

Commit

Permalink
Detect and lower depthwise conv to linalg.generic
Browse files Browse the repository at this point in the history
  • Loading branch information
asaadaldien committed Aug 4, 2020
1 parent 17acc88 commit 2cf9303
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 10 deletions.
2 changes: 0 additions & 2 deletions integrations/tensorflow/e2e/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ VMLA_FAILING = [
# keep sorted
LLVM_FAILING = [
"broadcasting_test.py",
"depth_conv_test.py",
"dynamic_mlp_relu_test.py",
"dynamic_mlp_test.py",
"fill_test.py", # TODO(jennik): Get this test working on IREE.
Expand All @@ -82,7 +81,6 @@ LLVM_FAILING = [
VULKAN_FAILING = [
"broadcasting_test.py",
"control_flow_test.py",
"depth_conv_test.py",
"dynamic_mlp_relu_test.py",
"dynamic_mlp_test.py",
"fill_test.py", # TODO(jennik): Get this test working on IREE.
Expand Down
45 changes: 41 additions & 4 deletions integrations/tensorflow/e2e/depth_conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,37 @@ class Conv2dModule(tf.Module):
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 2, 2, 3], tf.float32),
])
def conv2d_2452x2223_valid(self, img, kernel):
def conv2d_2423x2223_valid(self, img, kernel):
return tf.nn.depthwise_conv2d(
img, kernel, [1, 1, 1, 1], "VALID", name="result")

@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2452x2223_same(self, img, kernel):
def conv2d_2423x2223_same(self, img, kernel):
return tf.nn.depthwise_conv2d(
img, kernel, [1, 1, 1, 1], "SAME", name="result")



@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2423x2223_valid_stride_2(self, img, kernel):
return tf.nn.depthwise_conv2d(
img, kernel, [1, 2, 2, 1], "VALID", name="result")

@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2423x2223_same_stride_2(self, img, kernel):
return tf.nn.depthwise_conv2d(
img, kernel, [1, 2, 2, 1], "SAME", name="result")


@tf_test_utils.compile_module(Conv2dModule)
class ConvTest(tf_test_utils.TracedModuleTestCase):

Expand All @@ -47,7 +65,7 @@ def test_batched_feature_unpadded(self):
def batched_feature_unpadded(module):
i = tf_utils.ndarange([2, 4, 5, 2])
k = tf_utils.ndarange([2, 2, 2, 3])
module.conv2d_2452x2223_valid(i, k)
module.conv2d_2423x2223_valid(i, k)

self.compare_backends(batched_feature_unpadded)

Expand All @@ -56,10 +74,29 @@ def test_batched_feature_unpadded_same(self):
def batched_feature_unpadded_same(module):
i = tf_utils.ndarange([2, 4, 5, 2])
k = tf_utils.ndarange([2, 4, 2, 3])
module.conv2d_2452x2223_same(i, k)
module.conv2d_2423x2223_same(i, k)

self.compare_backends(batched_feature_unpadded_same)

def test_batched_feature_unpadded_same_stride_2(self):

def batched_feature_unpadded_same_stride_2(module):
i = tf_utils.ndarange([2, 4, 5, 2])
k = tf_utils.ndarange([2, 4, 2, 3])
module.conv2d_2423x2223_valid_stride_2(i, k)

self.compare_backends(batched_feature_unpadded_same_stride_2)


def test_batched_feature_padded_same_stride_2(self):

def batched_feature_padded_same_stride_2(module):
i = tf_utils.ndarange([2, 4, 5, 2])
k = tf_utils.ndarange([2, 4, 2, 3])
module.conv2d_2423x2223_same_stride_2(i, k)

self.compare_backends(batched_feature_padded_same_stride_2)


if __name__ == "__main__":
if hasattr(tf, "enable_v2_behavior"):
Expand Down
94 changes: 91 additions & 3 deletions iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,97 @@ LogicalResult ConvOpConversion::apply(
rewriter.notifyMatchFailure(op, "failed to zero fill result buffer");
return failure();
}
rewriter.create<linalg::ConvOp>(op.getLoc(), inputBuffers[1], inputBuffers[0],
resultBuffers[0], stridesArg, dilationArg,
padding);
// Depthwise conv path...
if (op.feature_group_count().getZExtValue() > 1u &&
op.feature_group_count().getZExtValue() ==
op.dimension_numbers().kernel_input_feature_dimension().getInt()) {
// Lowering depthwise convolution to linalg.generic op. The idea is to use
// the group convolution formulation to perform the separable depth wise
// convolution as the following, given an n-dimensional input x and filter w
// the direct convolution operation can be written as:
// y[n, d1, d2, ....dn, ci * groupSize + co] = sum(k1, k2, ....kn,
// x[n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn]
// * w[k1, k2, ...kn, ci, co])

// TODO(ataei): Support dilation.
if (llvm::any_of(dilation, [](Attribute attr) {
return (attr.dyn_cast<IntegerAttr>().getInt() != 1);
})) {
return failure();
}

SmallVector<AffineExpr, 4> inputExprs;
SmallVector<AffineExpr, 4> filterExprs;
SmallVector<AffineExpr, 4> outputExprs;

const auto spatialDims =
llvm::size(op.dimension_numbers().input_spatial_dimensions());
const int d1Index = 1;
const int coIndex = d1Index + spatialDims;
const int ciIndex = coIndex + 1;
const int k1Index = ciIndex + 1;
// n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn
inputExprs.push_back(rewriter.getAffineDimExpr(0));
for (int i = 0; i < spatialDims; ++i) {
if (op.window_stridesAttr()) {
auto stride = op.window_stridesAttr().getValue<APInt>(i);
inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) *
stride.getZExtValue() +
rewriter.getAffineDimExpr(k1Index + i));
} else {
inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) +
rewriter.getAffineDimExpr(k1Index + i));
}
}
inputExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
// k1, k2, ...kn, ci, co
for (int i = 0; i < spatialDims; ++i) {
filterExprs.push_back(rewriter.getAffineDimExpr(k1Index + i));
}
filterExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
filterExprs.push_back(rewriter.getAffineDimExpr(coIndex));

// n, d1, d2, ....dn, ci * groupSize + co
outputExprs.push_back(rewriter.getAffineDimExpr(0));
for (int i = 0; i < spatialDims; ++i) {
outputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i));
}
outputExprs.push_back(
rewriter.getAffineDimExpr(ciIndex) *
op.dimension_numbers().kernel_output_feature_dimension().getInt() +
rewriter.getAffineDimExpr(coIndex));

// nloops = |d| + |k| + |{n, ci, co}|
int nloops = spatialDims * 2 + 3;
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.emplace_back(AffineMap::get(
nloops, /*symbolCount=*/0, inputExprs, rewriter.getContext()));
indexingMaps.emplace_back(AffineMap::get(
nloops, /*symbolCount=*/0, filterExprs, rewriter.getContext()));
indexingMaps.emplace_back(AffineMap::get(
nloops, /*symbolCount=*/0, outputExprs, rewriter.getContext()));

Location loc = op.getLoc();
SmallVector<Value, 4> linalgOpArgs = {inputBuffers[0], inputBuffers[1],
resultBuffers[0]};

SmallVector<StringRef, 3> loopAttributeTypes(spatialDims + 3, "parallel");
loopAttributeTypes.append(spatialDims, "reduction");
rewriter.create<linalg::GenericOp>(
loc, ArrayRef<Type>{}, linalgOpArgs,
2, // args_in
1, // args_out
indexingMaps, loopAttributeTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
Value mul = nestedBuilder.create<MulFOp>(nestedLoc, args[0], args[1]);
Value add = nestedBuilder.create<AddFOp>(nestedLoc, mul, args[2]);
nestedBuilder.create<linalg::YieldOp>(loc, add);
});
} else {
rewriter.create<linalg::ConvOp>(op.getLoc(), inputBuffers[1],
inputBuffers[0], resultBuffers[0],
stridesArg, dilationArg, padding);
}
return success();
}

Expand Down
43 changes: 42 additions & 1 deletion iree/compiler/Conversion/HLOToLinalg/test/conv.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt -iree-codegen-hlo-to-linalg-on-buffers %s | IreeFileCheck %s
// RUN: iree-opt -split-input-file -iree-codegen-hlo-to-linalg-on-buffers %s | IreeFileCheck %s

module {
// CHECK: func @conv
Expand Down Expand Up @@ -37,3 +37,44 @@ module {
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
}
}

// -----

module {
func @depthwise_conv() {
%c0 = constant 0 : index
%0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<2x4x5x2xf32>
%1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x2x2x3xf32>
%2 = "mhlo.convolution"(%0, %1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 2 : i64,
padding = dense<0> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<2x3x4x6xf32>
hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<2x3x4x6xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
}
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d4, d3)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4 * 3 + d3)>
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]
// CHECK: mulf
// CHECK: addf
// CHECK: memref<2x4x5x2xf32>, memref<2x2x2x3xf32>, memref<2x3x4x6xf32>

0 comments on commit 2cf9303

Please sign in to comment.