Skip to content

Commit

Permalink
Merge pull request #12658 from velconia/port_pybind11
Browse files Browse the repository at this point in the history
Port pybind11 and python code to support py3 CI test
  • Loading branch information
velconia authored Aug 16, 2018
2 parents 91e84d3 + a32ce8c commit 340a104
Show file tree
Hide file tree
Showing 99 changed files with 1,363 additions and 466 deletions.
59 changes: 59 additions & 0 deletions paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,52 @@ std::vector<std::string> OpDesc::AttrNames() const {
}

void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
// NOTICE(minqiyang): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type
// here if we meet this issue
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
if (attr_type == proto::AttrType::INTS &&
boost::get<std::vector<int>>(v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value
const proto::OpProto::Attr &attr = GetProtoAttr(name);
switch (attr.type()) {
case proto::AttrType::BOOLEANS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BOOLEANS";
this->attrs_[name] = std::vector<bool>();
break;
}
case proto::AttrType::INTS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to INTS";
this->attrs_[name] = std::vector<int>();
break;
}
case proto::AttrType::FLOATS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to FLOATS";
this->attrs_[name] = std::vector<float>();
break;
}
case proto::AttrType::STRINGS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to STRINGS";
this->attrs_[name] = std::vector<std::string>();
break;
}
case proto::AttrType::BLOCKS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BLOCKS";
this->SetBlocksAttr(name, std::vector<BlockDesc *>());
return;
}
default:
PADDLE_THROW("Wrong attr type %d", attr.type());
}
need_update_ = true;
return;
}

this->attrs_[name] = v;
need_update_ = true;
}
Expand Down Expand Up @@ -229,6 +275,19 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
return it->second;
}

const proto::OpProto::Attr &OpDesc::GetProtoAttr(
const std::string &name) const {
const proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
for (int i = 0; i != proto.attrs_size(); ++i) {
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
return attr;
}
}

PADDLE_THROW("Attribute %s is not found in proto %s", name, proto.type());
}

