Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type Reader for VarDesc #8135

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ ParamGradInfoMap AppendBackward(
auto root_block = program_desc.MutableBlock(root_block_idx);

std::string fill_one_op_out = GradVarName(target.Name());
bool is_scalar = target.Shape() == std::vector<int64_t>{1};
bool is_scalar = target.GetShape() == std::vector<int64_t>{1};
PADDLE_ENFORCE(is_scalar, "target should be scalar");
VLOG(3) << "backward from loss=" << target.Name()
<< " data_type=" << target.GetDataType();
Expand Down Expand Up @@ -565,7 +565,7 @@ ParamGradInfoMap AppendBackward(

auto var = root_block->Var(fill_one_op_out);
var->SetDataType(target.GetDataType());
var->SetShape(target.Shape());
var->SetShape(target.GetShape());
auto& target_grad = retv[target.Name()];
target_grad.name_ = fill_one_op_out;
target_grad.block_idx_ = root_block_idx;
Expand Down
10 changes: 7 additions & 3 deletions paddle/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ message LoDTensorArrayDesc {
optional int32 lod_level = 2 [ default = 0 ];
}

message Reader { repeated LoDTensorDesc lod_tensor = 1; }

message VarDesc {
enum VarType {
LOD_TENSOR = 1;
Expand All @@ -126,13 +128,15 @@ message VarDesc {
LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
READER = 9;
}
required string name = 1;
required VarType type = 2;
optional LoDTensorDesc lod_tensor = 3;
optional TensorDesc selected_rows = 4;
optional bool persistable = 3 [ default = false ];
optional LoDTensorDesc lod_tensor = 4;
optional TensorDesc selected_rows = 5;
optional LoDTensorArrayDesc tensor_array = 6;
optional bool persistable = 5 [ default = false ];
optional Reader reader = 7;
}

message BlockDesc {
Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,11 @@ DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
try {
auto shape = var->Shape();
auto shape = var->GetShape();
if (shape.empty()) {
return framework::make_ddim({0UL});
} else {
return framework::make_ddim(var->Shape());
return framework::make_ddim(var->GetShape());
}
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/program_desc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ TEST(ProgramDesc, copy_ctor) {
ASSERT_NE(copy, var_before);
ASSERT_EQ(copy->Name(), var_before->Name());
ASSERT_EQ(copy->GetType(), var_before->GetType());
ASSERT_EQ(copy->Shape(), var_before->Shape());
ASSERT_EQ(copy->GetShape(), var_before->GetShape());
ASSERT_EQ(copy->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString());
};
Expand Down Expand Up @@ -117,7 +117,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
ASSERT_NE(restored, var_before);
ASSERT_EQ(restored->Name(), var_before->Name());
ASSERT_EQ(restored->GetType(), var_before->GetType());
ASSERT_EQ(restored->Shape(), var_before->Shape());
ASSERT_EQ(restored->GetShape(), var_before->GetShape());
ASSERT_EQ(restored->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString());
};
Expand Down
174 changes: 163 additions & 11 deletions paddle/framework/var_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,91 @@ void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
}

void VarDesc::SetTensorDescNum(size_t num) {
switch (desc_.type()) {
case proto::VarDesc::READER: {
auto *lod_tensors_ptr = desc_.mutable_reader()->mutable_lod_tensor();
lod_tensors_ptr->Clear();
for (size_t i = 0; i < num; ++i) {
lod_tensors_ptr->Add();
}
return;
} break;
default:
PADDLE_THROW(
"Setting 'sub_tensor_number' is not supported by the type of var %s.",
this->Name());
}
}

size_t VarDesc::GetTensorDescNum() const {
switch (desc_.type()) {
case proto::VarDesc::READER:
return desc_.reader().lod_tensor_size();
break;
default:
PADDLE_THROW(
"Getting 'sub_tensor_number' is not supported by the type of var %s.",
this->Name());
}
}

void VarDesc::SetShapes(
const std::vector<std::vector<int64_t>> &multiple_dims) {
PADDLE_ENFORCE_EQ(multiple_dims.size(), GetTensorDescNum(),
"The number of given shapes(%d) doesn't equal to the "
"number of sub tensor.",
multiple_dims.size(), GetTensorDescNum());
std::vector<proto::TensorDesc *> tensors = mutable_tensor_descs();
for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
}
}

std::vector<int64_t> VarDesc::GetShape() const {
return RepeatedToVector(tensor_desc().dims());
}

std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
std::vector<proto::TensorDesc> descs = tensor_descs();
std::vector<std::vector<int64_t>> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(RepeatedToVector(tensor_desc.dims()));
}
return res;
}

