-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement basic Load()
and modify example based on updated inference design
#7690
Changes from all commits
34c94d1
9506cc1
54f679e
ddfb3e5
eaaf19b
c06df11
997f5df
bfb82e4
31d249a
2a892a0
fe522b5
c7d6f74
056a71d
cf82b9d
dc1c8ca
b6f62e4
39b9adb
c7f4891
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,9 @@ limitations under the License. */ | |
namespace paddle { | ||
namespace framework { | ||
|
||
const std::string kFeedOpType = "feed"; | ||
const std::string kFetchOpType = "fetch"; | ||
|
||
BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) { | ||
auto *b = desc_.add_blocks(); | ||
b->set_parent_idx(parent.ID()); | ||
|
@@ -64,5 +67,27 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { | |
} | ||
} | ||
|
||
const std::vector<std::string> ProgramDesc::GetFeedVarNames() { | ||
BlockDesc *global_block = blocks_[0].get(); | ||
std::vector<std::string> feed_var_names; | ||
for (auto *op : global_block->AllOps()) { | ||
if (op->Type() == "feed") { | ||
feed_var_names.insert(feed_var_names.begin(), op->Output("Out")[0]); | ||
} | ||
} | ||
return feed_var_names; | ||
} | ||
|
||
const std::vector<std::string> ProgramDesc::GetFetchVarNames() { | ||
BlockDesc *global_block = blocks_[0].get(); | ||
std::vector<std::string> fetch_var_names; | ||
for (auto *op : global_block->AllOps()) { | ||
if (op->Type() == "fetch") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fetch -> kFetchOpType |
||
fetch_var_names.push_back(op->Input("X")[0]); | ||
} | ||
} | ||
return fetch_var_names; | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,10 @@ class ProgramDesc { | |
|
||
proto::ProgramDesc *Proto(); | ||
|
||
const std::vector<std::string> GetFeedVarNames(); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this blank line. |
||
const std::vector<std::string> GetFetchVarNames(); | ||
|
||
private: | ||
proto::ProgramDesc desc_; | ||
|
||
|
This file was deleted.
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feed -> kFeedOpType
and let's rename
feed_var_names
tofeed_target_names
if there is no better candidate.