Attribute OpDesc::GetNullableAttr(const std::string &name) const {
auto it = attrs_.find(name);
if (it != attrs_.end()) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/op_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class OpDesc {

Attribute GetAttr(const std::string &name) const;

const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const;

Attribute GetNullableAttr(const std::string &name) const;

int GetBlockAttrId(const std::string &name) const;
Expand Down
7 changes: 1 addition & 6 deletions paddle/fluid/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,7 @@ void BindBlockDesc(pybind11::module *m) {
void BindVarDsec(pybind11::module *m) {
pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", "");
var_desc
.def("name",
[](pd::VarDesc &self) {
pybind11::bytes name = self.Name();
return name;
},
pybind11::return_value_policy::reference)
.def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference)
.def("set_name", &pd::VarDesc::SetName)
.def("set_shape", &pd::VarDesc::SetShape)
.def("set_shapes", &pd::VarDesc::SetShapes)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h"
#endif

#include "pybind11/stl.h"

// disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);

Expand Down
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
import paddle.reader
import paddle.dataset
import paddle.batch
import paddle.compat
batch = batch.batch
237 changes: 237 additions & 0 deletions python/paddle/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import six
import math

__all__ = [
'long_type',
'to_text',
'to_bytes',
'round',
'floor_division',
'get_exception_message',
]

if six.PY2:
int_type = int
long_type = long
else:
int_type = int
long_type = int


# str and bytes related functions
def to_text(obj, encoding='utf-8', inplace=False):
"""
All string in PaddlePaddle should be represented as a literal string.
This function will convert object to a literal string without any encoding.
Especially, if the object type is a list or set container, we will iterate
all items in the object and convert them to literal string.
In Python3:
Decode the bytes type object to str type with specific encoding
In Python2:
Decode the str type object to unicode type with specific encoding
Args:
obj(unicode|str|bytes|list|set) : The object to be decoded.
encoding(str) : The encoding format to decode a string
inplace(bool) : If we change the original object or we create a new one
Returns:
Decoded result of obj
"""
if obj is None:
return obj

if isinstance(obj, list):
if inplace:
for i in six.moves.xrange(len(obj)):
obj[i] = _to_text(obj[i], encoding)
return obj
else:
return [_to_text(item, encoding) for item in obj]
elif isinstance(obj, set):
if inplace:
for item in obj:
obj.remove(item)
obj.add(_to_text(item, encoding))
return obj
else:
return set([_to_text(item, encoding) for item in obj])
else:
return _to_text(obj, encoding)


def _to_text(obj, encoding):
"""
In Python3:
Decode the bytes type object to str type with specific encoding
In Python2:
Decode the str type object to unicode type with specific encoding,
or we just return the unicode string of object
Args:
obj(unicode|str|bytes) : The object to be decoded.
encoding(str) : The encoding format
Returns:
decoded result of obj
"""
if obj is None:
return obj

if isinstance(obj, six.binary_type):
return obj.decode(encoding)
elif isinstance(obj, six.text_type):
return obj
else:
return six.u(obj)


def to_bytes(obj, encoding='utf-8', inplace=False):
"""
All string in PaddlePaddle should be represented as a literal string.
This function will convert object to a bytes with specific encoding.
Especially, if the object type is a list or set container, we will iterate
all items in the object and convert them to bytes.
In Python3:
Encode the str type object to bytes type with specific encoding
In Python2:
Encode the unicode type object to str type with specific encoding,
or we just return the 8-bit string of object
Args:
obj(unicode|str|bytes|list|set) : The object to be encoded.
encoding(str) : The encoding format to encode a string
inplace(bool) : If we change the original object or we create a new one
Returns:
Decoded result of obj
"""
if obj is None:
return obj

if isinstance(obj, list):
if inplace:
for i in six.moves.xrange(len(obj)):
obj[i] = _to_bytes(obj[i], encoding)
return obj
else:
return [_to_bytes(item, encoding) for item in obj]
elif isinstance(obj, set):
if inplace:
for item in obj:
obj.remove(item)
obj.add(_to_bytes(item, encoding))
return obj
else:
return set([_to_bytes(item, encoding) for item in obj])
else:
return _to_bytes(obj, encoding)


def _to_bytes(obj, encoding):
"""
In Python3:
Encode the str type object to bytes type with specific encoding
In Python2:
Encode the unicode type object to str type with specific encoding,
or we just return the 8-bit string of object
Args:
obj(unicode|str|bytes) : The object to be encoded.
encoding(str) : The encoding format
Returns:
encoded result of obj
"""
if obj is None:
return obj

assert encoding is not None
if isinstance(obj, six.text_type):
return obj.encode(encoding)
elif isinstance(obj, six.binary_type):
return obj
else:
return six.b(obj)


# math related functions
def round(x, d=0):
"""
Compatible round which act the same behaviour in Python3.
Args:
x(float) : The number to round halfway.
Returns:
round result of x
"""
if six.PY3:
# The official walkaround of round in Python3 is incorrect
# we implement accroding this answer: https://www.techforgeek.info/round_python.html
if x > 0.0:
p = 10**d
return float(math.floor((x * p) + math.copysign(0.5, x))) / p
elif x < 0.0:
p = 10**d
return float(math.ceil((x * p) + math.copysign(0.5, x))) / p
else:
return math.copysign(0.0, x)
else:
import __builtin__
return __builtin__.round(x, d)


def floor_division(x, y):
"""
Compatible division which act the same behaviour in Python3 and Python2,
whose result will be a int value of floor(x / y) in Python3 and value of
(x / y) in Python2.
Args:
x(int|float) : The number to divide.
y(int|float) : The number to be divided
Returns:
division result of x // y
"""
return x // y


# exception related functions
def get_exception_message(exc):
"""
Get the error message of a specific exception
Args:
exec(Exception) : The exception to get error message.
Returns:
the error message of exec
"""
assert exc is not None

if six.PY2:
return exc.message
else:
return str(exc)
15 changes: 10 additions & 5 deletions python/paddle/dataset/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import numpy
import paddle.dataset.common
import tarfile
from six.moves import zip
import six
from six.moves import cPickle as pickle

__all__ = ['train100', 'test100', 'train10', 'test10', 'convert']
Expand All @@ -46,10 +46,11 @@

def reader_creator(filename, sub_name, cycle=False):
def read_batch(batch):
data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None))
data = batch[six.b('data')]
labels = batch.get(
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None
for sample, label in zip(data, labels):
for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label)

def reader():
Expand All @@ -59,7 +60,11 @@ def reader():

while True:
for name in names:
batch = pickle.load(f.extractfile(name))
if six.PY2:
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(
f.extractfile(name), encoding='bytes')
for item in read_batch(batch):
yield item
if not cycle:
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def download(url, module_name, md5sum, save_name=None):
total_length = r.headers.get('content-length')

if total_length is None:
with open(filename, 'w') as f:
with open(filename, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(filename, 'w') as f:
with open(filename, 'wb') as f:
dl = 0
total_length = int(total_length)
for data in r.iter_content(chunk_size=4096):
Expand Down
Loading

0 comments on commit 340a104

Please sign in to comment.