-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add rnn op interfaces #2775
add rnn op interfaces #2775
Conversation
|
||
protected: | ||
std::vector<std::string> inputs_; | ||
std::vector<std::string> outputs_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add attributes
TODO 最终测试时,添加简单RNN需要的几个op,debug |
std::vector<std::string> outputs_; | ||
} | ||
|
||
class RecurrentForwardOp { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- RecurrentGroupOp
- RecurrentGroupGradientOp
std::vector<StateAttr> states_; | ||
}; | ||
|
||
class RecurrentBackwardOp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
贴上对应的接口
/* | ||
* Prepare inputs for each stepnet. | ||
*/ | ||
void ApplyInLinks(Scope* scope); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ScatterInLinks(Scope* scope)
GatherOutLinks(Scope* scope)
/* | ||
* Process outputs of stepnets and merge to variables. | ||
*/ | ||
void ApplyOutLinks(Scope* scope); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/* | ||
* Build a `Net` which is shared across all steps. | ||
*/ | ||
void BuildStepNet(Scope* scope); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* step scopes as the father scope. The step scopes will be stored in the | ||
* father scope as a variable. | ||
*/ | ||
void CreateScopes(Scope* scope); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*/ | ||
|
||
// State of a RNN (same as the role of `Momory` in PaddlePaddle) | ||
struct StateAttr { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按照Paddle的习惯叫Memory
吧
* NOTE the context's scope is not given until `Run` called, so step scopes' | ||
* father should be set/updated in this method. | ||
*/ | ||
virtual void Run(OpRunContext* contex) const = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
先写 Run,确定接口调用顺序
public: | ||
virtual ~OperatorBase() {} | ||
virtual void Run(OpRunContext* context) const = 0; | ||
virtual void InferShape(const Scope* scope) const = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does InferShape do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the purpose of InferShape is to inference the size of inputs/outputs from some of them that we already know the size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
InferShape will set the output variable dim according to the input variable dim.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RNNOp.InferShape will just call its step net's InferShape, and will
- check input variable/tensors' shape, raise an error if wrong
- update all outputs' variable/tensors' shape according to this mini-batch of input
It is offered as a public method because we want to keep checking dynamically during user adding operators.
* NOTE the context's scope is not given until `Run` called, so step scopes' | ||
* father should be set/updated in this method. | ||
*/ | ||
virtual void Run(OpRunContext* contex) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be in .cpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will move to .cpp later.
We are working on a simple implementation to verify the whole process and will give a version soon.
std::vector<std::string> outputs_; | ||
} | ||
|
||
class RecurrentGroupForwardOp { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RecurrentGroupForwardOp => RecurrentOp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good, short enough.
the backward op's name?
RecurrentBackwardOp
- or
RecurrentGradientOp
?
/* | ||
* Prepare inputs for each stepnet. | ||
*/ | ||
void ScatterInLinks(Scope* scope); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ScatterInLinks => SegmentInputs. Let us use accurate English wording.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
/* | ||
* Process outputs of stepnets and merge to variables. | ||
*/ | ||
void GatherOutLinks(Scope* scope); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GatherOutLinks => ConcatenateOutputs
/* | ||
* Build a `Net` which is shared across all steps. | ||
*/ | ||
void BuildStepNet(Scope* scope); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BuildStepNet => CreateStepNet
|
||
class OperatorBase { | ||
public: | ||
virtual ~OperatorBase() {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looks to me that the constructor needs a parameter paddle::framework::proto::OperatorDesc
so could it possble to call InferShape
, which saves sizes of inputs/outputs into the desc. Only if so, we could have all necessary information for calling OperatorBase::Run
:
class OperatorBase {
public:
OperatorBase(const proto::OperatorDesc& desc) : desc_(desc) {}
virtual void Run(OpRunContext* context) const = 0;
protected:
virtual void InferShape(const Scope* scope) const = 0; // needs to read from and write to desc_
proto::OperatorDesc desc_;
};
So the information in proto::OperatorDesc
propagates along the path:
Operator's constructor
① ↓
OperatorBase::desc_ → Operator's Run
②↓ ↑③ ④
Operator's InferShape
@Superjom @reyoung @jacquesqiao
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the new design of Operator, OpDesc will store in Op, and InferShape can get the information from scope, but it seems that it need not store the shape into the desc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jacquesqiao You are right.
The first clue about input/output sizes is in training data instances, and we get the instance when we do training, i.e,. call operator's Run
.
Should we just remove InferShape
and let each operator defines its own shape inference methods, i.e., one method for an output, so to shorten code in its Run method like this:
template <typename Context> class MyOperator;
template <>
class MyOperator<GPUContext> : public OperatorBase {
public:
MyOperator(const proto::OperatorDesc& desc) : OperatorBase(desc) {}
virtual void Run(OpRunContext* ctx) const {
cudnnGemm(
ctx->cudnnHandle,
Output(0, ctx)->GetMutable<Tensor>(Output0Size(ctx))->mutable_data(),
Input(0, ctx)->Get<Tensor>()->data(),
Input(1, ctx)->Get<Tensor>()->data(),
);
}
private:
DDim Output0Size(OpRunContext* ctx) const { ...}
DDim Output1Size(OpRunContext* ctx) const { ...}
};
for (size_t i = step_scopes->size(); i < seq_len; ++i) { | ||
// step_scopes->push_back(std::make_shared<Scope>( | ||
// std::shared_ptr<Scope>(scope))); | ||
step_scopes->push_back(new Scope(std::shared_ptr<Scope>(scope))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shared_ptr
is used for the scope constructor. I created an issue.
} | ||
} | ||
|
||
// void RecurrentOp::CreateStepNet(ScopePtr scope) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete these codes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input/output_alias和inlinks/outlinks有什么区别呢?前者是每个step中的名字?
auto dims = Input(scope, inlinks_[0])->GetMutable<Tensor>()->dims(); | ||
size_t seq_len = dims[0]; | ||
Variable* scopes_var = scope->GetVariable(step_scopes_name_); | ||
auto step_scopes = scopes_var->GetMutable<std::vector<ScopePtr>>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 116行和117行,在SegmentInput和ConcatOutput中都重复出现,可以把step_scopes作为一个变量存在.h文件中以保持代码整洁么
- seq_len也可以作为一个变量存在.h文件,名字可以叫maxSequenceLength_。因为很多地方都会用到,就不用重复取了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
弄成两个函数。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
// NetDesc desc; | ||
// desc.name_ = "rnn_step_net"; | ||
// var->Reset<PlainNet>(new PlainNet(desc)); | ||
// } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
127-140行可以去掉么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
LOG(INFO) << "create scopes"; | ||
CreateScopes(scope); | ||
LOG(INFO) << "segment input"; | ||
SegmentInputs(scope); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
32和34行和53行的LOG可以都打在函数里面。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试完后会全部删掉。
LOG(INFO) << "run step " << step_id; | ||
ScopePtr step_scope = scopes[step_id]; | ||
// TODO replace memorys' copy with reference | ||
LinkMemories(scope, scopes, step_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数里,scope, scopes挺难区分的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
留第二个变量。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LinkMemories(scope, scopes, step_id); -> LinkMemories(scopes, step_id);
scopes - > step_scopes
|
||
void RecurrentOp::LinkMemories(ScopePtr scope, | ||
std::vector<ScopePtr>& step_scopes, | ||
size_t step) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有一个疑问,传了scope后,为什么还要传step_scopes,后者可以从前者获得吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
#include "paddle/framework/scope.h" | ||
#include "paddle/framework/variable.h" | ||
|
||
// Remove when including operator.h |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在可以加operator.h了么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂时不加。
// TODO: | ||
// 1. No-padding computing for sequences with indifinite length in one batch. | ||
// 2. Hierarchical RNN for sequence with sub-sequence. | ||
// 3. Multi-inputs with indifinate length for RecurrentOp. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
第3点可以去掉了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
换成External Memory
/* | ||
* Create a `Net` which is shared across all steps. | ||
*/ | ||
// void CreateStepNet(ScopePtr scope) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
148-151行可以去掉了么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
/* | ||
* Create memories in each step scope. | ||
*/ | ||
// void CreateMemories(ScopePtr scope) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
163-166行可以去掉了么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
* Link memory in previous step scope to current scope. | ||
*/ | ||
void LinkMemories(ScopePtr scope, std::vector<ScopePtr>& step_scopes, | ||
size_t step) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最后一个变量step改成step_id,是否更容易理解
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
outputs_.push_back(output); | ||
} | ||
|
||
name_ = op_desc.name(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name will be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
for (int j = 0; j < seq_len; j++) { | ||
Variable* input_var = step_scopes[j]->CreateVariable(input_alias[i]); | ||
Tensor* step_input_tensor = input_var->GetMutable<Tensor>(); | ||
*step_input_tensor = scope_input_tensor->Slice(j, j + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Slice
operation will not change the rank of the tensor. If the input tensor shape is [10, 20, 30], the step_input_tensor
shape is [1, 20, 30]. See the code: https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/tensor.h#L82
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不需要改变tensor的顺序,每个step的tensor取得对应的那部分就行了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor里面加reshape函数。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to reshape
auto& step_scopes = *scopes_var->GetMutable<std::vector<ScopePtr>>(); | ||
auto dims = Input(scope, inlinks_[0])->GetMutable<Tensor>()->dims(); | ||
int seq_len = dims[0]; | ||
int batch_size = dims[1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to support more than 3 dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
等tensor.h中copyFrom函数merge了,这里就进行更改。
PADDLE_ENFORCE(boot_tensor, "boot_tensor should be retrieved before"); | ||
// copy from boot memory | ||
std::memcpy(memory_tensor_val, boot_tensor->data<float>(), | ||
product(attr.dims)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use ShareDataFrom
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
step_scopes[step - 1]->GetVariable(attr.var)->GetMutable<Tensor>(); | ||
|
||
std::memcpy(memory_tensor_val, pre_step_memory->data<float>(), | ||
product(attr.dims)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use ShareDataFrom
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
/* | ||
* RecurrentOp inputs stored in proto: | ||
* - real inputs that need to be segmented to steps. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weights in step net.
|
||
// this op's name, used as a unique key in father scope. | ||
// TODO repace it with OpBase's interface if supported. | ||
std::string name_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
op_desc.add_inputs("rnn/h_pre"); | ||
op_desc.add_inputs("rnn/w"); | ||
op_desc.add_outputs("rnn/s"); | ||
// s = h_pre * check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s = rnn/h_pre * rnn/w
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
名字改成rnn.h_pre
,rnn.w
这样的比较自然些吧。不然s = rnn/h_pre * rnn/w
会认为是除号。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
op_desc.add_inputs("rnn/x"); | ||
op_desc.add_inputs("rnn/s"); | ||
op_desc.add_outputs("rnn/h"); | ||
// h = x + s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rnn/h = rnn/x + rnn/s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
std::string step_scopes_name_; | ||
// real inputs that need to be segmented. | ||
std::vector<int> inlinks_; | ||
std::vector<std::string> outlinks_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inlinks和outlinks是对应的,为什么一个存int,一个存string?
|
||
// Remove when including operator.h | ||
#include <glog/logging.h> | ||
#include "paddle/framework/attr_checker.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和18行重复了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
// std::ostringstream stream; | ||
// op_desc.SerializeToOstream(&stream); | ||
// std::string text = stream.str(); | ||
// LOG(INFO) << text; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
196-199行可以删了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
CreateRNNOp(); | ||
} | ||
|
||
virtual void TearDown() override {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数是用来做什么呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
// PaddlePaddle), which is store in father's scope. | ||
std::string boot_var; | ||
// this dim will infered from boot memories's tensor in the first step. | ||
DDim dims; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove DDim dims
after using SharedDataFrom
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
paddle/framework/op_desc.proto
Outdated
@@ -51,6 +51,9 @@ message OpDesc { | |||
// type of this Operator, such as "add", "sub", "fc". | |||
required string type = 3; | |||
|
|||
// the name of this Operator. | |||
required string name = 4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
LOG(INFO) << "run step " << step_id; | ||
ScopePtr step_scope = scopes[step_id]; | ||
// TODO replace memorys' copy with reference | ||
LinkMemories(scope, scopes, step_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LinkMemories(scope, scopes, step_id); -> LinkMemories(scopes, step_id);
scopes - > step_scopes
LOG(INFO) << "segment input"; | ||
SegmentInputs(scope); | ||
|
||
Variable* step_scopes = scope->GetVariable(step_scopes_name_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename step_scopes
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据评论,删去不必要的函数、成员变量和注释。
LOG(INFO) << "create scopes"; | ||
CreateScopes(scope); | ||
LOG(INFO) << "segment input"; | ||
SegmentInputs(scope); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试完后会全部删掉。
LOG(INFO) << "run step " << step_id; | ||
ScopePtr step_scope = scopes[step_id]; | ||
// TODO replace memorys' copy with reference | ||
LinkMemories(scope, scopes, step_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
留第二个变量。
auto dims = Input(scope, inlinks_[0])->GetMutable<Tensor>()->dims(); | ||
size_t seq_len = dims[0]; | ||
Variable* scopes_var = scope->GetVariable(step_scopes_name_); | ||
auto step_scopes = scopes_var->GetMutable<std::vector<ScopePtr>>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
弄成两个函数。
for (int j = 0; j < seq_len; j++) { | ||
Variable* input_var = step_scopes[j]->CreateVariable(input_alias[i]); | ||
Tensor* step_input_tensor = input_var->GetMutable<Tensor>(); | ||
*step_input_tensor = scope_input_tensor->Slice(j, j + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor里面加reshape函数。
#include "paddle/framework/scope.h" | ||
#include "paddle/framework/variable.h" | ||
|
||
// Remove when including operator.h |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂时不加。
/* | ||
* Create memories in each step scope. | ||
*/ | ||
// void CreateMemories(ScopePtr scope) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
// this op's name, used as a unique key in father scope. | ||
// TODO repace it with OpBase's interface if supported. | ||
std::string name_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
// std::ostringstream stream; | ||
// op_desc.SerializeToOstream(&stream); | ||
// std::string text = stream.str(); | ||
// LOG(INFO) << text; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
op_desc.add_inputs("rnn/h_pre"); | ||
op_desc.add_inputs("rnn/w"); | ||
op_desc.add_outputs("rnn/s"); | ||
// s = h_pre * check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
op_desc.add_inputs("rnn/x"); | ||
op_desc.add_inputs("rnn/s"); | ||
op_desc.add_outputs("rnn/h"); | ||
// h = x + s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
ConcatOutputs(scope); | ||
} | ||
|
||
void RecurrentOp::Init(const OpDesc& op_desc, AttributeMap& attrs) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Init是不带参数的,op_desc和attr都可以从Op成员变量中拿出来~小问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里会连到 OpBase.Run 里,
if (!is_inited) Init(...);
只是还没引入 OpBase
resolve #2801