-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SyncBatchNorm #26032
Add SyncBatchNorm #26032
Conversation
892bbda
to
808572f
Compare
2a24b15
to
62eb999
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for op_function_generator.cc
@@ -0,0 +1,90 @@ | |||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copyright (c) 2020
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thanks
python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py
Outdated
Show resolved
Hide resolved
python/paddle/fluid/dygraph/nn.py
Outdated
with fluid.dygraph.guard(): | ||
x = to_variable(x) | ||
if fluid.is_compiled_with_cuda(): | ||
sync_batch_norm = nn.SyncBatchNorm(10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
辛苦用注释的形式给出输入、输出示例。如果PR着急,辛苦merge后再提一个PR进行更新。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for op_function_generator.cc
python/paddle/fluid/dygraph/nn.py
Outdated
import numpy as np | ||
|
||
x = np.array([[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32') | ||
with fluid.dygraph.guard(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
示例可以直接用paddle2.0-alpha的api
with paddle.imperative.guard():
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for op_function_generator.cc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need subsequent change for converting batchnorms, otherwise LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
APIs
Describe
Add SyncBatchNorm API for dygraph
use_global_stas
andtrainable_stats
to jugde whether to use global or mini-batch stats