-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Executorch] Refactor op_add to support op_sub broadcasting #8255
base: gh/kimishpatel/155/base
Are you sure you want to change the base?
Conversation
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8255
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 4f81db5 with merge base 5dd2ed3 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this stack is going to need some changes in response to comments on this diff and previous, pausing review here
out, | ||
"Failed to resize output tensor."); | ||
|
||
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to put const char *op_name
in the template parameters and fix this
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() { | ||
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto op name
yeah sounds good. Let me address your comments in the previous diffs |
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
if constexpr (is_sub) { | ||
if (selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments || | ||
selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments || | ||
selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { | ||
auto add_lambda = [&alpha_val_vec](auto x, auto y) { | ||
return y - alpha_val_vec * x; | ||
}; | ||
return torch::executor::handle_broadcast_elementwise<CTYPE>( | ||
ctx, add_lambda, a, b, out, selected_optimized_path, alpha); | ||
} else { | ||
auto add_lambda = [&alpha_val_vec](auto x, auto y) { | ||
return x - alpha_val_vec * y; | ||
}; | ||
return torch::executor::handle_broadcast_elementwise<CTYPE>( | ||
ctx, add_lambda, a, b, out, selected_optimized_path, alpha); | ||
} | ||
} else { | ||
if (selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments || | ||
selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments || | ||
selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { | ||
// Reason we swap out args here is because | ||
// handle_broadcast_elementwise handles this selected_optimized_path | ||
// option a bit differently. This should really be resolved in | ||
// handle_broadcast_elementwise. However, the current blocker is that | ||
// handle_broadcast_elementwise tries to be agnostic of op. This | ||
// should be fixed, likely by moving lambda creation to | ||
// handle_broadcast_elementwise and it be aware of which op is being | ||
// executed. | ||
auto add_lambda = [&alpha_val_vec](auto x, auto y) { | ||
return y + alpha_val_vec * x; | ||
}; | ||
return torch::executor::handle_broadcast_elementwise<CTYPE>( | ||
ctx, add_lambda, a, b, out, selected_optimized_path, alpha); | ||
} else { | ||
auto add_lambda = [&alpha_val_vec](auto x, auto y) { | ||
return x + alpha_val_vec * y; | ||
}; | ||
return torch::executor::handle_broadcast_elementwise<CTYPE>( | ||
ctx, add_lambda, a, b, out, selected_optimized_path, alpha); | ||
} | ||
} | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can just select the lambdas based on is_sub rather than duplicating the rest of the code under this if constexpr
: https://godbolt.org/z/Esdz1exKj
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thats a good suggestion. I tried doing the same in a different way which didnt quite work, but i can try out your suggestion
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Stack from ghstack (oldest at bottom):
Summary:
Refactor op_add to conslidate commong broadcasting related improvements
Test Plan:
Previously added tests
Reviewers:
Subscribers:
Tasks:
Tags:
cc @larryliu0820 @manuelcandales
Differential Revision: D69491817