Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify shape inference code #8087

Merged
merged 3 commits into from
Feb 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions paddle/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ class CompileTimeInferShapeContext : public InferShapeContext {

bool HasOutputs(const std::string &name) const override;

DDim GetInputDim(const std::string &name) const override;

void SetOutputDim(const std::string &name, const DDim &dim) override;

AttrReader Attrs() const override;

const std::vector<std::string> &Inputs(
Expand Down Expand Up @@ -444,21 +440,6 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
return true;
}

DDim CompileTimeInferShapeContext::GetInputDim(const std::string &name) const {
std::vector<DDim> ddims = GetInputsDim(name);
auto length = ddims.size();
PADDLE_ENFORCE_EQ(length, 1UL,
"Input(%s) should have 1 value, "
"but it has %d now",
name, length);
return ddims[0];
}

void CompileTimeInferShapeContext::SetOutputDim(const std::string &name,
const DDim &dim) {
SetOutputsDim(name, {dim});
}

AttrReader CompileTimeInferShapeContext::Attrs() const {
return AttrReader(op_.GetAttrMap());
}
Expand Down
8 changes: 0 additions & 8 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
return true;
}

DDim GetInputDim(const std::string& name) const override {
return GetDim(op_.Input(name));
}

void SetOutputDim(const std::string& name, const DDim& dim) override {
SetDim(op_.Output(name), dim);
}

AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }

const std::vector<std::string>& Inputs(
Expand Down
33 changes: 24 additions & 9 deletions paddle/framework/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,18 @@ limitations under the License. */
namespace paddle {
namespace framework {

std::vector<framework::DDim> InferShapeContext::GetInputsDim(
DDim InferShapeContext::GetInputDim(const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, arg_names.size());
return this->GetDim(arg_names[0]);
}

std::vector<DDim> InferShapeContext::GetInputsDim(
const std::string &name) const {
const std::vector<std::string> &names = Inputs(name);
return GetDims(names);
const std::vector<std::string> &arg_names = Inputs(name);
return GetDims(arg_names);
}

DDim InferShapeContext::GetInputsElementDim(const std::string &name,
Expand All @@ -30,24 +38,31 @@ DDim InferShapeContext::GetInputsElementDim(const std::string &name,
return this->GetDim(names[idx]);
}

void InferShapeContext::SetOutputsDim(
const std::string &name, const std::vector<framework::DDim> &dims) {
void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) {
auto &arg_names = Outputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Output(%s) should hold one element, but now it holds %d",
name, arg_names.size());
SetDim(arg_names[0], dim);
}

void InferShapeContext::SetOutputsDim(const std::string &name,
const std::vector<DDim> &dims) {
auto &names = Outputs(name);
SetDims(names, dims);
}

std::vector<framework::DDim> InferShapeContext::GetDims(
std::vector<DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const {
std::vector<framework::DDim> ret;
std::vector<DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}

void InferShapeContext::SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims) {
const std::vector<DDim> &dims) {
size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
Expand Down
19 changes: 8 additions & 11 deletions paddle/framework/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,13 @@ class InferShapeContext {
virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0;

virtual framework::DDim GetInputDim(const std::string &name) const = 0;
DDim GetInputDim(const std::string &name) const;

std::vector<framework::DDim> GetInputsDim(const std::string &name) const;
std::vector<DDim> GetInputsDim(const std::string &name) const;
DDim GetInputsElementDim(const std::string &name, int idx) const;

virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0;
void SetOutputsDim(const std::string &name,
const std::vector<framework::DDim> &dims);
void SetOutputDim(const std::string &name, const DDim &dim);
void SetOutputsDim(const std::string &name, const std::vector<DDim> &dims);

virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs(
Expand All @@ -57,15 +56,13 @@ class InferShapeContext {

// Note: In while op, we need this to be public
void SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims);
const std::vector<DDim> &dims);

protected:
virtual framework::DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0;

std::vector<framework::DDim> GetDims(
const std::vector<std::string> &names) const;
virtual DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const DDim &dim) = 0;

std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarDesc::VarType> GetVarTypes(
const std::vector<std::string> &names) const;

Expand Down