Skip to content

Commit

Permalink
add back has_valid_lod
Browse files Browse the repository at this point in the history
  • Loading branch information
kexinzhao committed Jun 14, 2018
1 parent 4bc8c5b commit 1695cd9
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 2 deletions.
7 changes: 6 additions & 1 deletion paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,12 @@ PYBIND11_PLUGIN(core) {
new_lod.reserve(lod.size());
std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
return new_lod;
});
})
.def("has_valid_recursive_sequence_lengths", [](LoDTensor &self) -> bool {
// Check that the lod info is valid and match the outermost
// dimension of the LoDTensor data
return CheckLoD(self.lod(), vectorize(self.dims()).front());
});

py::class_<SelectedRows>(m, "SelectedRows")
.def("__init__",
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/data_feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def done(self):
t.set(arr, self.place)
if self.lod_level > 0:
t.set_recursive_sequence_lengths(self.lod)
assert t.has_valid_recursive_sequence_lengths(
), "the provided lod info is invalid"
return t


Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/lod_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def create_lod_tensor(data, lod, place):
tensor = core.LoDTensor()
tensor.set(data, place)
tensor.set_recursive_sequence_lengths(lod)
assert tensor.has_valid_recursive_sequence_lengths(
), "the provided lod info is invalid"
return tensor
else:
raise TypeError(
Expand Down
22 changes: 22 additions & 0 deletions python/paddle/fluid/tests/test_lod_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,36 @@
class TestLoDTensor(unittest.TestCase):
def test_pybind_lod(self):
tensor = fluid.LoDTensor()
lod = []
tensor.set_recursive_sequence_lengths(lod)
self.assertTrue(tensor.has_valid_recursive_sequence_lengths())
lod = [[], [1], [3]]
tensor.set_recursive_sequence_lengths(lod)
self.assertFalse(tensor.has_valid_recursive_sequence_lengths())
lod = [[0], [2], [3]]
tensor.set_recursive_sequence_lengths(lod)
self.assertFalse(tensor.has_valid_recursive_sequence_lengths())

lod = [[1, 2, 3]]
tensor.set_recursive_sequence_lengths(lod)
self.assertEqual(tensor.recursive_sequence_lengths(), lod)
tensor.set(np.random.random([6, 1]), fluid.CPUPlace())
self.assertTrue(tensor.has_valid_recursive_sequence_lengths())
tensor.set(np.random.random([9, 1]), fluid.CPUPlace())
self.assertFalse(tensor.has_valid_recursive_sequence_lengths())

# Each level's sum should be equal to the number of items in the next level
# Moreover, last level's sum should be equal to the tensor height
lod = [[2, 1], [1, 3, 1, 2, 1]]
tensor.set_recursive_sequence_lengths(lod)
self.assertEqual(tensor.recursive_sequence_lengths(), lod)
tensor.set(np.random.random([8, 1]), fluid.CPUPlace())
self.assertFalse(tensor.has_valid_recursive_sequence_lengths())
lod = [[2, 3], [1, 3, 1, 2, 1]]
tensor.set_recursive_sequence_lengths(lod)
self.assertTrue(tensor.has_valid_recursive_sequence_lengths())
tensor.set(np.random.random([9, 1]), fluid.CPUPlace())
self.assertFalse(tensor.has_valid_recursive_sequence_lengths())

def test_create_lod_tensor(self):
# Create LoDTensor from a list
Expand Down
2 changes: 1 addition & 1 deletion tools/codestyle/cpplint_pre_commit.hook
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ for file in $(git diff --cached --name-status | awk '$1 != "D" {print $2}'); do
if [[ $file =~ ^(paddle/api/.*|paddle/capi/.*|paddle/contrib/.*|paddle/cuda/.*|paddle/function/.*|paddle/gserver/.*|paddle/math/.*|paddle/optimizer/.*|paddle/parameter/.*|paddle/pserver/.*|paddle/trainer/.*|paddle/utils/.*) ]]; then
continue;
else
cpplint $file;
cpplint --filter=-readability/fn_size $file;
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
fi
done
Expand Down

0 comments on commit 1695cd9

Please sign in to comment.