-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add axis for mul_op
and rowwise_add_op
#3888
Changes from 4 commits
e76fa85
86655cb
af0264a
69fbc54
d71396b
e168fc4
256d6a3
f2a66ff
823bdd6
3d62c6d
0c13660
5aacd64
d7c8bdc
b744430
1d9a4d2
f6e72c9
b6a4666
856611c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,6 +51,18 @@ class LargerThanChecker { | |
T lower_bound_; | ||
}; | ||
|
||
template <typename T> | ||
class EqualLargerThanChecker { | ||
public: | ||
explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} | ||
void operator()(T& value) const { | ||
PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fail"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check fail -> check fails There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
} | ||
|
||
private: | ||
T lower_bound_; | ||
}; | ||
|
||
// we can provide users more common Checker, like 'LessThanChecker', | ||
// 'BetweenChecker'... | ||
|
||
|
@@ -114,6 +126,11 @@ class TypedAttrChecker { | |
return *this; | ||
} | ||
|
||
TypedAttrChecker& EqualLargerThan(const T& lower_bound) { | ||
value_checkers_.push_back(EqualLargerThanChecker<T>(lower_bound)); | ||
return *this; | ||
} | ||
|
||
// we can add more common limits, like LessThan(), Between()... | ||
|
||
TypedAttrChecker& SetDefault(const T& default_value) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,7 +63,18 @@ struct EigenTensor { | |
|
||
template <typename T, int MajorType = Eigen::RowMajor, | ||
typename IndexType = Eigen::DenseIndex> | ||
struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {}; | ||
struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> { | ||
static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_row_dims) { | ||
int rank = tensor.dims_.size(); | ||
PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, | ||
"`num_row_dims` must be between (0, rank_of_tensor)."); | ||
return EigenMatrix::From( | ||
tensor, make_ddim({static_cast<int>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a method in class DDim {
public:
Dim<2> FlattenToMat(int numFlattenDims) const;
}; There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
product(tensor.dims_, 0, rank - num_row_dims)), | ||
static_cast<int>(product( | ||
tensor.dims_, rank - num_row_dims, rank))})); | ||
} | ||
}; | ||
|
||
template <typename T, int MajorType = Eigen::RowMajor, | ||
typename IndexType = Eigen::DenseIndex> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,5 +108,25 @@ TEST(Eigen, Matrix) { | |
} | ||
} | ||
|
||
TEST(Eigen, MatrixReshape) { | ||
Tensor t; | ||
float* p = | ||
t.mutable_data<float>(make_ddim({2, 3, 6, 4}), platform::CPUPlace()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks! |
||
for (int i = 0; i < 2 * 3 * 6 * 4; ++i) { | ||
p[i] = static_cast<float>(i); | ||
} | ||
|
||
EigenMatrix<float>::Type em = EigenMatrix<float>::Reshape(t, 2); | ||
|
||
ASSERT_EQ(2 * 3, em.dimension(0)); | ||
ASSERT_EQ(6 * 4, em.dimension(1)); | ||
|
||
for (int i = 0; i < 2 * 3; i++) { | ||
for (int j = 0; j < 6 * 4; j++) { | ||
ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f); | ||
} | ||
} | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,5 +148,17 @@ inline Tensor& Tensor::Resize(const DDim& dims) { | |
|
||
inline const DDim& Tensor::dims() const { return dims_; } | ||
|
||
template <typename T> | ||
inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to add the explanation for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
Tensor res; | ||
res.ShareDataWith<T>(src); | ||
DDim src_dim = src.dims(); | ||
int rank = src_dim.size(); | ||
res.Resize(make_ddim( | ||
{static_cast<int>(product(src_dim, 0, rank - num_row_dims)), | ||
static_cast<int>(product(src_dim, rank - num_row_dims, rank))})); | ||
return res; | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) { | |
} | ||
#endif | ||
} | ||
|
||
TEST(Tensor, FlattenToMatrix) { | ||
using namespace paddle::framework; | ||
using namespace paddle::platform; | ||
Tensor src; | ||
int* src_ptr = src.mutable_data<int>(make_ddim({2, 3, 4, 9}), CPUPlace()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make_ddim is not needed. |
||
for (int i = 0; i < 2 * 3 * 4 * 9; ++i) { | ||
src_ptr[i] = i; | ||
} | ||
Tensor res = FlattenToMatrix<int>(src, 2); | ||
ASSERT_EQ(res.dims()[0], 2 * 3); | ||
ASSERT_EQ(res.dims()[1], 4 * 9); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,18 +25,27 @@ class MulOp : public framework::OperatorWithKernel { | |
|
||
protected: | ||
void InferShape(const framework::InferShapeContext &ctx) const override { | ||
auto dim0 = ctx.Input<Tensor>("X")->dims(); | ||
auto dim1 = ctx.Input<Tensor>("Y")->dims(); | ||
PADDLE_ENFORCE_EQ(dim0.size(), 2, | ||
"input X(%s) should be a tensor with 2 dims, a matrix", | ||
ctx.op().Input("X")); | ||
PADDLE_ENFORCE_EQ(dim1.size(), 2, | ||
"input Y(%s) should be a tensor with 2 dims, a matrix", | ||
ctx.op().Input("Y")); | ||
auto x_dim = ctx.Input<Tensor>("X")->dims(); | ||
auto y_dim = ctx.Input<Tensor>("Y")->dims(); | ||
int x_num_row_dims = GetAttr<int>("x_num_row_dims"); | ||
int y_num_row_dims = GetAttr<int>("y_num_row_dims"); | ||
|
||
PADDLE_ENFORCE(x_dim.size() > x_num_row_dims, | ||
"The rank of input tensor X(%s) should be larger than " | ||
"`mul_op`'s `x_num_row_dims`.", | ||
ctx.op().Input("X")); | ||
PADDLE_ENFORCE(y_dim.size() > y_num_row_dims, | ||
"The rank of input tensor Y(%s) should be larger than " | ||
"`mul_op`'s `y_num_row_dims`.", | ||
ctx.op().Input("Y")); | ||
PADDLE_ENFORCE_EQ( | ||
dim0[1], dim1[0], | ||
product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()), | ||
product(y_dim, 0, y_dim.size() - y_num_row_dims), | ||
"First matrix's width must be equal with second matrix's height."); | ||
ctx.Output<Tensor>("Out")->Resize({dim0[0], dim1[1]}); | ||
ctx.Output<Tensor>("Out")->Resize( | ||
{static_cast<int>(product(x_dim, 0, x_dim.size() - x_num_row_dims)), | ||
static_cast<int>( | ||
product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size()))}); | ||
} | ||
}; | ||
|
||
|
@@ -47,6 +56,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { | |
AddInput("X", "The first input of mul op"); | ||
AddInput("Y", "The second input of mul op"); | ||
AddOutput("Out", "The output of mul op"); | ||
AddAttr<int>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a very useful syntax in AddAttr<int>("x_num_col_dims", R"DOC(mul_op can take ...
....
)DOC");
See http://en.cppreference.com/w/cpp/language/string_literal There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, Thank you! |
||
"x_num_row_dims", | ||
"mul_op can take tensors with more than two dimensions as input `X`, " | ||
"in that case, tensors will be flattened to a matrix. The matrix's " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. flattened -> reshaped? In the Numpy, flatten means converting to a vector. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
"second dimension(row length) will be the product of tensor's last " | ||
"`num_row_dims` dimensions, and the matrix's first dimension(column " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
matrix's first dimension是 dims[0]? second dimension是dims[1]吗? 如果是,matrix's first dimension表示的row length(也就是height), second dimension表示的是col length(也就是width)。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我想表达的row length的意思是“行的长度”,所以似乎应该是width? |
||
"length) will be the product of tensor's first `rank - num_row_dims` " | ||
"dimensions.") | ||
.SetDefault(1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 依据上面的描述,和最常用的情况不符合,最常用的是reshape成:height = dims[0], width = product(dims[1:]) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经修改,把参数从 |
||
.EqualLargerThan(1); | ||
AddAttr<int>( | ||
"y_num_row_dims", | ||
"mul_op can take tensors with more than two dimensions as input `Y`, " | ||
"in that case, tensors will be flattened to a matrix. Just like input " | ||
"`X`.") | ||
.SetDefault(1) | ||
.EqualLargerThan(1); | ||
AddComment(R"DOC( | ||
Two Element Mul Operator. | ||
|
||
|
@@ -70,10 +96,16 @@ class MulOpGrad : public framework::OperatorWithKernel { | |
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); | ||
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
PADDLE_ENFORCE(x_dims[0] == out_dims[0], | ||
"Out@GRAD M X N must equal to X dims 0, M "); | ||
PADDLE_ENFORCE(y_dims[1] == out_dims[1], | ||
"Out@GRAD M X N must equal to Y dims 1, N "); | ||
PADDLE_ENFORCE( | ||
product(x_dims, 0, x_dims.size() - GetAttr<int>("x_num_row_dims")) == | ||
out_dims[0], | ||
"The first dimension of Out@GRAD must equal to the first dimension of " | ||
"the first operand."); | ||
PADDLE_ENFORCE( | ||
product(y_dims, y_dims.size() - GetAttr<int>("y_num_row_dims"), | ||
y_dims.size()) == out_dims[1], | ||
"The second dimension of Out@GRAD must equal to the second " | ||
"dimension of the second operand."); | ||
|
||
if (x_grad) x_grad->Resize(x_dims); | ||
if (y_grad) y_grad->Resize(y_dims); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,13 +2,13 @@ | |
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
you may obtain a copy of the License at | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you -> You There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANy KIND, either express or implied. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ANy -> ANY There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
|
@@ -31,37 +31,65 @@ template <typename Place, typename T> | |
class MulKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto* x = context.Input<Tensor>("X"); | ||
auto* y = context.Input<Tensor>("Y"); | ||
auto* z = context.Output<Tensor>("Out"); | ||
z->mutable_data<T>(context.GetPlace()); | ||
const Tensor* x = context.Input<Tensor>("X"); | ||
const Tensor* y = context.Input<Tensor>("Y"); | ||
Tensor* Z = context.Output<Tensor>("Out"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Z -> z There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
const Tensor x_matrix = | ||
x->dims().size() > 2 | ||
? framework::FlattenToMatrix<T>( | ||
*x, context.template GetAttr<int>("x_num_row_dims")) | ||
: *x; | ||
const Tensor y_matrix = | ||
y->dims().size() > 2 | ||
? framework::FlattenToMatrix<T>( | ||
*y, context.template GetAttr<int>("y_num_row_dims")) | ||
: *y; | ||
|
||
Z->mutable_data<T>(context.GetPlace()); | ||
auto* device_context = | ||
const_cast<platform::DeviceContext*>(context.device_context_); | ||
math::matmul<Place, T>(*x, false, *y, false, 1, z, 0, device_context); | ||
math::matmul<Place, T>(x_matrix, false, y_matrix, false, 1, Z, 0, | ||
device_context); | ||
} | ||
}; | ||
|
||
template <typename Place, typename T> | ||
class MulGradKernel : public framework::OpKernel { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto* x = ctx.Input<Tensor>("X"); | ||
auto* y = ctx.Input<Tensor>("Y"); | ||
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||
int x_num_row_dims = ctx.template GetAttr<int>("x_num_row_dims"); | ||
int y_num_row_dims = ctx.template GetAttr<int>("y_num_row_dims"); | ||
const Tensor* x = ctx.Input<Tensor>("X"); | ||
const Tensor* y = ctx.Input<Tensor>("Y"); | ||
const Tensor x_matrix = | ||
x->dims().size() > 2 ? framework::FlattenToMatrix<T>(*x, x_num_row_dims) | ||
: *x; | ||
const Tensor y_matrix = | ||
y->dims().size() > 2 ? framework::FlattenToMatrix<T>(*y, y_num_row_dims) | ||
: *y; | ||
const Tensor* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||
|
||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
auto* device_context = | ||
const_cast<platform::DeviceContext*>(ctx.device_context_); | ||
if (dx) { | ||
dx->mutable_data<T>(ctx.GetPlace()); | ||
Tensor dx_matrix = dx->dims().size() > 2 ? framework::FlattenToMatrix<T>( | ||
*dx, x_num_row_dims) | ||
: *dx; | ||
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N | ||
math::matmul<Place, T>(*dout, false, *y, true, 1, dx, 0, device_context); | ||
math::matmul<Place, T>(*dout, false, y_matrix, true, 1, &dx_matrix, 0, | ||
device_context); | ||
} | ||
if (dy) { | ||
dy->mutable_data<T>(ctx.GetPlace()); | ||
Tensor dy_matrix = dy->dims().size() > 2 ? framework::FlattenToMatrix<T>( | ||
*dy, y_num_row_dims) | ||
: *dy; | ||
// dy = x' * dout. dy K x N, dout : M x N, x : M x K | ||
math::matmul<Place, T>(*x, true, *dout, false, 1, dy, 0, device_context); | ||
math::matmul<Place, T>(x_matrix, true, *dout, false, 1, &dy_matrix, 0, | ||
device_context); | ||
} | ||
} | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is better to compatible with gtest. Such as
CHECK_GE
or something?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EqualLargerThan
is a function, not a macro, so the name shall not be too short.