Skip to content

Commit

Permalink
Detect and lower depthwise conv to linalg.generic
Browse files Browse the repository at this point in the history
  • Loading branch information
asaadaldien committed Jul 27, 2020
1 parent ffc7080 commit 95aa614
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 6 deletions.
1 change: 0 additions & 1 deletion integrations/tensorflow/e2e/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ VMLA_FAILING = [
# keep sorted
LLVM_FAILING = [
"broadcasting_test.py",
"depth_conv_test.py",
"dynamic_mlp_relu_test.py",
"dynamic_mlp_test.py",
"fill_test.py", # TODO(jennik): Get this test working on IREE.
Expand Down
34 changes: 32 additions & 2 deletions integrations/tensorflow/e2e/depth_conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,51 @@ def conv2d_2452x2223_same(self, img, kernel):
img, kernel, [1, 1, 1, 1], "SAME", name="result")



@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2452x2223_valid_stride_2(self, img, kernel):
return tf.nn.depthwise_conv2d(
img, kernel, [1, 2, 2, 1], "VALID", name="result")

@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2452x2223_same_stride_2(self, img, kernel):
return tf.nn.depthwise_conv2d(
img, kernel, [1, 2, 2, 1], "SAME", name="result")


@tf_test_utils.compile_module(Conv2dModule)
class ConvTest(tf_test_utils.CompiledModuleTestCase):

def test_batched_feature_unpadded(self):
def test_batched_feature_padded(self):
i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
k = np.arange(24, dtype=np.float32).reshape([2, 2, 2, 3])
r = self.get_module().conv2d_2452x2223_valid(i, k)
r.print().assert_all_close()

def test_batched_feature_unpadded_smae(self):
def test_batched_feature_unpadded_same(self):
i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
k = np.arange(48, dtype=np.float32).reshape([2, 4, 2, 3])
r = self.get_module().conv2d_2452x2223_same(i, k)
r.print().assert_all_close()

def test_batched_feature_unpadded_same_stride_2(self):
i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
k = np.arange(48, dtype=np.float32).reshape([2, 4, 2, 3])
r = self.get_module().conv2d_2452x2223_same_stride_2(i, k)
r.print().assert_all_close()

def test_batched_feature_padded_same_stride_2(self):
i = np.arange(80, dtype=np.float32).reshape([2, 4, 5, 2])
k = np.arange(48, dtype=np.float32).reshape([2, 4, 2, 3])
r = self.get_module().conv2d_2452x2223_valid_stride_2(i, k)
r.print().assert_all_close()


if __name__ == "__main__":
if hasattr(tf, "enable_v2_behavior"):
Expand Down
93 changes: 90 additions & 3 deletions iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,96 @@ LogicalResult ConvOpConversion::apply(
rewriter.notifyMatchFailure(op, "failed to zero fill result buffer");
return failure();
}
rewriter.create<linalg::ConvOp>(op.getLoc(), inputBuffers[1], inputBuffers[0],
resultBuffers[0], stridesArg, dilationArg,
padding);
// Depthwise conv path...
if (op.feature_group_count().getZExtValue() > 1u &&
op.feature_group_count().getZExtValue() ==
op.dimension_numbers().kernel_input_feature_dimension().getInt()) {
// Lowering depthwise convolution to linalg.generic op. The idea is to use
// the group convolution formulation to perform the separable depth wise
// convolution as the following, given an n-dimensional input x and filter w
// the direct convolution operation can be written as:
// y[n, d1, d2, ....dn, ci * groupSize + co] = sum(k1, k2, ....kn,
// x[n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn]
// * w[k1, k2, ...kn, ci, co])
SmallVector<AffineExpr, 4> inputExprs;
SmallVector<AffineExpr, 4> filterExprs;
SmallVector<AffineExpr, 4> outputExprs;

const auto spatialDims =
llvm::size(op.dimension_numbers().input_spatial_dimensions());
const int d1Index = 1;
const int coIndex = d1Index + spatialDims;
const int ciIndex = coIndex + 1;
const int k1Index = ciIndex + 1;
// n
inputExprs.push_back(rewriter.getAffineDimExpr(0));
// d1, d2....dn
for (int i = 0; i < spatialDims; ++i) {
if (op.window_stridesAttr()) {
auto stride = op.window_stridesAttr().getValue<APInt>(i);
inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) *
stride.getZExtValue() +
rewriter.getAffineDimExpr(k1Index + i));
} else {
inputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i) +
rewriter.getAffineDimExpr(k1Index + i));
}
}
// ci
inputExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
// k1, k2,...kn
for (int i = 0; i < spatialDims; ++i) {
filterExprs.push_back(rewriter.getAffineDimExpr(k1Index + i));
}
// ci, co
filterExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
filterExprs.push_back(rewriter.getAffineDimExpr(coIndex));

// n
outputExprs.push_back(rewriter.getAffineDimExpr(0));
for (int i = 0; i < spatialDims; ++i) {
outputExprs.push_back(rewriter.getAffineDimExpr(d1Index + i));
}
// ci * groupSize + co
outputExprs.push_back(
rewriter.getAffineDimExpr(ciIndex) *
op.dimension_numbers().kernel_output_feature_dimension().getInt() +
rewriter.getAffineDimExpr(coIndex));

// nloops = |d| + |k| + 3
int nloops = spatialDims * 2 + 3;
SmallVector<AffineMap, 4> indexingMaps;
indexingMaps.emplace_back(AffineMap::get(
nloops, /*symbolCount=*/0, inputExprs, rewriter.getContext()));
indexingMaps.emplace_back(AffineMap::get(
nloops, /*symbolCount=*/0, filterExprs, rewriter.getContext()));
indexingMaps.emplace_back(AffineMap::get(
nloops, /*symbolCount=*/0, outputExprs, rewriter.getContext()));

Location loc = op.getLoc();
SmallVector<Type, 4> bodyArgTypes, opResultTypes;
SmallVector<Value, 2> linalgOpArgs = {inputBuffers[0], inputBuffers[1],
resultBuffers[0]};

SmallVector<StringRef, 3> parallelLoopsNum(spatialDims + 3, "parallel");
for (int i = 0; i < spatialDims; ++i) {
parallelLoopsNum.push_back("reduction");
}
rewriter.create<linalg::GenericOp>(
loc, opResultTypes, linalgOpArgs,
2, // args_in
1, // args_out
indexingMaps, parallelLoopsNum,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
Value mul = nestedBuilder.create<MulFOp>(nestedLoc, args[0], args[1]);
Value add = nestedBuilder.create<AddFOp>(nestedLoc, mul, args[2]);
nestedBuilder.create<linalg::YieldOp>(loc, add);
});
} else {
rewriter.create<linalg::ConvOp>(op.getLoc(), inputBuffers[1],
inputBuffers[0], resultBuffers[0],
stridesArg, dilationArg, padding);
}
return success();
}

Expand Down

0 comments on commit 95aa614

Please sign in to comment.