-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[async] Support global temporaries value state #2061
Conversation
taichi/program/async_utils.cpp
Outdated
@@ -119,18 +143,33 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { | |||
|
|||
// TODO: this is an abuse since it gathers nothing... | |||
gather_statements(root_stmt, [&](Stmt *stmt) { | |||
if (auto global_load = stmt->cast<GlobalLoadStmt>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the real change. The rest is just to adapt to snode_or_global_tmp
..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool -- We may use a snode for global temps, merge GlobalTemporaryStmt
into GlobalPtrStmt
and handle this in get_meta_input_value_states()
(which calls ControlFlowGraph::gather_loaded_snodes()
) in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
taichi/program/async_utils.cpp
Outdated
@@ -119,18 +143,33 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { | |||
|
|||
// TODO: this is an abuse since it gathers nothing... | |||
gather_statements(root_stmt, [&](Stmt *stmt) { | |||
if (auto global_load = stmt->cast<GlobalLoadStmt>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool -- We may use a snode for global temps, merge GlobalTemporaryStmt
into GlobalPtrStmt
and handle this in get_meta_input_value_states()
(which calls ControlFlowGraph::gather_loaded_snodes()
) in the future.
if (state.type == AsyncState::Type::value && state.holds_snode()) { | ||
const auto *sn = state.snode(); | ||
if (meta.element_wise.find(sn) == meta.element_wise.end() || | ||
!meta.element_wise[sn]) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also apply to global temps because we never completely overwriting the state.
Maybe we should directly write sth like this in the above gather_statements
, whose correctness is clearer to me:
if (stmt->is<GlobalTemporaryStmt>()) {
meta.input_states.insert(AsyncState::for_global_tmp(t.kernel));
meta.output_states.insert(AsyncState::for_global_tmp(t.kernel));
}
This may cause some tasks only reading the global temp buffer not swappable, but I think there won't be >1 consecutive tasks reading the buffer without writing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Marking GlobalTemporaryStmt
whenever we see it, as both input and output states sounds good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha, that's actually where I started (locally)... Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! LGTM!
if (state.type == AsyncState::Type::value && state.holds_snode()) { | ||
const auto *sn = state.snode(); | ||
if (meta.element_wise.find(sn) == meta.element_wise.end() || | ||
!meta.element_wise[sn]) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Marking GlobalTemporaryStmt
whenever we see it, as both input and output states sounds good to me!
Codecov Report
@@ Coverage Diff @@
## master #2061 +/- ##
==========================================
+ Coverage 42.60% 43.56% +0.95%
==========================================
Files 45 45
Lines 6478 6267 -211
Branches 1110 1110
==========================================
- Hits 2760 2730 -30
+ Misses 3546 3366 -180
+ Partials 172 171 -1
Continue to review full report at Codecov.
|
I've extended
AsyncNode::snode
to hold astd::variant<SNode*, Kernel*>
, because global temporaries can be identified by its enclosing kernel. (I feel like this is more intuitive than having aglobal_tmp
in the enum class. Otherwise, what should the SNode be?)Related issue = #2024
[Click here for the format server]