-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skeleton Of fully connected operator #2945
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,6 @@ limitations under the License. */ | |
|
||
#include <paddle/framework/op_desc.pb.h> | ||
#include <paddle/framework/operator.h> | ||
#include "paddle/framework/net_proto.pb.h" | ||
#include "paddle/framework/op_proto.pb.h" | ||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/framework/scope.h" | ||
|
@@ -41,7 +40,7 @@ namespace framework { | |
class Net : public OperatorBase { | ||
public: | ||
virtual void AddOp(const OperatorPtr& op) = 0; | ||
virtual void CompleteAddOp() = 0; | ||
virtual void CompleteAddOp(bool calc) = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Declared parameter should be same with the implement. |
||
}; | ||
|
||
using NetPtr = std::shared_ptr<Net>; | ||
|
@@ -86,7 +85,7 @@ class PlainNet : public Net { | |
ops_.push_back(op); | ||
} | ||
|
||
void CompleteAddOp() override; | ||
void CompleteAddOp(bool calculate = true) override; | ||
|
||
std::string DebugString() const override; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/framework/net.h" | ||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/framework/operator.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class FullyConnectedOp : public framework::PlainNet { | ||
public: | ||
void Init() override { | ||
AddOp(framework::OpRegistry::CreateOp("mul", | ||
{ | ||
Input("X"), Input("W"), | ||
}, | ||
{Output("before_act")}, | ||
{})); | ||
auto b = Input("b"); | ||
if (b != framework::OperatorBase::EMPTY_VAR_NAME()) { | ||
AddOp(framework::OpRegistry::CreateOp("rowwise_add", | ||
{Output("before_act"), Input("b")}, | ||
{Output("before_act")}, | ||
{})); | ||
} | ||
|
||
auto activation = GetAttr<std::string>("activation"); | ||
AddOp(framework::OpRegistry::CreateOp( | ||
activation, {Output("before_act")}, {Output("Y")}, {})); | ||
CompleteAddOp(false); | ||
} | ||
}; | ||
|
||
class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
FullyConnectedOpMaker(framework::OpProto *proto, | ||
framework::OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("X", "the input of fc operator"); | ||
AddInput("W", "the weight of fc operator"); | ||
AddInput("b", "the bias of fc operator"); | ||
|
||
AddOutput("Y", "the output of fc operator"); | ||
AddOutput( | ||
"before_act", "the before activation output of fc operator", true); | ||
AddAttr<std::string>("activation", "The activation key for fc layer") | ||
.SetDefault("sigmoid") | ||
.InEnum({"sigmoid", "softmax"}); | ||
|
||
//! TODO(yuyang18): Complete comment; | ||
AddComment("FullyConnected Operator"); | ||
} | ||
}; | ||
} // namespace operators | ||
} // namespace paddle | ||
|
||
USE_OP(mul); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请教之后得知,因为 USE_OP 会在当前文件生成一个static变量,而static是private的,在不同文件中彼此不会冲突。所以不存在问题。 |
||
USE_OP(rowwise_add); | ||
USE_OP(sigmoid); | ||
USE_OP(softmax); | ||
|
||
REGISTER_OP(fc, | ||
paddle::operators::FullyConnectedOp, | ||
paddle::operators::FullyConnectedOpMaker); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python | ||
add_op mul_op rowwise_add_op sigmoid_op softmax_op) | ||
add_op fc_op) |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,6 +14,7 @@ limitations under the License. */ | |||||||||
|
||||||||||
#include <Python.h> | ||||||||||
#include <paddle/framework/op_registry.h> | ||||||||||
#include <paddle/framework/operator.h> | ||||||||||
#include <paddle/framework/scope.h> | ||||||||||
#include <paddle/pybind/tensor_bind.h> | ||||||||||
#include <pybind11/numpy.h> | ||||||||||
|
@@ -26,10 +27,7 @@ namespace py = pybind11; | |||||||||
namespace pd = paddle::framework; | ||||||||||
|
||||||||||
USE_OP(add_two); | ||||||||||
USE_OP(softmax); | ||||||||||
USE_OP(mul); | ||||||||||
USE_OP(rowwise_add); | ||||||||||
USE_OP(sigmoid); | ||||||||||
USE_OP_WITHOUT_KERNEL(fc); | ||||||||||
|
||||||||||
PYBIND11_PLUGIN(core) { | ||||||||||
py::module m("core", "C++ core of Paddle Paddle"); | ||||||||||
|
@@ -53,7 +51,9 @@ PYBIND11_PLUGIN(core) { | |||||||||
self.mutable_data<int>(paddle::platform::CPUPlace()); | ||||||||||
}) | ||||||||||
.def("set", paddle::pybind::PyTensorSetFromArray<float>) | ||||||||||
.def("set", paddle::pybind::PyTensorSetFromArray<int>); | ||||||||||
.def("set", paddle::pybind::PyTensorSetFromArray<int>) | ||||||||||
.def("shape", | ||||||||||
[](pd::Tensor& self) { return pd::vectorize(self.dims()); }); | ||||||||||
|
||||||||||
py::class_<pd::Variable>(m, "Variable", R"DOC(Variable Class. | ||||||||||
|
||||||||||
|
@@ -83,15 +83,16 @@ All parameter, weight, gradient are variables in Paddle. | |||||||||
|
||||||||||
//! @note: Be careful! PyBind will return std::string as an unicode, not | ||||||||||
//! Python str. If you want a str object, you should cast them in Python. | ||||||||||
m.def("get_all_op_protos", []() -> std::vector<std::string> { | ||||||||||
m.def("get_all_op_protos", []() -> std::vector<py::bytes> { | ||||||||||
auto& protos = pd::OpRegistry::protos(); | ||||||||||
std::vector<std::string> ret_values; | ||||||||||
std::vector<py::bytes> ret_values; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看起来这里碰到坑了,能记录一下么? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PyBind 11的一些使用问题PyBind11是一个非常简单的库。这个库的使用问题记录在如下几个方面。 内存问题Python使用GC和引用技术并存的内存问题。C++是手动管理内存。这二者之间的内存交换就有一些问题了。 对于简单类型还比较简单,直接返回值类型即可。但是对于复杂类型,就分为三种情况。
字符串问题C++中 而对应的Python对象有三个
C++ 端返回 |
||||||||||
for (auto it = protos.begin(); it != protos.end(); ++it) { | ||||||||||
PADDLE_ENFORCE(it->second.IsInitialized(), | ||||||||||
"OpProto must all be initialized"); | ||||||||||
ret_values.emplace_back(); | ||||||||||
PADDLE_ENFORCE(it->second.SerializeToString(&ret_values.back()), | ||||||||||
std::string str; | ||||||||||
PADDLE_ENFORCE(it->second.SerializeToString(&str), | ||||||||||
"Serialize OpProto Error. This could be a bug of Paddle."); | ||||||||||
ret_values.push_back(py::bytes(str)); | ||||||||||
} | ||||||||||
return ret_values; | ||||||||||
}); | ||||||||||
|
@@ -101,17 +102,26 @@ All parameter, weight, gradient are variables in Paddle. | |||||||||
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME) | ||||||||||
.def("temp", pd::OperatorBase::TMP_VAR_NAME); | ||||||||||
|
||||||||||
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext") | ||||||||||
.def_static("cpu_context", []() -> paddle::platform::DeviceContext* { | ||||||||||
return new paddle::platform::CPUDeviceContext(); | ||||||||||
}); | ||||||||||
|
||||||||||
py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator") | ||||||||||
.def("__str__", &pd::OperatorBase::DebugString) | ||||||||||
.def_static("create", [](const std::string& protobin) { | ||||||||||
pd::OpDesc desc; | ||||||||||
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), | ||||||||||
"Cannot parse user input to OpDesc"); | ||||||||||
PADDLE_ENFORCE(desc.IsInitialized(), | ||||||||||
"User OpDesc is not initialized, reason %s", | ||||||||||
desc.InitializationErrorString()); | ||||||||||
return pd::OpRegistry::CreateOp(desc); | ||||||||||
}); | ||||||||||
.def_static("create", | ||||||||||
[](py::bytes protobin) { | ||||||||||
pd::OpDesc desc; | ||||||||||
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), | ||||||||||
"Cannot parse user input to OpDesc"); | ||||||||||
PADDLE_ENFORCE(desc.IsInitialized(), | ||||||||||
"User OpDesc is not initialized, reason %s", | ||||||||||
desc.InitializationErrorString()); | ||||||||||
return pd::OpRegistry::CreateOp(desc); | ||||||||||
}) | ||||||||||
.def("infer_shape", &pd::OperatorBase::InferShape) | ||||||||||
.def("run", &pd::OperatorBase::Run) | ||||||||||
.def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; }); | ||||||||||
|
||||||||||
return m.ptr(); | ||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
add_python_test(test_framework test_protobuf.py test_scope.py | ||
test_default_scope_funcs.py test_op_creation_methods.py | ||
test_tensor.py) | ||
test_tensor.py test_fc_op.py) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import paddle.v2.framework.core as core | ||
import unittest | ||
import numpy | ||
import paddle.v2.framework.create_op_creation_methods as creation | ||
|
||
|
||
class TestFc(unittest.TestCase): | ||
def test_fc(self): | ||
scope = core.Scope(None) | ||
x = scope.create_var("X") | ||
x_tensor = x.get_tensor() | ||
x_tensor.set_dims([1000, 784]) | ||
x_tensor.alloc_float() | ||
|
||
w = scope.create_var("W") | ||
w_tensor = w.get_tensor() | ||
w_tensor.set_dims([784, 100]) | ||
w_tensor.alloc_float() | ||
|
||
w_tensor.set(numpy.random.random((784, 100)).astype("float32")) | ||
|
||
# Set a real numpy array here. | ||
# x_tensor.set(numpy.array([])) | ||
|
||
op = creation.op_creations.fc(X="X", Y="Y", W="W") | ||
|
||
for out in op.outputs(): | ||
if scope.get_var(out) is None: | ||
scope.create_var(out).get_tensor() | ||
|
||
tensor = scope.get_var("Y").get_tensor() | ||
op.infer_shape(scope) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we decide to do infer_shape in Python, or may change in the future? |
||
self.assertEqual([1000, 100], tensor.shape()) | ||
|
||
ctx = core.DeviceContext.cpu_context() | ||
|
||
op.run(scope, ctx) | ||
|
||
# After complete all ops, check Y is expect or not. | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the meaning of
cacl