void VarDesc::SetDataType(proto::DataType data_type) {
mutable_tensor_desc()->set_data_type(data_type);
}

std::vector<int64_t> VarDesc::Shape() const {
return RepeatedToVector(tensor_desc().dims());
void VarDesc::SetDataTypes(
const std::vector<proto::DataType> &multiple_data_type) {
PADDLE_ENFORCE_EQ(multiple_data_type.size(), GetTensorDescNum(),
"The number of given data types(%d) doesn't equal to the "
"number of sub tensor.",
multiple_data_type.size(), GetTensorDescNum());
std::vector<proto::TensorDesc *> tensor_descs = mutable_tensor_descs();
for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]);
}
}

proto::DataType VarDesc::GetDataType() const {
return tensor_desc().data_type();
}

std::vector<proto::DataType> VarDesc::GetDataTypes() const {
std::vector<proto::TensorDesc> descs = tensor_descs();
std::vector<proto::DataType> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(tensor_desc.data_type());
}
return res;
}

void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type()) {
case proto::VarDesc::LOD_TENSOR:
Expand All @@ -47,8 +120,28 @@ void VarDesc::SetLoDLevel(int32_t lod_level) {
desc_.mutable_tensor_array()->set_lod_level(lod_level);
break;
default:
PADDLE_THROW("Tensor type=%d does not support LoDLevel",
desc_.tensor_array().lod_level());
PADDLE_THROW(
"Setting 'lod_level' is not supported by the type of var %s.",
this->Name());
}
}

void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
PADDLE_ENFORCE_EQ(multiple_lod_level.size(), GetTensorDescNum(),
"The number of given data types(%d) doesn't equal to the "
"number of sub tensor.",
multiple_lod_level.size(), GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER: {
size_t i = 0;
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
lod_tensor.set_lod_level(multiple_lod_level[i++]);
}
} break;
default:
PADDLE_THROW(
"Setting 'lod_levels' is not supported by the type of var %s.",
this->Name());
}
}

Expand All @@ -59,13 +152,31 @@ int32_t VarDesc::GetLoDLevel() const {
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().lod_level();
default:
PADDLE_THROW("Tensor type=%d does not support LoDLevel",
desc_.tensor_array().lod_level());
PADDLE_THROW(
"Getting 'lod_level' is not supported by the type of var %s.",
this->Name());
}
}

std::vector<int32_t> VarDesc::GetLoDLevels() const {
std::vector<int32_t> res;
switch (desc_.type()) {
case proto::VarDesc::READER:
res.reserve(desc_.reader().lod_tensor_size());
for (auto &lod_tensor : desc_.reader().lod_tensor()) {
res.push_back(lod_tensor.lod_level());
}
return res;
break;
default:
PADDLE_THROW(
"Getting 'lod_levels' is not supported by the type of var %s.",
this->Name());
}
}

const proto::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type");
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
switch (desc_.type()) {
case proto::VarDesc::SELECTED_ROWS:
return desc_.selected_rows();
Expand All @@ -74,13 +185,32 @@ const proto::TensorDesc &VarDesc::tensor_desc() const {
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().tensor();
default:
PADDLE_THROW("The type of var %s is unsupported.", this->Name());
PADDLE_THROW(
"Getting 'tensor_desc' is not supported by the type of var %s.",
this->Name());
}
}

std::vector<proto::TensorDesc> VarDesc::tensor_descs() const {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc> res;
res.reserve(GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER:
for (const auto &lod_tensor : desc_.reader().lod_tensor()) {
res.push_back(lod_tensor.tensor());
}
return res;
default:
PADDLE_THROW(
"Getting 'tensor_descs' is not supported by the type of var "
"%s.",
this->Name());
}
}

