-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add rnn op interfaces #2775
add rnn op interfaces #2775
Changes from 17 commits
c418dac
6042795
13d8ca9
a645ae6
8640f96
d4cde51
6e99289
63b5841
08f69f6
007ca1e
2538b2f
5eb87f0
4dcb02e
ca53f3a
671cc26
1e48cc8
e0cbcd0
f7916a6
089c448
bffd11e
c7947de
94766b6
6dca711
eabf1bf
d210b0b
6674fee
778ebb4
c60ed35
8642b27
b0938ed
3921fbb
244fe51
020c189
8e70b37
4150fa7
1584414
ce802c0
a883b4c
b98cae4
a81be58
acde9b7
638384e
82464f5
bbcc149
c92ce74
5c5d890
522445b
01f20be
08003de
a6483e8
7b1d123
bcd03bf
de319bb
0a4a502
e64b5d3
e700bf6
f525390
3a27b02
aede869
45682d2
497c7ff
fc5acee
14dd843
3c15641
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
#include <glog/logging.h> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copy right There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
#include <cstring> | ||
|
||
#include "paddle/framework/recurrent_network_op.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the corresponding .h file of this .cc file should be the first include. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
#include "paddle/framework/tensor.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
// fake op implementations | ||
namespace fake { | ||
class FcOp : public OperatorBase { | ||
public: | ||
FcOp(NetDesc& net_desc) : name_(net_desc.name_) {} | ||
|
||
virtual void InferShape(const Scope* scope) const override { | ||
LOG(INFO) << "fc InferShape"; | ||
} | ||
|
||
virtual void Run(OpRunContext* contex) const override { | ||
LOG(INFO) << "fc Run"; | ||
} | ||
|
||
private: | ||
std::string name_; | ||
}; | ||
|
||
class SGDOptimizerOp : public OperatorBase { | ||
public: | ||
SGDOptimizerOp(NetDesc& net_desc) : name_(net_desc.name_) {} | ||
|
||
virtual void InferShape(const Scope* scope) const override { | ||
LOG(INFO) << "optimizer InferShape"; | ||
} | ||
|
||
virtual void Run(OpRunContext* contex) const override { | ||
LOG(INFO) << "optimizer Run"; | ||
} | ||
|
||
private: | ||
std::string name_; | ||
}; | ||
}; // namespace fake | ||
|
||
void RecurrentOp::Run(OpRunContext* contex) const { | ||
auto scope = contex->scope; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. contex => context or ctx? |
||
|
||
if (!scope->HasVariable(net_name_)) { | ||
CreateStepNet(scope); | ||
} | ||
Variable* net = scope->GetVariable(net_name_); | ||
PADDLE_ENFORCE(net, "failed to get step net"); | ||
|
||
CreateScopes(scope); | ||
SegmentInputs(scope); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 32和34行和53行的LOG可以都打在函数里面。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 测试完后会全部删掉。 |
||
|
||
Variable* step_scopes = scope->GetVariable(step_scopes_name_); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename |
||
PADDLE_ENFORCE(step_scopes, "failed to get step scopes"); | ||
// forward | ||
auto dims = Input(scope, 0)->GetMutable<Tensor>()->dims(); | ||
size_t seq_len = dims[1]; | ||
auto& scopes = *step_scopes->GetMutable<std::vector<Scope*>>(); | ||
for (size_t step_id = 0; step_id < seq_len; step_id++) { | ||
Scope* step_scope = scopes[step_id]; | ||
// TODO replace memorys' copy with reference | ||
LinkMemories(scope, scopes, step_id); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个函数里,scope, scopes挺难区分的。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 留第二个变量。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LinkMemories(scope, scopes, step_id); -> LinkMemories(scopes, step_id); |
||
|
||
net->GetMutable<PlainNet>()->Run(step_scope); | ||
} | ||
|
||
// prepare outputs | ||
ConcateOutputs(scope); | ||
} | ||
|
||
void RecurrentOp::CreateScopes(Scope* scope) const { | ||
auto dims = Input(scope, 0)->GetMutable<Tensor>()->dims(); | ||
size_t seq_len = dims[1]; | ||
Variable* scopes_var = scope->GetVariable(step_scopes_name_); | ||
// auto step_scopes = | ||
// scopes_var->GetMutable<std::vector<std::shared_ptr<Scope>>>(); | ||
auto step_scopes = scopes_var->GetMutable<std::vector<Scope*>>(); | ||
// TODO Only two scopes are needed for inference, this case will be supported | ||
// later. | ||
if (seq_len > step_scopes->size()) { | ||
for (size_t i = step_scopes->size(); i < seq_len; ++i) { | ||
// step_scopes->push_back(std::make_shared<Scope>( | ||
// std::shared_ptr<Scope>(scope))); | ||
step_scopes->push_back(new Scope(std::shared_ptr<Scope>(scope))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
} | ||
} | ||
} | ||
|
||
void RecurrentOp::CreateStepNet(Scope* scope) const { | ||
Variable* var = scope->CreateVariable(net_name_); | ||
auto step_net = GetAttr<std::string>("step_net"); | ||
// get the step net proto from the string. | ||
// PADDLE_ENFORCE( | ||
// google::protobuf::TextFormat::ParseFromString(step_net, | ||
// &step_net_desc_)); | ||
// this is a fake net, it will be rewrite after the network has been merged. | ||
var->Reset<PlainNet>(new PlainNet(step_net)); | ||
} | ||
|
||
void RecurrentOp::LinkMemories(Scope* scope, std::vector<Scope*>& step_scopes, | ||
size_t step) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有一个疑问,传了scope后,为什么还要传step_scopes,后者可以从前者获得吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
PADDLE_ENFORCE(step < step_scopes.size(), | ||
"step [%d] out of range of step scopes' size [%d]", step, | ||
step_scopes.size()); | ||
// copy boot memory | ||
for (auto& attr : memory_attrs_) { | ||
Scope* step_scope = step_scopes[step]; | ||
|
||
Tensor* boot_tensor{nullptr}; | ||
Variable* memory_var = step_scope->CreateVariable(attr.pre_var); | ||
if (step == 0) { | ||
PADDLE_ENFORCE(scope->HasVariable(attr.boot_var), | ||
"memory [%s]'s boot variable [%s] not exists", attr.var, | ||
attr.boot_var); | ||
// update memory's ddim | ||
boot_tensor = scope->CreateVariable(attr.boot_var)->GetMutable<Tensor>(); | ||
attr.dims = boot_tensor->dims(); | ||
} | ||
|
||
// copy from boot memory | ||
// TODO support more device | ||
float* memory_tensor_val = | ||
memory_var->GetMutable<Tensor>()->mutable_data<float>( | ||
attr.dims, platform::CPUPlace()); | ||
if (step == 0) { | ||
PADDLE_ENFORCE(boot_tensor, "boot_tensor should be retrieved before"); | ||
// copy from boot memory | ||
std::memcpy(memory_tensor_val, boot_tensor->data<float>(), | ||
product(attr.dims)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} else { | ||
// copy from previous step scope's memory to this scope's `pre-memory` | ||
Tensor* pre_step_memory = | ||
step_scopes[step - 1]->GetVariable(attr.var)->GetMutable<Tensor>(); | ||
std::memcpy(memory_tensor_val, pre_step_memory->data<float>(), | ||
product(attr.dims)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的memcpy可以用tensor.h中最新的shardData来写吧,可以节省很多代码。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
} | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <google/protobuf/text_format.h> | ||
#include "paddle/framework/attr_checker.h" | ||
#include "paddle/framework/ddim.h" | ||
#include "paddle/framework/enforce.h" | ||
#include "paddle/framework/scope.h" | ||
#include "paddle/framework/variable.h" | ||
|
||
// Remove when including operator.h | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 现在可以加operator.h了么 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 暂时不加。 |
||
#include "paddle/framework/attr_checker.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 和18行重复了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
#include "paddle/framework/op_desc.pb.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
// -------------------------------------------------------------------- | ||
// fake interfaces that has not be implemented by other modules. | ||
// TODO keep updating according to other modules' designs. | ||
struct OpRunContext { | ||
Scope* scope; | ||
}; | ||
|
||
// TODO replace this with Net's proto. | ||
struct NetDesc { | ||
std::string name_; | ||
}; | ||
|
||
class PlainNet { | ||
public: | ||
PlainNet() {} | ||
PlainNet(const NetDesc& desc) {} | ||
PlainNet(const std::string desc) {} | ||
void Run(Scope* scope) {} | ||
}; | ||
|
||
class OperatorBase { | ||
public: | ||
virtual ~OperatorBase() {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I looks to me that the constructor needs a parameter class OperatorBase {
public:
OperatorBase(const proto::OperatorDesc& desc) : desc_(desc) {}
virtual void Run(OpRunContext* context) const = 0;
protected:
virtual void InferShape(const Scope* scope) const = 0; // needs to read from and write to desc_
proto::OperatorDesc desc_;
}; So the information in
@Superjom @reyoung @jacquesqiao There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the new design of Operator, OpDesc will store in Op, and InferShape can get the information from scope, but it seems that it need not store the shape into the desc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jacquesqiao You are right. The first clue about input/output sizes is in training data instances, and we get the instance when we do training, i.e,. call operator's Should we just remove template <typename Context> class MyOperator;
template <>
class MyOperator<GPUContext> : public OperatorBase {
public:
MyOperator(const proto::OperatorDesc& desc) : OperatorBase(desc) {}
virtual void Run(OpRunContext* ctx) const {
cudnnGemm(
ctx->cudnnHandle,
Output(0, ctx)->GetMutable<Tensor>(Output0Size(ctx))->mutable_data(),
Input(0, ctx)->Get<Tensor>()->data(),
Input(1, ctx)->Get<Tensor>()->data(),
);
}
private:
DDim Output0Size(OpRunContext* ctx) const { ...}
DDim Output1Size(OpRunContext* ctx) const { ...}
}; |
||
void Init(const OpDesc& op_desc, AttributeMap& attrs) {} | ||
virtual void Run(OpRunContext* context) const = 0; | ||
virtual void InferShape(const Scope* scope) const = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does InferShape do? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the purpose of InferShape is to inference the size of inputs/outputs from some of them that we already know the size. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. InferShape will set the output variable dim according to the input variable dim. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RNNOp.InferShape will just call its step net's InferShape, and will
It is offered as a public method because we want to keep checking dynamically during user adding operators. |
||
inline Variable* Input(Scope* scope, int index) const { | ||
return scope->GetVariable(inputs_[index]); | ||
}; | ||
|
||
template <typename T> | ||
inline const T GetAttr(const std::string& name) const { | ||
return boost::get<T>(attrs_.at(name)); | ||
} | ||
|
||
protected: | ||
std::vector<std::string> inputs_; | ||
std::vector<std::string> outputs_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add attributes |
||
AttributeMap attrs_; | ||
}; | ||
// fake interfaces end | ||
// -------------------------------------------------------------------- | ||
// TODO: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to define the data structure for sequence. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The sequence format in RecurrentOp is Tensor<seq_len, batch_size, dim> now, we consider all the sequences have the same length. After this simplest version, we will design new sequence format like sequenceStartPositions of original Paddle. |
||
// 1. No-padding computing for sequences with indifinite length in one batch. | ||
// 2. Hierarchical RNN for sequence with sub-sequence. | ||
// 3. Multi-inputs with indifinate length for RecurrentOp. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 第3点可以去掉了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 换成External Memory |
||
class RecurrentOp : public OperatorBase { | ||
public: | ||
void Init(const OpDesc& op_desc, AttributeMap& attrs) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only the implementation of very short functions can be in .h file. Other should be put into .cc file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
OperatorBase::Init(op_desc, attrs); | ||
name_ = op_desc.name(); | ||
net_name_ = op_desc.name() + "_net"; | ||
step_scopes_name_ = op_desc.name() + "_step_scopes"; | ||
auto memories = GetAttr<std::vector<std::string>>("memories"); | ||
auto boot_memories = GetAttr<std::vector<std::string>>("boot_memories"); | ||
PADDLE_ENFORCE(memories.size() == boot_memories.size(), | ||
"The size of memories and boot_memories is mismatched."); | ||
for (size_t i = 0; i < memories.size(); ++i) { | ||
MemoryAttr mem_attr; | ||
mem_attr.var = memories[i]; | ||
mem_attr.boot_var = boot_memories[i]; | ||
memory_attrs_.push_back(mem_attr); | ||
} | ||
} | ||
|
||
virtual void InferShape(const Scope* scope) const override; | ||
|
||
/* | ||
* Forward run the RNN. | ||
* | ||
* NOTE the context's scope is not given until `Run` called, so step scopes' | ||
* father should be set/updated in this method. | ||
*/ | ||
virtual void Run(OpRunContext* contex) const override; | ||
|
||
protected: | ||
/* | ||
* Prepare inputs for each stepnet. | ||
*/ | ||
void SegmentInputs(Scope* scope) const {}; | ||
|
||
/* | ||
* Process outputs of stepnets and merge to variables. | ||
*/ | ||
void ConcateOutputs(Scope* scope) const {}; | ||
|
||
/* | ||
* Create a `Net` which is shared across all steps. | ||
*/ | ||
void CreateStepNet(Scope* scope) const; | ||
|
||
/* | ||
* Create a scope for each step, the context's scope is shared across all | ||
* the step scopes as the father scope. The step scopes will be stored in | ||
* the father scope as a variable whose name is specified by | ||
* `step_scopes_name_`. | ||
* | ||
* NOTE the scopes are reused by both the `Forward` and `Backward`, so just | ||
* create once and expand its size if more steps need. | ||
*/ | ||
void CreateScopes(Scope* scope) const; | ||
|
||
/* | ||
* Create memories in each step scope. | ||
*/ | ||
// void CreateMemories(Scope* scope) const; | ||
|
||
/* | ||
* Link memory in previous step scope to current scope. | ||
*/ | ||
void LinkMemories(Scope* scope, std::vector<Scope*>& step_scopes, | ||
size_t step) const; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 最后一个变量step改成step_id,是否更容易理解 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
private: | ||
/* | ||
* Memory of a RNN (same as the role of `Momory` in PaddlePaddle). | ||
* | ||
* Memory attributes cached by this op, dims will be infered from | ||
* boot memories in father scope. Other attributes are copied from Op's proto | ||
* attributes. | ||
*/ | ||
struct MemoryAttr { | ||
// name of current state variable | ||
std::string var; | ||
// name of previous step's state variable | ||
std::string pre_var; | ||
// name of the variables to init this memory (same role of `boot_layer` in | ||
// PaddlePaddle), which is store in father's scope. | ||
std::string boot_var; | ||
// this dim will infered from boot memories's tensor in the first step. | ||
DDim dims; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
}; | ||
|
||
/* | ||
* The attributes in protobuf about the memory description and the booted | ||
* memory description are as follows. The number of booted memories should | ||
* equal to the memories number. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是把置0的memory,也当成boot memory么? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 直接可以支持 boot_memory 啊,不需要置0模拟,LinkMemories里面已经支持了 |
||
* | ||
* arg { | ||
* name: “memories” | ||
* strings: "hidden” | ||
* strings: "state” | ||
* } | ||
* arg { | ||
* name: “boot_memories” | ||
* strings: "boot_hidden” | ||
* strings: "boot_state” | ||
* } | ||
*/ | ||
// TODO copy from OpBase's | ||
mutable std::vector<MemoryAttr> memory_attrs_; | ||
|
||
// this op's name, used as a unique key in father scope. | ||
// TODO repace it with OpBase's interface if supported. | ||
std::string name_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this name. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
// name of rnn op's step net, the step net will be shared by both `Forward` | ||
// and `Backward`, so we store it as a variable in father's scope, with a | ||
// unique key specified by `net_name_`. | ||
std::string net_name_; | ||
// name of steps' scopes which is stored in father scope with a unique key | ||
// specified by `step_scopes_name_`. | ||
std::string step_scopes_name_; | ||
|
||
NetDesc step_net_desc_; | ||
}; | ||
|
||
class RecurrentGradientOp; | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
#include "paddle/framework/recurrent_network_op.h" | ||
#include "gtest/gtest.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
class RecurrentOpTest : public ::testing::Test { | ||
protected: | ||
virtual void SetUp() override {} | ||
}; | ||
} // namespace framework | ||
|
||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done