-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/operator run place #6783
Feature/operator run place #6783
Conversation
@@ -83,8 +83,7 @@ class OperatorBase { | |||
virtual std::string DebugString() const; | |||
|
|||
/// Net will call this function to Run an op. | |||
virtual void Run(const Scope& scope, | |||
const platform::DeviceContext& dev_ctx) const = 0; | |||
virtual void Run(const Scope& scope, const platform::Place& place) const = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, we can simply add another interface here:
virtual void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const = 0;
void Run(const Scope& scope, const platform::Place& place) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
Run(scope, dev_ctx);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We hold the same believe that Operatorbase has one and only one runnable interface.
Now it's the early time that we can fix operator interface once, otherwise, we can not simply remove the second Run
in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same confusing with @QiJune , for this PR, too many operators will fetch the device_ctx
and add the following code:
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get();
auto &dev_ctx = *pool.Borrow(place);
I knew from @dzhwinter that one Op without Kernel would fetch the dev_ctx
from DeviceContextPool
, and I think the suggestion from @QiJune maybe a simple way, and don't add too much repeated code in the Ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hold a different view. @QiJune 's suggestion is a transitional approach.
DeviceContextPool contains device related resources, so I believe only the OperatorWithKernel need to touch the pool.
void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get();
auto dev_ctx = pool.Borrow(place);
If we make an agreement on above, so we should not add a Run(DeviceContext) in the final solution.
In addition, most of the operators don't need to add the redundant snippet of get device from the global pool, so why this PR contains a lot Get()
?
The reason is our CopyFrom didn't different CPUDevice and GPUDevice, actually, only the GPUDevice need to get a devicecontext from the pool and do a copy. We will change it in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I approve this for now, we shall fix following issues later.
fix #6784