proto::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(),
"invoke MutableTensorDesc must after set type");
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
switch (desc_.type()) {
case proto::VarDesc::SELECTED_ROWS:
return desc_.mutable_selected_rows();
Expand All @@ -89,8 +219,30 @@ proto::TensorDesc *VarDesc::mutable_tensor_desc() {
case proto::VarDesc::LOD_TENSOR_ARRAY:
return desc_.mutable_tensor_array()->mutable_tensor();
default:
PADDLE_THROW("Unexpected branch.");
PADDLE_THROW(
"Getting 'mutable_tensor_desc' is not supported by the type of var "
"%s.",
this->Name());
}
}

std::vector<proto::TensorDesc *> VarDesc::mutable_tensor_descs() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
std::vector<proto::TensorDesc *> res;
res.reserve(GetTensorDescNum());
switch (desc_.type()) {
case proto::VarDesc::READER:
for (auto &lod_tensor : *desc_.mutable_reader()->mutable_lod_tensor()) {
res.push_back(lod_tensor.mutable_tensor());
}
return res;
default:
PADDLE_THROW(
"Getting 'tensor_descs' is not supported by the type of var "
"%s.",
this->Name());
}
}

} // namespace framework
} // namespace paddle
20 changes: 19 additions & 1 deletion paddle/framework/var_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,34 @@ class VarDesc {

void SetName(std::string name) { desc_.set_name(name); }

void SetTensorDescNum(size_t num);

size_t GetTensorDescNum() const;

void SetShape(const std::vector<int64_t> &dims);

void SetShapes(const std::vector<std::vector<int64_t>> &multiple_dims);

std::vector<int64_t> GetShape() const;

std::vector<std::vector<int64_t>> GetShapes() const;

void SetDataType(proto::DataType data_type);

std::vector<int64_t> Shape() const;
void SetDataTypes(const std::vector<proto::DataType> &multiple_data_type);

proto::DataType GetDataType() const;

std::vector<proto::DataType> GetDataTypes() const;

void SetLoDLevel(int32_t lod_level);

void SetLoDLevels(const std::vector<int32_t> &multiple_lod_level);

int32_t GetLoDLevel() const;

std::vector<int32_t> GetLoDLevels() const;

proto::VarDesc::VarType GetType() const;

void SetType(proto::VarDesc::VarType type);
Expand All @@ -90,7 +106,9 @@ class VarDesc {

private:
const proto::TensorDesc &tensor_desc() const;
std::vector<proto::TensorDesc> tensor_descs() const;
proto::TensorDesc *mutable_tensor_desc();
std::vector<proto::TensorDesc *> mutable_tensor_descs();

proto::VarDesc desc_;
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/inference/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void LoadPersistables(framework::Executor& executor,
VLOG(3) << "parameter's name: " << var->Name();

framework::VarDesc* new_var = load_block->Var(var->Name());
new_var->SetShape(var->Shape());
new_var->SetShape(var->GetShape());
new_var->SetDataType(var->GetDataType());
new_var->SetType(var->GetType());
new_var->SetLoDLevel(var->GetLoDLevel());
Expand Down
14 changes: 12 additions & 2 deletions paddle/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,20 @@ void BindVarDsec(py::module &m) {
py::return_value_policy::reference)
.def("set_name", &VarDesc::SetName)
.def("set_shape", &VarDesc::SetShape)
.def("set_shapes", &VarDesc::SetShapes)
.def("set_dtype", &VarDesc::SetDataType)
.def("shape", &VarDesc::Shape, py::return_value_policy::reference)
.def("set_dtypes", &VarDesc::SetDataTypes)
.def("set_tensor_num", &VarDesc::SetTensorDescNum)
.def("tensor_num", &VarDesc::GetTensorDescNum)
.def("shape", &VarDesc::GetShape, py::return_value_policy::reference)
.def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference)
.def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference)
.def("dtypes", &VarDesc::GetDataTypes, py::return_value_policy::reference)
.def("lod_level", &VarDesc::GetLoDLevel)
.def("lod_levels", &VarDesc::GetLoDLevels,
py::return_value_policy::reference)
.def("set_lod_level", &VarDesc::SetLoDLevel)
.def("set_lod_levels", &VarDesc::SetLoDLevels)
.def("type", &VarDesc::GetType)
.def("set_type", &VarDesc::SetType)
.def("serialize_to_string", SerializeMessage<VarDesc>)
Expand All @@ -233,7 +242,8 @@ void BindVarDsec(py::module &m) {
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST);
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST)
.value("READER", proto::VarDesc::READER);
}

void BindOpDesc(py::module &m) {
Expand Down
Loading