-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Complete register gradient for compile time #4566
Complete register gradient for compile time #4566
Conversation
97c6c95
to
e119177
Compare
1e3b45f
to
c4effc7
Compare
paddle/framework/op_registry.cc
Outdated
std::unique_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) { | ||
PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); | ||
return std::unique_ptr<OperatorBase>(BuildGradOp(&op)); | ||
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(OpDescBind* op_desc) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_desc
should be a reference.
@@ -43,19 +43,17 @@ struct OpInfo { | |||
return *proto_; | |||
} | |||
|
|||
const OpAttrChecker& Checker() const { | |||
PADDLE_ENFORCE_NOT_NULL(checker_, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although the PADDLE_ENFORCE
is not necessary, Checker()
should still be retained for being unified with Creator()
, GradOpMaker()
and so on.
"REGISTER_OPERATOR must be called in global namespace"); \ | ||
class _OpClass_##op_type##_ : public op_class { \ | ||
public: \ | ||
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need to make Op
being able to copy themselves?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need discussion. Will change later if necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need to make Op being able to copy themselves?
I think it is not needed.
grad_op->SetType("mean_grad"); | ||
grad_op->SetInput("X", Input("X")); | ||
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); | ||
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm adding a new constructor to OpDescBind
, which initializes all its fields at one time.
OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs) {
op_desc_.set_type(type);
inputs_ = inputs;
outputs_ = outputs;
attrs_ = attrs;
}
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; | ||
|
||
protected: | ||
std::unique_ptr<framework::OpDescBind> Apply() const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think Apply()
is not necessary. We can let op developers override operator()
directly by themselves. The definition of operator()
has dictated that the return value should be a vector
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can discuss later and change if necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
No description provided.