Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move static variable defined in .cc #2782

Merged
merged 2 commits into from
Jul 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
Expand Down
36 changes: 36 additions & 0 deletions paddle/framework/op_registry.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include <paddle/framework/op_registry.h>

namespace paddle {
namespace framework {

template <>
void AttrTypeHelper::SetAttrType<int>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INT);
}

template <>
void AttrTypeHelper::SetAttrType<float>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOAT);
}

template <>
void AttrTypeHelper::SetAttrType<std::string>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRING);
}

template <>
void AttrTypeHelper::SetAttrType<std::vector<int>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INTS);
}

template <>
void AttrTypeHelper::SetAttrType<std::vector<float>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOATS);
}

template <>
void AttrTypeHelper::SetAttrType<std::vector<std::string>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRINGS);
}
}
}
154 changes: 36 additions & 118 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "paddle/framework/attr_checker.h"

//#include "paddle/framework/op_base.h"
#include <algorithm>
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"

Expand Down Expand Up @@ -64,36 +65,6 @@ struct AttrTypeHelper {
}
};

template <>
void AttrTypeHelper::SetAttrType<int>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INT);
}

template <>
void AttrTypeHelper::SetAttrType<float>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOAT);
}

template <>
void AttrTypeHelper::SetAttrType<std::string>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRING);
}

template <>
void AttrTypeHelper::SetAttrType<std::vector<int>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::INTS);
}

template <>
void AttrTypeHelper::SetAttrType<std::vector<float>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::FLOATS);
}

template <>
void AttrTypeHelper::SetAttrType<std::vector<std::string>>(AttrProto* attr) {
attr->set_type(paddle::framework::AttrType::STRINGS);
}

// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
Expand All @@ -103,22 +74,22 @@ class OpProtoAndCheckerMaker {
protected:
void AddInput(const std::string& name, const std::string& comment) {
auto input = proto_->mutable_inputs()->Add();
*(input->mutable_name()) = name;
*(input->mutable_comment()) = comment;
*input->mutable_name() = name;
*input->mutable_comment() = comment;
}

void AddOutput(const std::string& name, const std::string& comment) {
auto output = proto_->mutable_outputs()->Add();
*(output->mutable_name()) = name;
*(output->mutable_comment()) = comment;
*output->mutable_name() = name;
*output->mutable_comment() = comment;
}

template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment) {
auto attr = proto_->mutable_attrs()->Add();
*(attr->mutable_name()) = name;
*(attr->mutable_comment()) = comment;
*attr->mutable_name() = name;
*attr->mutable_comment() = comment;
AttrTypeHelper::SetAttrType<T>(attr);
return op_checker_->AddAttrChecker<T>(name);
}
Expand All @@ -134,49 +105,51 @@ class OpProtoAndCheckerMaker {
};

class OpRegistry {
typedef std::function<OpBase*()> OpCreator;
using OpCreator = std::function<OpBase*()>;

public:
template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) {
creators_[op_type] = []() { return new OpType; };
OpProto& op_proto = protos_[op_type];
OpAttrChecker& op_checker = op_checkers_[op_type];
creators()[op_type] = [] { return new OpType; };
OpProto& op_proto = protos()[op_type];
OpAttrChecker& op_checker = op_checkers()[op_type];
ProtoMakerType(&op_proto, &op_checker);
PADDLE_ENFORCE(op_proto.IsInitialized() == true,
PADDLE_ENFORCE(op_proto.IsInitialized(),
"Fail to initialize %s's OpProto !", op_type);
}

static OpBase* CreateOp(const OpDesc& op_desc) {
std::string op_type = op_desc.type();
OpBase* op = (creators_.at(op_type))();
(op->inputs_).resize(op_desc.inputs_size());
for (int i = 0; i < op_desc.inputs_size(); ++i) {
(op->inputs_)[i] = op_desc.inputs(i);
}
(op->outputs_).resize(op_desc.outputs_size());
for (int i = 0; i < op_desc.outputs_size(); ++i) {
(op->outputs_)[i] = op_desc.outputs(i);
}
for (int i = 0; i < op_desc.attrs_size(); ++i) {
const AttrDesc& ith_attr = op_desc.attrs(i);
std::string name = ith_attr.name();
(op->attr_map_)[name] = AttrTypeHelper::GetAttrValue(ith_attr);
OpBase* op = creators().at(op_type)();
op->inputs_.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(op->inputs_));
op->outputs_.reserve((size_t)op_desc.outputs_size());
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
std::back_inserter(op->outputs_));
for (auto& attr : op_desc.attrs()) {
op->attr_map_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
}
const OpAttrChecker& op_checker = op_checkers_.at(op_type);
op_checker.Check(op->attr_map_);
op_checkers().at(op_type).Check(op->attr_map_);
return op;
}

private:
static std::unordered_map<std::string, OpCreator> creators_;
static std::unordered_map<std::string, OpProto> protos_;
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
};
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
}

std::unordered_map<std::string, std::function<OpBase*()>> OpRegistry::creators_;
std::unordered_map<std::string, OpProto> OpRegistry::protos_;
std::unordered_map<std::string, OpAttrChecker> OpRegistry::op_checkers_;
static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_;
return protos_;
};

static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
};

template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper {
Expand All @@ -194,60 +167,5 @@ class OpRegisterHelper {
const OpRegisterHelper<__op_class, __op_maker_class> \
__op_class##Register::reg(#__op_type);

// Demos

class CosineOp : public OpBase {
public:
virtual std::string Run() const {
std::string msg = "CosineOp runs! scale = " +
std::to_string(boost::get<float>(attr_map_.at("scale")));
return msg;
}
};

class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("cos");
AddComment("This is cos op");
}
};

REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)

class MyTestOp : public OpBase {
public:
virtual std::string Run() const {
std::string msg =
"MyTestOp runs! test_attr = " +
std::to_string(boost::get<int>(attr_map_.at("test_attr")));
return msg;
}
};

class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op");
}
};

REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)

} // namespace framework
} // namespace paddle
62 changes: 62 additions & 0 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,63 @@
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>

namespace paddle {
namespace framework {
class CosineOp : public OpBase {
public:
virtual std::string Run() const {
std::string msg = "CosineOp runs! scale = " +
std::to_string(boost::get<float>(attr_map_.at("scale")));
return msg;
}
};

class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
AddType("cos");
AddComment("This is cos op");
}
};

REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim)

class MyTestOp : public OpBase {
public:
virtual std::string Run() const {
std::string msg =
"MyTestOp runs! test_attr = " +
std::to_string(boost::get<int>(attr_map_.at("test_attr")));
return msg;
}
};

class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op");
AddOutput("output", "output of cosine op");
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op");
}
};

REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
} // namespace framework
} // namespace paddle

TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
Expand Down Expand Up @@ -120,3 +177,8 @@ TEST(OpRegistry, CustomChecker) {
ASSERT_EQ(debug_str[i], str[i]);
}
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}