Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unittest related #653 #752

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions paddle/gserver/tests/test_PyDataProvider2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ limitations under the License. */
#ifndef PADDLE_NO_PYTHON
#include <gtest/gtest.h>
#include <fstream>
#include "paddle/utils/Util.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/gserver/dataproviders/DataProvider.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Util.h"

P_DEFINE_string(train_list, "unittest.list", "file list for unittest");

namespace paddle {
namespace unittest {
namespace pydp2 {
extern void setOnPoolFilledHook(const std::function<void(size_t)>& func);
extern void setOnPoolFilledHook(const std::function<void(size_t)> &func);
extern void clearOnPoolFilledHook();

} // namespace pydp2
Expand All @@ -33,8 +33,8 @@ extern void clearOnPoolFilledHook();

const paddle::real epsilon = 1e-5;

static inline int64_t readDataBatch(paddle::DataBatch* batch,
const std::string& funcName,
static inline int64_t readDataBatch(paddle::DataBatch *batch,
const std::string &funcName,
int64_t batchSize = 65535) {
paddle::DataConfig config;
config.set_type("py2");
Expand Down Expand Up @@ -143,7 +143,7 @@ TEST(PyDataProvider2, init_hook) {
paddle::DataBatch batch;
int64_t num = provider->getNextBatchInternal(100000, &batch);
ASSERT_EQ(num, 200);
auto& mat = batch.getStreams()[0].value;
auto &mat = batch.getStreams()[0].value;
ASSERT_EQ((size_t)mat->getWidth(), (size_t)20);
for (size_t i = 0; i < 200; ++i) {
for (size_t j = 0; j < 20; ++j) {
Expand All @@ -170,7 +170,7 @@ TEST(PyDataProvider2, sparse_no_value_no_seq) {
CHECK(csm != nullptr);
for (int i = 0; i < 200; ++i) {
CHECK_EQ(csm->getColNum(i), (size_t)10);
int* cols = csm->getRowCols(i);
int *cols = csm->getRowCols(i);
for (int j = 0; j < 10; ++j) {
CHECK_EQ(cols[j], (i + 1) * (j + 1));
}
Expand All @@ -185,8 +185,8 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
CHECK(csm != nullptr);
for (int i = 0; i < 200; ++i) {
CHECK_EQ(csm->getColNum(i), (size_t)10);
int* cols = csm->getRowCols(i);
real* dat = csm->getRowValues(i);
int *cols = csm->getRowCols(i);
real *dat = csm->getRowValues(i);
for (int j = 0; j < 10; ++j) {
EXPECT_EQ(cols[j], (i + 1) * (j + 1));
EXPECT_EQ(dat[j], real(j) / real(i + 1));
Expand All @@ -197,7 +197,7 @@ TEST(PyDataProvider2, sparse_value_no_seq) {
TEST(PyDataProvider2, index_seq) {
paddle::DataBatch batch;
CHECK_EQ(readDataBatch(&batch, "test_index_seq"), 200);
auto& arg = batch.getStreams()[0];
auto &arg = batch.getStreams()[0];
CHECK_EQ((int)arg.ids->getSize(), (200 + 1) * 200 / 2);
size_t tmp = 0;
for (size_t i = 0; i < 200; ++i) { // CHECK DATA CORRECT
Expand All @@ -219,7 +219,7 @@ TEST(PyDataProvider2, index_seq) {
TEST(PyDataProvider2, index_sub_seq) {
paddle::DataBatch batch;
ASSERT_EQ(readDataBatch(&batch, "test_index_sub_seq"), 200);
auto& arg = batch.getStreams()[0];
auto &arg = batch.getStreams()[0];
size_t tmp = 0;
for (size_t i = 0; i < 200; ++i) {
for (size_t j = 0; j < i + 1; ++j) {
Expand Down Expand Up @@ -268,7 +268,7 @@ TEST(PyDataProvider2, min_pool_size) {
}
});
while (true) {
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
if (realBatchSize) {
totalData -= realBatchSize;
} else {
Expand All @@ -291,7 +291,7 @@ TEST(PyDataProvider2, can_over_batch_size) {
provider->reset();
constexpr size_t batchSize = 100;
while (true) {
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
if (realBatchSize) {
CHECK_LE(realBatchSize, batchSize);
} else {
Expand All @@ -317,12 +317,12 @@ TEST(PyDataProvider2, input_order) {
provider->reset();
constexpr size_t batchSize = 100;
while (true) {
size_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
int64_t realBatchSize = provider->getNextBatchInternal(batchSize, &batch);
if (!realBatchSize) {
break;
}
ASSERT_EQ(batch.getStreams().size(), (size_t)2);
for (size_t i = 0; i < realBatchSize; ++i) {
ASSERT_EQ(batch.getStreams().size(), static_cast<size_t>(2));
for (int64_t i = 0; i < realBatchSize; ++i) {
ASSERT_EQ(batch.getStream(0).ids->getData()[i], 0);
ASSERT_EQ(batch.getStream(1).ids->getData()[i], 1);
}
Expand All @@ -341,11 +341,11 @@ TEST(PyDataProvider2, test_check) {
paddle::DataProvider::create(config, false));
provider->reset();
while (true) {
size_t realBatchSize = provider->getNextBatchInternal(100, &batch);
int64_t realBatchSize = provider->getNextBatchInternal(100, &batch);
if (!realBatchSize) {
break;
} else {
auto& ivec = batch.getStream(0).ids;
auto &ivec = batch.getStream(0).ids;
for (size_t i = 0; i < ivec->getSize(); ++i) {
CHECK_LT(ivec->getData()[i], 10);
}
Expand All @@ -370,7 +370,30 @@ TEST(PyDataProvider2, multiThread) {
provider.reset();
}

int main(int argc, char** argv) {
TEST(PyDataProvider2, minPoolSizeWithCache) {
paddle::DataConfig config;
config.set_type("py2");
config.set_files(FLAGS_train_list.c_str());
config.set_load_data_module("test_PyDataProvider2");
config.set_load_data_object("test_min_pool_size_with_cache");
config.set_async_load_data(true);

std::unique_ptr<paddle::DataProvider> provider(
paddle::DataProvider::create(config, false));

paddle::DataBatch batch;

for (int i = 0; i < 10; ++i) {
provider->reset();
int64_t sum = 0;
while (int64_t actualNum = provider->getNextBatch(100, &batch)) {
sum += actualNum;
}
ASSERT_EQ(1 << 20, sum);
}
}

int main(int argc, char **argv) {
testing::InitGoogleTest(&argc, argv);
paddle::initMain(argc, argv);
paddle::initPython(argc, argv);
Expand Down
10 changes: 10 additions & 0 deletions paddle/gserver/tests/test_PyDataProvider2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,13 @@ def test_check(settings, filename):
if i < 10:
yield_good_value = True
yield i


@provider(
input_types=[index_slot(10)],
min_pool_size=1000,
cache=CacheType.CACHE_PASS_IN_MEM, )
def test_min_pool_size_with_cache(settings, filename):
import random
for _ in xrange(2**20):
yield random.randint(0, 9)