Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add axis for mul_op and rowwise_add_op #3888

Merged
merged 18 commits into from
Sep 8, 2017
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions paddle/framework/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ class LargerThanChecker {
T lower_bound_;
};

template <typename T>
class EqualLargerThanChecker {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is better to compatible with gtest. Such as CHECK_GE or something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EqualLargerThan is a function, not a macro, so the name shall not be too short.

public:
explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const {
PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fail");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check fail -> check fails

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}

private:
T lower_bound_;
};

// we can provide users more common Checker, like 'LessThanChecker',
// 'BetweenChecker'...

Expand Down Expand Up @@ -114,6 +126,11 @@ class TypedAttrChecker {
return *this;
}

TypedAttrChecker& EqualLargerThan(const T& lower_bound) {
value_checkers_.push_back(EqualLargerThanChecker<T>(lower_bound));
return *this;
}

// we can add more common limits, like LessThan(), Between()...

TypedAttrChecker& SetDefault(const T& default_value) {
Expand Down
30 changes: 18 additions & 12 deletions paddle/framework/ddim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,6 @@ std::vector<int> vectorize(const DDim& ddim) {
return result;
}

struct ProductVisitor : public boost::static_visitor<ssize_t> {
template <int D>
ssize_t operator()(const Dim<D>& dim) {
return product(dim);
}
};

ssize_t product(const DDim& ddim) {
ProductVisitor visitor;
return boost::apply_visitor(visitor, ddim);
}

struct SliceVectorizeVisitor : public boost::static_visitor<> {
std::vector<int>& vector;
int begin;
Expand Down Expand Up @@ -247,6 +235,24 @@ DDim slice_ddim(const DDim& dim, int begin, int end) {
return make_ddim(vec);
}

struct ProductVisitor : public boost::static_visitor<ssize_t> {
template <int D>
ssize_t operator()(const Dim<D>& dim) {
return product(dim);
}
};

ssize_t product(const DDim& ddim) {
ProductVisitor visitor;
return boost::apply_visitor(visitor, ddim);
}

ssize_t product(const DDim& ddim, int begin, int end) {
ProductVisitor visitor;
DDim sliced_ddim = slice_ddim(ddim, begin, end);
return boost::apply_visitor(visitor, sliced_ddim);
}

/// \cond HIDDEN

struct ArityVisitor : boost::static_visitor<int> {
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/ddim.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ std::vector<int> vectorize(const DDim& ddim);

ssize_t product(const DDim& ddim);

ssize_t product(const DDim& ddim, int begin, int end);

/**
* \brief Slice a ddim
*
Expand Down
13 changes: 12 additions & 1 deletion paddle/framework/eigen.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,18 @@ struct EigenTensor {

template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {};
struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_row_dims) {
int rank = tensor.dims_.size();
PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank,
"`num_row_dims` must be between (0, rank_of_tensor).");
return EigenMatrix::From(
tensor, make_ddim({static_cast<int>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_ddim could be removed, just {0, 10} is OK.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a method in class DDim, such as

class DDim {
 public:
  Dim<2> FlattenToMat(int numFlattenDims) const;
};

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

product(tensor.dims_, 0, rank - num_row_dims)),
static_cast<int>(product(
tensor.dims_, rank - num_row_dims, rank))}));
}
};

template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
Expand Down
20 changes: 20 additions & 0 deletions paddle/framework/eigen_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,25 @@ TEST(Eigen, Matrix) {
}
}

TEST(Eigen, MatrixReshape) {
Tensor t;
float* p =
t.mutable_data<float>(make_ddim({2, 3, 6, 4}), platform::CPUPlace());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_ddim is not needed, just t.mutable_data<float>({2, 3, 6, 4}) is cool.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

for (int i = 0; i < 2 * 3 * 6 * 4; ++i) {
p[i] = static_cast<float>(i);
}

EigenMatrix<float>::Type em = EigenMatrix<float>::Reshape(t, 2);

ASSERT_EQ(2 * 3, em.dimension(0));
ASSERT_EQ(6 * 4, em.dimension(1));

for (int i = 0; i < 2 * 3; i++) {
for (int j = 0; j < 6 * 4; j++) {
ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f);
}
}
}

} // namespace framework
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class Tensor {
template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor;

template <typename T, int MajorType, typename IndexType>
friend struct EigenMatrix;

template <typename T, int MajorType, typename IndexType>
friend struct EigenVector;

Expand Down
12 changes: 12 additions & 0 deletions paddle/framework/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,17 @@ inline Tensor& Tensor::Resize(const DDim& dims) {

inline const DDim& Tensor::dims() const { return dims_; }

template <typename T>
inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to add the explanation for num_row_dims.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_row_dims is not easy to use, so I use num_col_dms instead. And comments have been added.

Tensor res;
res.ShareDataWith<T>(src);
DDim src_dim = src.dims();
int rank = src_dim.size();
res.Resize(make_ddim(
{static_cast<int>(product(src_dim, 0, rank - num_row_dims)),
static_cast<int>(product(src_dim, rank - num_row_dims, rank))}));
return res;
}

} // namespace framework
} // namespace paddle
13 changes: 13 additions & 0 deletions paddle/framework/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) {
}
#endif
}

TEST(Tensor, FlattenToMatrix) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
int* src_ptr = src.mutable_data<int>(make_ddim({2, 3, 4, 9}), CPUPlace());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_ddim is not needed.

for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
src_ptr[i] = i;
}
Tensor res = FlattenToMatrix<int>(src, 2);
ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9);
}
60 changes: 46 additions & 14 deletions paddle/operators/mul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,27 @@ class MulOp : public framework::OperatorWithKernel {

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
auto dim1 = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_EQ(dim0.size(), 2,
"input X(%s) should be a tensor with 2 dims, a matrix",
ctx.op().Input("X"));
PADDLE_ENFORCE_EQ(dim1.size(), 2,
"input Y(%s) should be a tensor with 2 dims, a matrix",
ctx.op().Input("Y"));
auto x_dim = ctx.Input<Tensor>("X")->dims();
auto y_dim = ctx.Input<Tensor>("Y")->dims();
int x_num_row_dims = GetAttr<int>("x_num_row_dims");
int y_num_row_dims = GetAttr<int>("y_num_row_dims");

PADDLE_ENFORCE(x_dim.size() > x_num_row_dims,
"The rank of input tensor X(%s) should be larger than "
"`mul_op`'s `x_num_row_dims`.",
ctx.op().Input("X"));
PADDLE_ENFORCE(y_dim.size() > y_num_row_dims,
"The rank of input tensor Y(%s) should be larger than "
"`mul_op`'s `y_num_row_dims`.",
ctx.op().Input("Y"));
PADDLE_ENFORCE_EQ(
dim0[1], dim1[0],
product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()),
product(y_dim, 0, y_dim.size() - y_num_row_dims),
"First matrix's width must be equal with second matrix's height.");
ctx.Output<Tensor>("Out")->Resize({dim0[0], dim1[1]});
ctx.Output<Tensor>("Out")->Resize(
{static_cast<int>(product(x_dim, 0, x_dim.size() - x_num_row_dims)),
static_cast<int>(
product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size()))});
}
};

Expand All @@ -47,6 +56,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op");
AddAttr<int>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a very useful syntax in C++ 11.

AddAttr<int>("x_num_col_dims", R"DOC(mul_op can take ...
....
)DOC");

R"LABEL(...)LABEL" is just like python's """...""". which LABEL is a custom label to identify where the string begins and ends.

See http://en.cppreference.com/w/cpp/language/string_literal

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, Thank you!

"x_num_row_dims",
"mul_op can take tensors with more than two dimensions as input `X`, "
"in that case, tensors will be flattened to a matrix. The matrix's "
Copy link
Contributor

@qingqing01 qingqing01 Sep 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flattened -> reshaped? In the Numpy, flatten means converting to a vector.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"second dimension(row length) will be the product of tensor's last "
"`num_row_dims` dimensions, and the matrix's first dimension(column "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

second dimension(row length)
matrix's first dimension(column length)

matrix's first dimension是 dims[0]? second dimension是dims[1]吗? 如果是,matrix's first dimension表示的row length(也就是height), second dimension表示的是col length(也就是width)。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我想表达的row length的意思是“行的长度”,所以似乎应该是width?

"length) will be the product of tensor's first `rank - num_row_dims` "
"dimensions.")
.SetDefault(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

依据上面的描述,和最常用的情况不符合,最常用的是reshape成:height = dims[0], width = product(dims[1:])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改,把参数从num_raw_dims改成了num_col_dims,表示乘起来的前面维度的数目

.EqualLargerThan(1);
AddAttr<int>(
"y_num_row_dims",
"mul_op can take tensors with more than two dimensions as input `Y`, "
"in that case, tensors will be flattened to a matrix. Just like input "
"`X`.")
.SetDefault(1)
.EqualLargerThan(1);
AddComment(R"DOC(
Two Element Mul Operator.

Expand All @@ -70,10 +96,16 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE(x_dims[0] == out_dims[0],
"Out@GRAD M X N must equal to X dims 0, M ");
PADDLE_ENFORCE(y_dims[1] == out_dims[1],
"Out@GRAD M X N must equal to Y dims 1, N ");
PADDLE_ENFORCE(
product(x_dims, 0, x_dims.size() - GetAttr<int>("x_num_row_dims")) ==
out_dims[0],
"The first dimension of Out@GRAD must equal to the first dimension of "
"the first operand.");
PADDLE_ENFORCE(
product(y_dims, y_dims.size() - GetAttr<int>("y_num_row_dims"),
y_dims.size()) == out_dims[1],
"The second dimension of Out@GRAD must equal to the second "
"dimension of the second operand.");

if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
Expand Down
56 changes: 42 additions & 14 deletions paddle/operators/mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
you may obtain a copy of the License at
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you -> You

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANy KIND, either express or implied.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ANy -> ANY

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

See the License for the specific language governing permissions and
limitations under the License. */

Expand All @@ -31,37 +31,65 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
z->mutable_data<T>(context.GetPlace());
const Tensor* x = context.Input<Tensor>("X");
const Tensor* y = context.Input<Tensor>("Y");
Tensor* Z = context.Output<Tensor>("Out");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Z -> z

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const Tensor x_matrix =
x->dims().size() > 2
? framework::FlattenToMatrix<T>(
*x, context.template GetAttr<int>("x_num_row_dims"))
: *x;
const Tensor y_matrix =
y->dims().size() > 2
? framework::FlattenToMatrix<T>(
*y, context.template GetAttr<int>("y_num_row_dims"))
: *y;

Z->mutable_data<T>(context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
math::matmul<Place, T>(*x, false, *y, false, 1, z, 0, device_context);
math::matmul<Place, T>(x_matrix, false, y_matrix, false, 1, Z, 0,
device_context);
}
};

template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
int x_num_row_dims = ctx.template GetAttr<int>("x_num_row_dims");
int y_num_row_dims = ctx.template GetAttr<int>("y_num_row_dims");
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y");
const Tensor x_matrix =
x->dims().size() > 2 ? framework::FlattenToMatrix<T>(*x, x_num_row_dims)
: *x;
const Tensor y_matrix =
y->dims().size() > 2 ? framework::FlattenToMatrix<T>(*y, y_num_row_dims)
: *y;
const Tensor* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));

auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
Tensor dx_matrix = dx->dims().size() > 2 ? framework::FlattenToMatrix<T>(
*dx, x_num_row_dims)
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
math::matmul<Place, T>(*dout, false, *y, true, 1, dx, 0, device_context);
math::matmul<Place, T>(*dout, false, y_matrix, true, 1, &dx_matrix, 0,
device_context);
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
Tensor dy_matrix = dy->dims().size() > 2 ? framework::FlattenToMatrix<T>(
*dy, y_num_row_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
math::matmul<Place, T>(*x, true, *dout, false, 1, dy, 0, device_context);
math::matmul<Place, T>(x_matrix, true, *dout, false, 1, &dy_matrix, 0,
device_context);
}
}
};
Expand Down