From 95aa614085028b184ec4dcfe50f9d24334bea0ea Mon Sep 17 00:00:00 2001 From: Ahmed Taei Date: Fri, 24 Jul 2020 17:11:11 -0700 Subject: [PATCH] Detect and lower depthwise conv to linalg.generic --- integrations/tensorflow/e2e/BUILD | 1 - .../tensorflow/e2e/depth_conv_test.py | 34 ++++++- .../HLOToLinalg/HLOToLinalgOnBuffers.cpp | 93 ++++++++++++++++++- 3 files changed, 122 insertions(+), 6 deletions(-) diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD index d63b8394a2e2..90065bbe0392 100644 --- a/integrations/tensorflow/e2e/BUILD +++ b/integrations/tensorflow/e2e/BUILD @@ -69,7 +69,6 @@ VMLA_FAILING = [ # keep sorted LLVM_FAILING = [ "broadcasting_test.py", - "depth_conv_test.py", "dynamic_mlp_relu_test.py", "dynamic_mlp_test.py", "fill_test.py", # TODO(jennik): Get this test working on IREE. diff --git a/integrations/tensorflow/e2e/depth_conv_test.py b/integrations/tensorflow/e2e/depth_conv_test.py index 1e8a002caa0f..f827d0c8e8a7 100644 --- a/integrations/tensorflow/e2e/depth_conv_test.py +++ b/integrations/tensorflow/e2e/depth_conv_test.py @@ -38,21 +38,51 @@ def conv2d_2452x2223_same(self, img, kernel): img, kernel, [1, 1, 1, 1], "SAME", name="result") + + @tf.function(input_signature=[ + tf.TensorSpec([2, 4, 5, 2], tf.float32), + tf.TensorSpec([2, 4, 2, 3], tf.float32), + ]) + def conv2d_2452x2223_valid_stride_2(self, img, kernel): + return tf.nn.depthwise_conv2d( + img, kernel, [1, 2, 2, 1], "VALID", name="result") + + @tf.function(input_signature=[ + tf.TensorSpec([2, 4, 5, 2], tf.float32), + tf.TensorSpec([2, 4, 2, 3], tf.float32), + ]) + def conv2d_2452x2223_same_stride_2(self, img, kernel): + return tf.nn.depthwise_conv2d( + img, kernel, [1, 2, 2, 1], "SAME", name="result") + + @tf_test_utils.compile_module(Conv2dModule) class ConvTest(tf_test_utils.CompiledModuleTestCase): - def test_batched_feature_unpadded(self): + def test_batched_feature_padded(self): i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2]) k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3]) r = self.get_module().conv2d_2452x2223_valid(i, k) r.print().assert_all_close() - def test_batched_feature_unpadded_smae(self): + def test_batched_feature_unpadded_same(self): i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2]) k = np.arange(48, dtype=np.float32).reshape([2, 4, 2, 3]) r = self.get_module().conv2d_2452x2223_same(i, k) r.print().assert_all_close() + def test_batched_feature_unpadded_same_stride_2(self): + i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2]) + k = np.arange(48, dtype=np.float32).reshape([2, 4, 2, 3]) + r = self.get_module().conv2d_2452x2223_same_stride_2(i, k) + r.print().assert_all_close() + + def test_batched_feature_padded_same_stride_2(self): + i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2]) + k = np.arange(48, dtype=np.float32).reshape([2, 4, 2, 3]) + r = self.get_module().conv2d_2452x2223_valid_stride_2(i, k) + r.print().assert_all_close() + if __name__ == "__main__": if hasattr(tf, "enable_v2_behavior"): diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp index 90178fcbf73d..aded66057015 100644 --- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp +++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp @@ -384,9 +384,96 @@ LogicalResult ConvOpConversion::apply( rewriter.notifyMatchFailure(op, "failed to zero fill result buffer"); return failure(); } - rewriter.create(op.getLoc(), inputBuffers[1], inputBuffers[0], - resultBuffers[0], stridesArg, dilationArg, - padding); + // Depthwise conv path... + if (op.feature_group_count().getZExtValue() > 1u && + op.feature_group_count().getZExtValue() == + op.dimension_numbers().kernel_input_feature_dimension().getInt()) { + // Lowering depthwise convolution to linalg.generic op. The idea is to use + // the group convolution formulation to perform the separable depth wise + // convolution as the following, given an n-dimensional input x and filter w + // the direct convolution operation can be written as: + // y[n, d1, d2, ....dn, ci * groupSize + co] = sum(k1, k2, ....kn, + // x[n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn] + // * w[k1, k2, ...kn, ci, co]) + SmallVector inputExprs; + SmallVector filterExprs; + SmallVector outputExprs; + + const auto spatialDims = + llvm::size(op.dimension_numbers().input_spatial_dimensions()); + const int d1Index = 1; + const int coIndex = d1Index + spatialDims; + const int ciIndex = coIndex + 1; + const int k1Index = ciIndex + 1; + // n + inputExprs.push_back(rewriter.getAffineDimExpr(0)); + // d1, d2....dn + for (int i = 0; i < spatialDims; ++i) { + if (op.window_stridesAttr()) { + auto stride = op.window_stridesAttr().getValue(i); + inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) * + stride.getZExtValue() + + rewriter.getAffineDimExpr(k1Index + i)); + } else { + inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) + + rewriter.getAffineDimExpr(k1Index + i)); + } + } + // ci + inputExprs.push_back(rewriter.getAffineDimExpr(ciIndex)); + // k1, k2,...kn + for (int i = 0; i < spatialDims; ++i) { + filterExprs.push_back(rewriter.getAffineDimExpr(k1Index + i)); + } + // ci, co + filterExprs.push_back(rewriter.getAffineDimExpr(ciIndex)); + filterExprs.push_back(rewriter.getAffineDimExpr(coIndex)); + + // n + outputExprs.push_back(rewriter.getAffineDimExpr(0)); + for (int i = 0; i < spatialDims; ++i) { + outputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i)); + } + // ci * groupSize + co + outputExprs.push_back( + rewriter.getAffineDimExpr(ciIndex) * + op.dimension_numbers().kernel_output_feature_dimension().getInt() + + rewriter.getAffineDimExpr(coIndex)); + + // nloops = |d| + |k| + 3 + int nloops = spatialDims * 2 + 3; + SmallVector indexingMaps; + indexingMaps.emplace_back(AffineMap::get( + nloops, /*symbolCount=*/0, inputExprs, rewriter.getContext())); + indexingMaps.emplace_back(AffineMap::get( + nloops, /*symbolCount=*/0, filterExprs, rewriter.getContext())); + indexingMaps.emplace_back(AffineMap::get( + nloops, /*symbolCount=*/0, outputExprs, rewriter.getContext())); + + Location loc = op.getLoc(); + SmallVector bodyArgTypes, opResultTypes; + SmallVector linalgOpArgs = {inputBuffers[0], inputBuffers[1], + resultBuffers[0]}; + + SmallVector parallelLoopsNum(spatialDims + 3, "parallel"); + for (int i = 0; i < spatialDims; ++i) { + parallelLoopsNum.push_back("reduction"); + } + rewriter.create( + loc, opResultTypes, linalgOpArgs, + 2, // args_in + 1, // args_out + indexingMaps, parallelLoopsNum, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value mul = nestedBuilder.create(nestedLoc, args[0], args[1]); + Value add = nestedBuilder.create(nestedLoc, mul, args[2]); + nestedBuilder.create(loc, add); + }); + } else { + rewriter.create(op.getLoc(), inputBuffers[1], + inputBuffers[0], resultBuffers[0], + stridesArg, dilationArg, padding); + } return success(); }