-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[refactor] Add SNode::GradInfoProvider to isolate SNode from Expr #2207
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! I like the GradInfoProvider
class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
ExprGroup indices; | ||
for (int i = 0; i < snode->num_active_indices; i++) { | ||
indices.push_back(Expr::make<ArgLoadExpression>(i, PrimitiveType::i32)); | ||
} | ||
auto ret = Stmt::make<FrontendKernelReturnStmt>( | ||
load_if_ptr((snode->expr)[indices])); | ||
load_if_ptr(Expr(snode_to_glb_var_exprs_.at(snode))[indices])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why we need this Expr()
? (same question at line 663)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Note that SNode::expr
used to be of type Expr
, which contains a GlobalVariableExpression
. Here, snode_to_glb_var_exprs_
maps directly from SNode*
to GlobalVariableExpression
. As a result, when we do need to use an Expr
, we have to wrap that up.
The reason for mapping to GlobalVariableExpression
is that, If you look at the previous implementation of is_primal()
(has_grad()
, get_grad()
), they do a lot of dynamic casts from Expr
to GlobalVariableExpression
.
Related issue = #2196
This PR is still about decoupling
SNode
fromtaichi/program
. The major issue here is around its member variableexpr
. BecauseExpr
is part of the frontend and relies ontaichi/program
, we have to move it out fromSNode
:place()
andlazy_grad()
fromSNode
itself to a separate filetaichi/program/snode_expr_utils.h/cpp
(i cannot think of a better file name...). Other functions inSNode
do not depend onexpr
.SNode::GradInfoProvider
abstract class. This is essentially a wrapper aroundGlobalVariableExpression
. From what I can tell,SNode::expr
is mostly used to provide gradient information, e.g.is_primal()
,has_grad()
, etc. The actual implementation,GradInfoImpl
, is intaichi/program/snode_expr_utils.cpp
.SNode*
toGlobalVariableExpression
. The map is stored insideProgram
. While it is possible to retrieveGlobalVariableExpression
fromGradInfoImpl
, that involves castingGradInfoProvider
toGradInfoImpl
and kind of exposes the implementation detail. I'd like to makeGradInfoProvider
have a single responsibility.[Click here for the format server]