Skip to content

Commit

Permalink
Merge pull request #752 from reyoung/feature/fix_data_loss_in_pydp2
Browse files Browse the repository at this point in the history
Add unittest related #653
  • Loading branch information
gangliao authored Dec 7, 2016
2 parents b6d036a + 1539335 commit adc23f6
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 19 deletions.
61 changes: 42 additions & 19 deletions paddle/gserver/tests/test_PyDataProvider2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ limitations under the License. */
#ifndef PADDLE_NO_PYTHON
#include <gtest/gtest.h>
#include <fstream>
#include "paddle/utils/Util.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/gserver/dataproviders/DataProvider.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Util.h"

P_DEFINE_string(train_list, "unittest.list", "file list for unittest");

namespace paddle {
namespace unittest {
namespace pydp2 {
extern void setOnPoolFilledHook(const std::function<void(size_t)>& func);
extern void setOnPoolFilledHook(const std::function<void(size_t)> &func);
extern void clearOnPoolFilledHook();

} // namespace pydp2
Expand All @@ -33,8 +33,8 @@ extern void clearOnPoolFilledHook();

const paddle::real epsilon = 1e-5;

static inline int64_t readDataBatch(paddle::DataBatch* batch,
const std::string& funcName,
static inline int64_t readDataBatch(paddle::DataBatch *batch,
const std::string &funcName,
int64_t batchSize = 65535) {
paddle::DataConfig config;
config.set_type("py2");
Expand Down Expand Up @@ -143,7 +143,7 @@ TEST(PyDataProvider2, init_hook) {
paddle::DataBatch batch;
int64_t num = provider->getNextBatchInternal(100000, &batch);
ASSERT_EQ(num, 200);
auto& mat = batch.getStreams()[0].value;
auto &mat = batch.getStreams()[0].value;
ASSERT_EQ((size_t)mat->getWidth(), (size_t)20);
for (size_t i = 0; i < 200; ++i) {
for (size_t j = 0; j < 20; ++j) {
Expand All @@ -170,7 +170,7 @@ TEST(PyDataProvider2, sparse_no_value_no_seq) {
CHECK(csm != nullptr);
for (int i = 0; i < 200; ++i) {
CHECK_EQ(csm->getColNum(i), (size_t)10);
int* cols = csm->getRowCols(i);
int *cols = csm->getRowCols(i);
for (int j = 0; j < 10; ++j) {
CHECK_EQ(cols[j], (i + 1) * (j + 1));
}
Expand All @@ -185,8 +185,8 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
CHECK(csm != nullptr);
for (int i = 0; i < 200; ++i) {
CHECK_EQ(csm->getColNum(i), (size_t)10);
int* cols = csm->getRowCols(i);
real* dat = csm->getRowValues(i);
int *cols = csm->getRowCols(i);
real *dat = csm->getRowValues(i);
for (int j = 0; j < 10; ++j) {
EXPECT_EQ(cols[j], (i + 1) * (j + 1));
EXPECT_EQ(dat[j], real(j) / real(i + 1));
Expand All @@ -197,7 +197,7 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
TEST(PyDataProvider2, index_seq) {
paddle::DataBatch batch;
CHECK_EQ(readDataBatch(&batch, "test_index_seq"), 200);
auto& arg = batch.getStreams()[0];
auto &arg = batch.getStreams()[0];
CHECK_EQ((int)arg.ids->getSize(), (200 + 1) * 200 / 2);
size_t tmp = 0;
for (size_t i = 0; i < 200; ++i) { // CHECK DATA CORRECT
Expand All @@ -219,7 +219,7 @@ TEST(PyDataProvider2, index_seq) {
TEST(PyDataProvider2, index_sub_seq) {
paddle::DataBatch batch;
ASSERT_EQ(readDataBatch(&batch, "test_index_sub_seq"), 200);
auto& arg = batch.getStreams()[0];
auto &arg = batch.getStreams()[0];
size_t tmp = 0;
for (size_t i = 0; i < 200; ++i) {
for (size_t j = 0; j < i + 1; ++j) {
Expand Down Expand Up @@ -268,7 +268,7 @@ TEST(PyDataProvider2, min_pool_size) {
}
});
while (true) {
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
if (realBatchSize) {
totalData -= realBatchSize;
} else {
Expand All @@ -291,7 +291,7 @@ TEST(PyDataProvider2, can_over_batch_size) {
provider->reset();
constexpr size_t batchSize = 100;
while (true) {
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
if (realBatchSize) {
CHECK_LE(realBatchSize, batchSize);
} else {
Expand All @@ -317,12 +317,12 @@ TEST(PyDataProvider2, input_order) {
provider->reset();
constexpr size_t batchSize = 100;
while (true) {
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
if (!realBatchSize) {
break;
}
ASSERT_EQ(batch.getStreams().size(), (size_t)2);
for (size_t i = 0; i < realBatchSize; ++i) {
ASSERT_EQ(batch.getStreams().size(), static_cast<size_t>(2));
for (int64_t i = 0; i < realBatchSize; ++i) {
ASSERT_EQ(batch.getStream(0).ids->getData()[i], 0);
ASSERT_EQ(batch.getStream(1).ids->getData()[i], 1);
}
Expand All @@ -341,11 +341,11 @@ TEST(PyDataProvider2, test_check) {
paddle::DataProvider::create(config, false));
provider->reset();
while (true) {
size_t realBatchSize = provider->getNextBatchInternal(100, &batch);
int64_t realBatchSize = provider->getNextBatchInternal(100, &batch);
if (!realBatchSize) {
break;
} else {
auto& ivec = batch.getStream(0).ids;
auto &ivec = batch.getStream(0).ids;
for (size_t i = 0; i < ivec->getSize(); ++i) {
CHECK_LT(ivec->getData()[i], 10);
}
Expand All @@ -370,7 +370,30 @@ TEST(PyDataProvider2, multiThread) {
provider.reset();
}

int main(int argc, char** argv) {
TEST(PyDataProvider2, minPoolSizeWithCache) {
paddle::DataConfig config;
config.set_type("py2");
config.set_files(FLAGS_train_list.c_str());
config.set_load_data_module("test_PyDataProvider2");
config.set_load_data_object("test_min_pool_size_with_cache");
config.set_async_load_data(true);

std::unique_ptr<paddle::DataProvider> provider(
paddle::DataProvider::create(config, false));

paddle::DataBatch batch;

for (int i = 0; i < 10; ++i) {
provider->reset();
int64_t sum = 0;
while (int64_t actualNum = provider->getNextBatch(100, &batch)) {
sum += actualNum;
}
ASSERT_EQ(1 << 20, sum);
}
}

int main(int argc, char **argv) {
testing::InitGoogleTest(&argc, argv);
paddle::initMain(argc, argv);
paddle::initPython(argc, argv);
Expand Down
10 changes: 10 additions & 0 deletions paddle/gserver/tests/test_PyDataProvider2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,13 @@ def test_check(settings, filename):
if i < 10:
yield_good_value = True
yield i


@provider(
input_types=[index_slot(10)],
min_pool_size=1000,
cache=CacheType.CACHE_PASS_IN_MEM, )
def test_min_pool_size_with_cache(settings, filename):
import random
for _ in xrange(2**20):
yield random.randint(0, 9)

0 comments on commit adc23f6

Please sign in to comment.