-
Notifications
You must be signed in to change notification settings - Fork 671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Detect and lower depthwise conv to linalg.generic #2678
Conversation
95aa614
to
4ce60a0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a TODO that we can also just lower this to linalg.depthwise-convolution
named op. THere were named ops for conv1D and conv2D, etc. added to MLIR. So, it is unclear whether named ops addition is OK or not, but thats an option for the future.
// y[n, d1, d2, ....dn, ci * groupSize + co] = sum(k1, k2, ....kn, | ||
// x[n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn] | ||
// * w[k1, k2, ...kn, ci, co]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should include dilation in the comment too since it is supported.
}, | ||
feature_group_count = 2 : i64, | ||
padding = dense<0> : tensor<2x2xi64>, | ||
rhs_dilation = dense<1> : tensor<2xi64>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know if we support dilation in convolution itself, but would be good to check that dilation loweres correctly (The linalg.conv lowers the dilation properly)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a front end problem when constructing tf.depthwise_conv2d op the saved model doesn't have correct dilation attributes but default values instead. For conv I will add more test cases
So I will create a separate issues to track :
- Increase coverage of conv tests to include dialed convolution.
- Triage dilated depthwise convolution savedModel issue
5eba976
to
2cf9303
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just few nits, thanks!
@@ -26,19 +26,37 @@ class Conv2dModule(tf.Module): | |||
tf.TensorSpec([2, 4, 5, 2], tf.float32), | |||
tf.TensorSpec([2, 2, 2, 3], tf.float32), | |||
]) | |||
def conv2d_2452x2223_valid(self, img, kernel): | |||
def conv2d_2423x2223_valid(self, img, kernel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2452x2223
return tf.nn.depthwise_conv2d( | ||
img, kernel, [1, 1, 1, 1], "VALID", name="result") | ||
|
||
@tf.function(input_signature=[ | ||
tf.TensorSpec([2, 4, 5, 2], tf.float32), | ||
tf.TensorSpec([2, 4, 2, 3], tf.float32), | ||
]) | ||
def conv2d_2452x2223_same(self, img, kernel): | ||
def conv2d_2423x2223_same(self, img, kernel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2452x2423
tf.TensorSpec([2, 4, 5, 2], tf.float32), | ||
tf.TensorSpec([2, 4, 2, 3], tf.float32), | ||
]) | ||
def conv2d_2423x2223_valid_stride_2(self, img, kernel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2452x2423
tf.TensorSpec([2, 4, 5, 2], tf.float32), | ||
tf.TensorSpec([2, 4, 2, 3], tf.float32), | ||
]) | ||
def conv2d_2423x2223_same_stride_2(self, img, kernel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2452x2423
} | ||
} | ||
inputExprs.push_back(rewriter.getAffineDimExpr(ciIndex)); | ||
// k1, k2, ...kn, ci, co |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add a blank line before the comment since they are for different exprs.
2cf9303
to
8344da9
Compare
Lower 'mhlo.conv' with dpethwise convolution properties to 'linalg.generic' on Buffers.
Lowering depthwise convolution to linalg.generic op. The idea is to use the group convolution formulation to perform the separable depth wise convolution as the following, given an n-dimensional input x and filter w the direct convolution operation can be written as following expression that performs reduction over kernel dimensions k1..kn.