From 9d912c87e96db9e8bf7f1f7c25e4dd83c3c54ba2 Mon Sep 17 00:00:00 2001 From: Sebastian Puetz Date: Thu, 19 Sep 2019 14:09:09 +0200 Subject: [PATCH] Add Lookup Op and unit tests. --- .travis.yml | 21 ++++--- CMakeLists.txt | 2 +- ci/script.sh | 2 +- finalfusion-cxx | 2 +- finalfusion-tf/kernel/FFLookupKernels.cc | 61 +++++++++++++++++- finalfusion-tf/ops/FFLookupOps.cc | 23 +++++++ tests/CMakeLists.txt | 9 ++- tests/conftest.py | 10 +-- tests/test_eager_mode.py | 71 +++++++++++++++++++++ tests/test_graph_mode.py | 79 ++++++++++++++++++++++++ tests/test_init_close.py | 14 ----- 11 files changed, 258 insertions(+), 36 deletions(-) create mode 100644 tests/test_eager_mode.py create mode 100644 tests/test_graph_mode.py delete mode 100644 tests/test_init_close.py diff --git a/.travis.yml b/.travis.yml index 7e698f8..e53fcaa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,8 +3,7 @@ addons: apt: packages: - cmake - - python3.6-dev - - python3.6-venv + matrix: fast_finish: true include: @@ -12,15 +11,23 @@ matrix: rust: stable os: osx osx_image: xcode10.1 - - language: rust - os: linux - rust: stable + - language: python + python: 3.6 + addons: + apt: + packages: + - g++-4.8 + env: + - CC=gcc-4.8 + - CXX=g++-4.8 + - language: python + python: 3.7 install: - | if [ "$TRAVIS_OS_NAME" == "linux" ]; then - python3.6 -m venv venv - source venv/bin/activate + curl -sSf https://build.travis-ci.org/files/rustup-init.sh | sh -s -- -y + source $HOME/.cargo/env pip install tensorflow virtualenv pytest fi - | diff --git a/CMakeLists.txt b/CMakeLists.txt index 08ef6bb..6878c3e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ project(finalfusion_tf) enable_testing() -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 11) set(CMAKE_POSITION_INDEPENDENT_CODE ON) file(COPY tf_flags.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/ci/script.sh b/ci/script.sh index 9a7d89e..08688ba 100755 --- a/ci/script.sh +++ b/ci/script.sh @@ -7,4 +7,4 @@ cd build cmake .. make -ctest +ctest -V diff --git a/finalfusion-cxx b/finalfusion-cxx index 518196c..1e4890c 160000 --- a/finalfusion-cxx +++ b/finalfusion-cxx @@ -1 +1 @@ -Subproject commit 518196c3027c299fa83205503804d000886bb772 +Subproject commit 1e4890c3d00044738c8a72cc930d3c94ca7056cb diff --git a/finalfusion-tf/kernel/FFLookupKernels.cc b/finalfusion-tf/kernel/FFLookupKernels.cc index 7b91055..567fe97 100644 --- a/finalfusion-tf/kernel/FFLookupKernels.cc +++ b/finalfusion-tf/kernel/FFLookupKernels.cc @@ -56,4 +56,63 @@ class CloseFFEmbeddingsOp : public OpKernel { }; REGISTER_KERNEL_BUILDER( - Name("CloseFFEmbeddings").Device(DEVICE_CPU), CloseFFEmbeddingsOp); + Name("CloseFFEmbeddings").Device(DEVICE_CPU), + CloseFFEmbeddingsOp); + +class FFLookupOp : public OpKernel { +public: + explicit FFLookupOp(OpKernelConstruction *context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("mask_empty_string", &mask_empty_string_)); + OP_REQUIRES_OK(context, context->GetAttr("mask_failed_lookup", &mask_failed_lookup_)); + OP_REQUIRES_OK(context, context->GetAttr("embedding_len", &embedding_len_)); + } + + void Compute(OpKernelContext *context) override { + FFLookup *lookup; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &lookup)); + core::ScopedUnref unref(lookup); + + // verify length from construction with actual length + size_t const dims = lookup->dimensions(); + if (embedding_len_ != -1) { + OP_REQUIRES(context, + (dims == embedding_len_), + errors::InvalidArgument("Actual embedding length (", dims, ") does not match provided length (", + embedding_len_, ")")); + } + + // Get input tensor and flatten + Tensor const &query_tensor = context->input(1); + auto query = query_tensor.flat(); + + // Set output shape: add new dim with dimensionality of embeddings + TensorShape out_shape(query_tensor.shape()); + out_shape.AddDim(((int64) dims)); + + // Create output tensor and flatten + Tensor *output_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output_tensor)); + auto output_flat = output_tensor->flat(); + + for (int i = 0; i < query.size(); i++) { + std::vector embedding = lookup->embedding(query(i)); + // optionally mask failed lookups and/or empty string. Generally, empty string will lead to a failed lookup. + if ((query(i).empty() && mask_empty_string_) || (mask_failed_lookup_ && embedding.empty())) { + std::memset(&output_flat(i * dims), 0., dims * sizeof(float)); + } else { + // if no masking attributes are set and the embedding is empty, return error. + OP_REQUIRES(context, !embedding.empty(), errors::InvalidArgument("Embedding lookup failed for: ", query(i))); + std::memcpy(&output_flat(i * dims), embedding.data(), dims * sizeof(float)); + } + } + } + +private: + bool mask_empty_string_; + bool mask_failed_lookup_; + int embedding_len_; +}; + +REGISTER_KERNEL_BUILDER( + Name("FFLookup").Device(DEVICE_CPU), + FFLookupOp); \ No newline at end of file diff --git a/finalfusion-tf/ops/FFLookupOps.cc b/finalfusion-tf/ops/FFLookupOps.cc index d455981..b978424 100644 --- a/finalfusion-tf/ops/FFLookupOps.cc +++ b/finalfusion-tf/ops/FFLookupOps.cc @@ -22,4 +22,27 @@ namespace tensorflow { REGISTER_OP("CloseFFEmbeddings") .Input("embeds: resource") .SetShapeFn(shape_inference::NoOutputs); + + REGISTER_OP("FFLookup") + .Input("embeds: resource") + .Input("query: string") + .Attr("embedding_len: int >= -1 = -1") + .Attr("mask_empty_string: bool = true") + .Attr("mask_failed_lookup: bool = true") + .Output("embeddings: float") + .SetShapeFn([]( + ::tensorflow::shape_inference::InferenceContext *c + ) { + ShapeHandle strings_shape = c->input(1); + ShapeHandle output_shape; + int embedding_len; + TF_RETURN_IF_ERROR(c->GetAttr("embedding_len", &embedding_len)); + TF_RETURN_IF_ERROR( + c->Concatenate(strings_shape, c->Vector(embedding_len), &output_shape) + ); + ShapeHandle embeds = c->output(0); + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &embeds)); + c->set_output(0, output_shape); + return Status::OK(); + }); } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 472d3ba..05e739b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -3,7 +3,12 @@ include(CTest) file(COPY testdata/test.fifu DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/data) message(${CMAKE_CURRENT_BINARY_DIR}) -add_test(NAME python-init-close - COMMAND pytest ${CMAKE_CURRENT_SOURCE_DIR} +add_test(NAME eager-mode + COMMAND pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_eager_mode.py + WORKING_DIRECTORY ${PROJECT_BINARY_DIR} + ) + +add_test(NAME graph-mode + COMMAND pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_graph_mode.py WORKING_DIRECTORY ${PROJECT_BINARY_DIR} ) diff --git a/tests/conftest.py b/tests/conftest.py index 235a571..9bc0f8b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,22 +1,14 @@ -import os import platform import pytest import tensorflow as tf -tf.enable_eager_execution() - @pytest.fixture -def ff_lib(tests_root): +def ff_lib(): if platform.system() == "Darwin": LIB_SUFFIX = ".dylib" else: LIB_SUFFIX = ".so" yield tf.load_op_library("./finalfusion-tf/libfinalfusion_tf" + LIB_SUFFIX) - - -@pytest.fixture -def tests_root(): - yield os.path.dirname(__file__) \ No newline at end of file diff --git a/tests/test_eager_mode.py b/tests/test_eager_mode.py new file mode 100644 index 0000000..9f4dccd --- /dev/null +++ b/tests/test_eager_mode.py @@ -0,0 +1,71 @@ +import numpy as np +import pytest +import tensorflow as tf + +tf.enable_eager_execution() + + +def test_init_and_close(ff_lib): + embeddings = ff_lib.ff_embeddings() + ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", mmap=False) + ff_lib.close_ff_embeddings(embeddings) + + +def test_init_and_close_mmap(ff_lib): + embeddings = ff_lib.ff_embeddings() + ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", mmap=True) + ff_lib.close_ff_embeddings(embeddings) + + +def test_eager_lookup(ff_lib): + embeddings = ff_lib.ff_embeddings() + ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", mmap=False) + + ber = ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False) + ber_list = ff_lib.ff_lookup(embeddings, ["Berlin"], mask_empty_string=False, mask_failed_lookup=False) + ber_tensor = ff_lib.ff_lookup(embeddings, [["Berlin"]], mask_empty_string=False, mask_failed_lookup=False) + + assert ber.shape == (100,) + assert ber_list.shape == (1, 100) + assert ber_tensor.shape == (1, 1, 100) + + ff_lib.close_ff_embeddings(embeddings) + + +def test_eager_lookup_masked(ff_lib): + embeddings = ff_lib.ff_embeddings() + ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False) + tuebingen_masked = ff_lib.ff_lookup(embeddings, "Tübingen", mask_empty_string=False, mask_failed_lookup=True, + embedding_len=100) + empty_masked = ff_lib.ff_lookup(embeddings, "", mask_empty_string=True, mask_failed_lookup=False, embedding_len=100) + empty_masked_through_fail = ff_lib.ff_lookup(embeddings, "", mask_empty_string=False, mask_failed_lookup=True, + embedding_len=100) + assert np.allclose(tuebingen_masked, 0.) + assert np.allclose(empty_masked, 0.) + assert np.allclose(empty_masked_through_fail, 0.) + ff_lib.close_ff_embeddings(embeddings) + + +def test_eager_errors(ff_lib): + embeddings = ff_lib.ff_embeddings() + with pytest.raises(tf.errors.UnknownError): + ff_lib.initialize_ff_embeddings(embeddings, "foo.fifu", False) + + ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False) + + with pytest.raises(tf.errors.AlreadyExistsError): + ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False) + + with pytest.raises(tf.errors.InvalidArgumentError): + ff_lib.ff_lookup(embeddings, "Tübingen", mask_empty_string=False, mask_failed_lookup=False, embedding_len=100) + + # shape mismatch, 10 vs. actual 100 + with pytest.raises(tf.errors.InvalidArgumentError): + ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False, embedding_len=10) + + with pytest.raises(tf.errors.InvalidArgumentError): + ff_lib.ff_lookup(embeddings, "", mask_empty_string=False, mask_failed_lookup=False, embedding_len=100) + + ff_lib.close_ff_embeddings(embeddings) + with pytest.raises(tf.errors.NotFoundError): + ff_lib.close_ff_embeddings(embeddings) diff --git a/tests/test_graph_mode.py b/tests/test_graph_mode.py new file mode 100644 index 0000000..d22431a --- /dev/null +++ b/tests/test_graph_mode.py @@ -0,0 +1,79 @@ +import numpy as np +import pytest +import tensorflow as tf + + +def test_graph_lookup(ff_lib): + embeddings = ff_lib.ff_embeddings() + init = ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False) + + ber = ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False, embedding_len=100) + assert ber.shape == (100,) + + ber_list = ff_lib.ff_lookup(embeddings, ["Berlin"], mask_empty_string=False, mask_failed_lookup=False, + embedding_len=100) + assert ber_list.shape == (1, 100) + + ber_tensor = ff_lib.ff_lookup(embeddings, [["Berlin"]], mask_empty_string=False, mask_failed_lookup=False, + embedding_len=100) + assert ber_tensor.shape == (1, 1, 100) + + ber_no_shape = ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False) + assert ber_no_shape.shape.rank == 1 + assert ber_no_shape.shape[0].value is None + + ber_list_no_shape = ff_lib.ff_lookup(embeddings, ["Berlin"], mask_empty_string=False, mask_failed_lookup=False) + assert ber_list_no_shape.shape.rank == 2 + assert ber_list_no_shape.shape[0].value == tf.Dimension(1) + assert ber_list_no_shape.shape[1].value is None + + with tf.Session() as sess: + sess.run([init]) + res = sess.run([ber, ber_list, ber_tensor]) + assert res[0].shape == (100,) + assert res[1].shape == (1, 100) + assert res[2].shape == (1, 1, 100) + sess.run([ff_lib.close_ff_embeddings(embeddings)]) + + +def test_graph_lookup_masked(ff_lib): + embeddings = ff_lib.ff_embeddings() + init = ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", True) + tuebingen_masked = ff_lib.ff_lookup(embeddings, "Tübingen", mask_empty_string=False, mask_failed_lookup=True, + embedding_len=100) + empty_masked = ff_lib.ff_lookup(embeddings, "", mask_empty_string=True, mask_failed_lookup=False, embedding_len=100) + empty_masked_through_fail = ff_lib.ff_lookup(embeddings, "", mask_empty_string=False, mask_failed_lookup=True, + embedding_len=100) + with tf.Session() as sess: + sess.run([init]) + res = sess.run([tuebingen_masked, empty_masked, empty_masked_through_fail]) + assert np.allclose(res, 0.) + + +def test_graph_errors(ff_lib): + embeddings = ff_lib.ff_embeddings() + tuebingen_unmasked = ff_lib.ff_lookup(embeddings, "Tübingen", mask_empty_string=False, mask_failed_lookup=False, + embedding_len=100) + ber_bad_shape = ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False, + embedding_len=10) + assert ber_bad_shape.shape == (10,) + empty_unmasked = ff_lib.ff_lookup(embeddings, "", mask_empty_string=False, mask_failed_lookup=False, + embedding_len=100) + + with tf.Session() as sess: + with pytest.raises(tf.errors.UnknownError): + sess.run([ff_lib.initialize_ff_embeddings(embeddings, "foo.fifu", False)]) + + sess.run([ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)]) + + with pytest.raises(tf.errors.AlreadyExistsError): + sess.run([ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)]) + with pytest.raises(tf.errors.InvalidArgumentError): + sess.run([tuebingen_unmasked]) + with pytest.raises(tf.errors.InvalidArgumentError): + sess.run([empty_unmasked]) + with pytest.raises(tf.errors.InvalidArgumentError): + sess.run([ber_bad_shape]) + sess.run([ff_lib.close_ff_embeddings(embeddings)]) + with pytest.raises(tf.errors.NotFoundError): + sess.run([ff_lib.close_ff_embeddings(embeddings)]) diff --git a/tests/test_init_close.py b/tests/test_init_close.py deleted file mode 100644 index ea284c6..0000000 --- a/tests/test_init_close.py +++ /dev/null @@ -1,14 +0,0 @@ -import platform -import pytest - - -def test_init_and_close(ff_lib): - embeddings = ff_lib.ff_embeddings() - ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", mmap=False) - ff_lib.close_ff_embeddings(embeddings) - - -def test_init_and_close_mmap(ff_lib): - embeddings = ff_lib.ff_embeddings() - ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", mmap=True) - ff_lib.close_ff_embeddings(embeddings)