-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add expand operator #4061
Add expand operator #4061
Conversation
namespace paddle { | ||
namespace operators { | ||
|
||
using Tensor = framework::Tensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not using namespace
or type alias in the header file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/expand_op.h
Outdated
template <int Rank> | ||
void Expand(const framework::ExecutionContext& context) const { | ||
auto* in0 = context.Input<Tensor>("X"); | ||
auto expand_times = context.Attr<std::vector<int>>("expandTimes"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto -> auto&
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/expand_op.h
Outdated
auto* in0 = context.Input<Tensor>("X"); | ||
auto expand_times = context.Attr<std::vector<int>>("expandTimes"); | ||
auto* out0 = context.Output<Tensor>("Out"); | ||
Eigen::DSizes<int, Rank> bcast_dims; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the problem is how to cast vector<int> to Eigen::DSizes<Rank>
Related issue #4091 |
paddle/operators/expand_op.h
Outdated
} | ||
} | ||
|
||
int dims = reshape_dims_vec.size() * 6 + reduce_dims_vec.size() - 7; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不-7
的话是否ExpandBackward
中 Dims / 6 + 1
和 Dims % 6 + 1
就可以把+1
去掉了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really, please consider 6*1 + 6 = 12, 12/6=2 (expect 1)
paddle/operators/expand_op.cc
Outdated
AddComment(R"DOC( | ||
Expand operator tiles the input by given times number. You should set times | ||
number for each dimension by providing attribute 'expandTimes'. Rank of input | ||
tensor should be in [1, 6]. Please draw an attention that size of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why Rank of input tensor should be in [1, 6]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation will specialize the template due to the limitation of Eigen. I think rank 6 is big enough.
paddle/operators/expand_op.cc
Outdated
void InferShape(const framework::InferShapeContext& ctx) const override { | ||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X must be initialized."); | ||
std::vector<int> expand_times = Attr<std::vector<int>>("expandTimes"); | ||
auto* x = ctx.Input<Tensor>("X"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems x is not used, can use
auto x_dims = ctx.Input<Tensor>("X")->dims();
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
self.op_type = "expand" | ||
self.inputs = {'X': np.random.random((12, 14)).astype("float32")} | ||
self.attrs = {'expandTimes': [1, 1]} | ||
output = np.tile(self.inputs['X'], (1, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think use all one as expand time is not good for test ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is a corner case and I will add more robust test cases. Same as bellow.
def setUp(self): | ||
self.op_type = "expand" | ||
self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")} | ||
self.attrs = {'expandTimes': [1, 1, 1]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seem as above
Maybe could use this PR #4205, too? |
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-4029
Please update this PR and merge it asap. |
paddle/operators/expand_op.cc
Outdated
"dimension size of Input(X) multiplying corresponding value of " | ||
"Attr(expandTimes)."); | ||
AddAttr<std::vector<int>>("expandTimes", | ||
"Expand times number for each dimension."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expandTimes -> expand_times
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/expand_op.cc
Outdated
} | ||
|
||
ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); | ||
ctx->ShareLoD("X", "Out"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该out_shape[0] == x_dims[0]
时,才能SharedLoD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/expand_op.cc
Outdated
|
||
protected: | ||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should not be null.
Also needs to check the output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/expand_op.h
Outdated
auto& expand_times = context.Attr<std::vector<int>>("expandTimes"); | ||
auto x_dims = in0->dims(); | ||
std::vector<int> reshape_dims_vec; | ||
std::vector<int> reduce_dims_vec; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the comments about how to compute gradients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/operators/expand_op.h
Outdated
const std::vector<int>& reshape_dims_vec, | ||
const std::vector<int>& reduce_dims_vec) const { | ||
size_t reshape_size = Dims / 6 + 1; | ||
size_t reduce_size = Dims % 6 + 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use a constant variable to represent 6, do not use 6 directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
self.check_grad(['X'], 'Out') | ||
|
||
|
||
class TestExpandOpRank2_2(OpTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name _1 _2 is not clear.
Resolves #4029