Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect and lower depthwise conv to linalg.generic #2678

Merged
merged 1 commit into from
Aug 5, 2020

Conversation

asaadaldien
Copy link
Contributor

Lower 'mhlo.conv' with dpethwise convolution properties to 'linalg.generic' on Buffers.

Lowering depthwise convolution to linalg.generic op. The idea is to use the group convolution formulation to perform the separable depth wise convolution as the following, given an n-dimensional input x and filter w the direct convolution operation can be written as following expression that performs reduction over kernel dimensions k1..kn.

y[n, d1, d2, ..., dn, ci * groupSize + co] = sum(x[n, d1 * stride1 + k1, d1 * stride2 + k2, ..., dn * striden + kn] * w[k1, k2, ..., kn, ci, co], k1, k2, ..., kn)

@google-cla google-cla bot added the cla: yes label Jul 27, 2020
@asaadaldien asaadaldien force-pushed the ataei-linalg_depthwise_conv branch from 95aa614 to 4ce60a0 Compare July 27, 2020 23:09
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a TODO that we can also just lower this to linalg.depthwise-convolution named op. THere were named ops for conv1D and conv2D, etc. added to MLIR. So, it is unclear whether named ops addition is OK or not, but thats an option for the future.

Comment on lines +395 to +399
// y[n, d1, d2, ....dn, ci * groupSize + co] = sum(k1, k2, ....kn,
// x[n, d1 * stride1 + k1, d1 * stride2 + k2, ...dn * striden + kn]
// * w[k1, k2, ...kn, ci, co])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should include dilation in the comment too since it is supported.

},
feature_group_count = 2 : i64,
padding = dense<0> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if we support dilation in convolution itself, but would be good to check that dilation loweres correctly (The linalg.conv lowers the dilation properly)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a front end problem when constructing tf.depthwise_conv2d op the saved model doesn't have correct dilation attributes but default values instead. For conv I will add more test cases
So I will create a separate issues to track :

  • Increase coverage of conv tests to include dialed convolution.
  • Triage dilated depthwise convolution savedModel issue

@asaadaldien asaadaldien force-pushed the ataei-linalg_depthwise_conv branch 2 times, most recently from 5eba976 to 2cf9303 Compare August 4, 2020 04:54
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just few nits, thanks!

@@ -26,19 +26,37 @@ class Conv2dModule(tf.Module):
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 2, 2, 3], tf.float32),
])
def conv2d_2452x2223_valid(self, img, kernel):
def conv2d_2423x2223_valid(self, img, kernel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2452x2223

return tf.nn.depthwise_conv2d(
img, kernel, [1, 1, 1, 1], "VALID", name="result")

@tf.function(input_signature=[
tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2452x2223_same(self, img, kernel):
def conv2d_2423x2223_same(self, img, kernel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2452x2423

tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2423x2223_valid_stride_2(self, img, kernel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2452x2423

tf.TensorSpec([2, 4, 5, 2], tf.float32),
tf.TensorSpec([2, 4, 2, 3], tf.float32),
])
def conv2d_2423x2223_same_stride_2(self, img, kernel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2452x2423

}
}
inputExprs.push_back(rewriter.getAffineDimExpr(ciIndex));
// k1, k2, ...kn, ci, co
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a blank line before the comment since they are for different exprs.

@asaadaldien asaadaldien force-pushed the ataei-linalg_depthwise_conv branch from 2cf9303 to 8344da9 Compare August 5, 2020 16:55
@asaadaldien asaadaldien merged commit 068da1f into main Aug 5, 2020
@asaadaldien asaadaldien deleted the ataei-linalg_depthwise_conv branch August 5, 2020 17:26
This was referenced Aug 5, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants