Skip to content

Commit

Permalink
Add Lookup Op and unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
sebpuetz committed Sep 20, 2019
1 parent d1f0283 commit 0c5e049
Show file tree
Hide file tree
Showing 11 changed files with 263 additions and 36 deletions.
26 changes: 19 additions & 7 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
dist: bionic
addons:
apt:
packages:
- cmake
- python3.6-dev
- python3.6-venv

matrix:
fast_finish: true
include:
- language: rust
rust: stable
os: osx
osx_image: xcode10.1
- language: rust
- language: python
python: 3.6
os: linux
rust: stable
dist: bionic
addons:
apt:
sources:
- ubuntu-toolchain-r-test
packages:
- g++-4.8
env:
- CC=gcc-4.8
- CXX=g++-4.8
- language: python
python: 3.7
os: linux
dist: xenial

install:
- |
if [ "$TRAVIS_OS_NAME" == "linux" ]; then
python3.6 -m venv venv
source venv/bin/activate
curl -sSf https://build.travis-ci.org/files/rustup-init.sh | sh -s -- -y
source $HOME/.cargo/env
pip install tensorflow virtualenv pytest
fi
- |
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ project(finalfusion_tf)

enable_testing()

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

file(COPY tf_flags.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
Expand Down
2 changes: 1 addition & 1 deletion ci/script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ cd build
cmake ..
make

ctest
ctest -V
2 changes: 1 addition & 1 deletion finalfusion-cxx
61 changes: 60 additions & 1 deletion finalfusion-tf/kernel/FFLookupKernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,63 @@ class CloseFFEmbeddingsOp : public OpKernel {
};

REGISTER_KERNEL_BUILDER(
Name("CloseFFEmbeddings").Device(DEVICE_CPU), CloseFFEmbeddingsOp);
Name("CloseFFEmbeddings").Device(DEVICE_CPU),
CloseFFEmbeddingsOp);

class FFLookupOp : public OpKernel {
public:
explicit FFLookupOp(OpKernelConstruction *context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("mask_empty_string", &mask_empty_string_));
OP_REQUIRES_OK(context, context->GetAttr("mask_failed_lookup", &mask_failed_lookup_));
OP_REQUIRES_OK(context, context->GetAttr("embedding_len", &embedding_len_));
}

void Compute(OpKernelContext *context) override {
FFLookup *lookup;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &lookup));
core::ScopedUnref unref(lookup);

// verify length from construction with actual length
size_t const dims = lookup->dimensions();
if (embedding_len_ != -1) {
OP_REQUIRES(context,
(dims == embedding_len_),
errors::InvalidArgument("Actual embedding length (", dims, ") does not match provided length (",
embedding_len_, ")"));
}

// Get input tensor and flatten
Tensor const &query_tensor = context->input(1);
auto query = query_tensor.flat<string>();

// Set output shape: add new dim with dimensionality of embeddings
TensorShape out_shape(query_tensor.shape());
out_shape.AddDim(((int64) dims));

// Create output tensor and flatten
Tensor *output_tensor = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output_tensor));
auto output_flat = output_tensor->flat<float>();

for (int i = 0; i < query.size(); i++) {
std::vector<float> embedding = lookup->embedding(query(i));
// optionally mask failed lookups and/or empty string. Generally, empty string will lead to a failed lookup.
if ((query(i).empty() && mask_empty_string_) || (mask_failed_lookup_ && embedding.empty())) {
std::memset(&output_flat(i * dims), 0., dims * sizeof(float));
} else {
// if no masking attributes are set and the embedding is empty, return error.
OP_REQUIRES(context, !embedding.empty(), errors::InvalidArgument("Embedding lookup failed for: ", query(i)));
std::memcpy(&output_flat(i * dims), embedding.data(), dims * sizeof(float));
}
}
}

private:
bool mask_empty_string_;
bool mask_failed_lookup_;
int embedding_len_;
};

REGISTER_KERNEL_BUILDER(
Name("FFLookup").Device(DEVICE_CPU),
FFLookupOp);
23 changes: 23 additions & 0 deletions finalfusion-tf/ops/FFLookupOps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,27 @@ namespace tensorflow {
REGISTER_OP("CloseFFEmbeddings")
.Input("embeds: resource")
.SetShapeFn(shape_inference::NoOutputs);

REGISTER_OP("FFLookup")
.Input("embeds: resource")
.Input("query: string")
.Attr("embedding_len: int >= -1 = -1")
.Attr("mask_empty_string: bool = true")
.Attr("mask_failed_lookup: bool = true")
.Output("embeddings: float")
.SetShapeFn([](
::tensorflow::shape_inference::InferenceContext *c
) {
ShapeHandle strings_shape = c->input(1);
ShapeHandle output_shape;
int embedding_len;
TF_RETURN_IF_ERROR(c->GetAttr("embedding_len", &embedding_len));
TF_RETURN_IF_ERROR(
c->Concatenate(strings_shape, c->Vector(embedding_len), &output_shape)
);
ShapeHandle embeds = c->output(0);
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &embeds));
c->set_output(0, output_shape);
return Status::OK();
});
}
9 changes: 7 additions & 2 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ include(CTest)
file(COPY testdata/test.fifu DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/data)
message(${CMAKE_CURRENT_BINARY_DIR})

add_test(NAME python-init-close
COMMAND pytest ${CMAKE_CURRENT_SOURCE_DIR}
add_test(NAME eager-mode
COMMAND pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_eager_mode.py
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
)

add_test(NAME graph-mode
COMMAND pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_graph_mode.py
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
)
10 changes: 1 addition & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
import os
import platform

import pytest
import tensorflow as tf

tf.enable_eager_execution()


@pytest.fixture
def ff_lib(tests_root):
def ff_lib():
if platform.system() == "Darwin":
LIB_SUFFIX = ".dylib"
else:
LIB_SUFFIX = ".so"

yield tf.load_op_library("./finalfusion-tf/libfinalfusion_tf" + LIB_SUFFIX)


@pytest.fixture
def tests_root():
yield os.path.dirname(__file__)
71 changes: 71 additions & 0 deletions tests/test_eager_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import pytest
import tensorflow as tf

tf.enable_eager_execution()


def test_init_and_close(ff_lib):
embeddings = ff_lib.ff_embeddings()
ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", mmap=False)
ff_lib.close_ff_embeddings(embeddings)


def test_init_and_close_mmap(ff_lib):
embeddings = ff_lib.ff_embeddings()
ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", mmap=True)
ff_lib.close_ff_embeddings(embeddings)


def test_eager_lookup(ff_lib):
embeddings = ff_lib.ff_embeddings()
ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", mmap=False)

ber = ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False)
ber_list = ff_lib.ff_lookup(embeddings, ["Berlin"], mask_empty_string=False, mask_failed_lookup=False)
ber_tensor = ff_lib.ff_lookup(embeddings, [["Berlin"]], mask_empty_string=False, mask_failed_lookup=False)

assert ber.shape == (100,)
assert ber_list.shape == (1, 100)
assert ber_tensor.shape == (1, 1, 100)

ff_lib.close_ff_embeddings(embeddings)


def test_eager_lookup_masked(ff_lib):
embeddings = ff_lib.ff_embeddings()
ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)
tuebingen_masked = ff_lib.ff_lookup(embeddings, "Tübingen", mask_empty_string=False, mask_failed_lookup=True,
embedding_len=100)
empty_masked = ff_lib.ff_lookup(embeddings, "", mask_empty_string=True, mask_failed_lookup=False, embedding_len=100)
empty_masked_through_fail = ff_lib.ff_lookup(embeddings, "", mask_empty_string=False, mask_failed_lookup=True,
embedding_len=100)
assert np.allclose(tuebingen_masked, 0.)
assert np.allclose(empty_masked, 0.)
assert np.allclose(empty_masked_through_fail, 0.)
ff_lib.close_ff_embeddings(embeddings)


def test_eager_errors(ff_lib):
embeddings = ff_lib.ff_embeddings()
with pytest.raises(tf.errors.UnknownError):
ff_lib.initialize_ff_embeddings(embeddings, "foo.fifu", False)

ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)

with pytest.raises(tf.errors.AlreadyExistsError):
ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)

with pytest.raises(tf.errors.InvalidArgumentError):
ff_lib.ff_lookup(embeddings, "Tübingen", mask_empty_string=False, mask_failed_lookup=False, embedding_len=100)

# shape mismatch, 10 vs. actual 100
with pytest.raises(tf.errors.InvalidArgumentError):
ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False, embedding_len=10)

with pytest.raises(tf.errors.InvalidArgumentError):
ff_lib.ff_lookup(embeddings, "", mask_empty_string=False, mask_failed_lookup=False, embedding_len=100)

ff_lib.close_ff_embeddings(embeddings)
with pytest.raises(tf.errors.NotFoundError):
ff_lib.close_ff_embeddings(embeddings)
79 changes: 79 additions & 0 deletions tests/test_graph_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import pytest
import tensorflow as tf


def test_graph_lookup(ff_lib):
embeddings = ff_lib.ff_embeddings()
init = ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)

ber = ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False, embedding_len=100)
assert ber.shape == (100,)

ber_list = ff_lib.ff_lookup(embeddings, ["Berlin"], mask_empty_string=False, mask_failed_lookup=False,
embedding_len=100)
assert ber_list.shape == (1, 100)

ber_tensor = ff_lib.ff_lookup(embeddings, [["Berlin"]], mask_empty_string=False, mask_failed_lookup=False,
embedding_len=100)
assert ber_tensor.shape == (1, 1, 100)

ber_no_shape = ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False)
assert ber_no_shape.shape.rank == 1
assert ber_no_shape.shape[0].value is None

ber_list_no_shape = ff_lib.ff_lookup(embeddings, ["Berlin"], mask_empty_string=False, mask_failed_lookup=False)
assert ber_list_no_shape.shape.rank == 2
assert ber_list_no_shape.shape[0].value == tf.Dimension(1)
assert ber_list_no_shape.shape[1].value is None

with tf.Session() as sess:
sess.run([init])
res = sess.run([ber, ber_list, ber_tensor])
assert res[0].shape == (100,)
assert res[1].shape == (1, 100)
assert res[2].shape == (1, 1, 100)
sess.run([ff_lib.close_ff_embeddings(embeddings)])


def test_graph_lookup_masked(ff_lib):
embeddings = ff_lib.ff_embeddings()
init = ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", True)
tuebingen_masked = ff_lib.ff_lookup(embeddings, "Tübingen", mask_empty_string=False, mask_failed_lookup=True,
embedding_len=100)
empty_masked = ff_lib.ff_lookup(embeddings, "", mask_empty_string=True, mask_failed_lookup=False, embedding_len=100)
empty_masked_through_fail = ff_lib.ff_lookup(embeddings, "", mask_empty_string=False, mask_failed_lookup=True,
embedding_len=100)
with tf.Session() as sess:
sess.run([init])
res = sess.run([tuebingen_masked, empty_masked, empty_masked_through_fail])
assert np.allclose(res, 0.)


def test_graph_errors(ff_lib):
embeddings = ff_lib.ff_embeddings()
tuebingen_unmasked = ff_lib.ff_lookup(embeddings, "Tübingen", mask_empty_string=False, mask_failed_lookup=False,
embedding_len=100)
ber_bad_shape = ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False,
embedding_len=10)
assert ber_bad_shape.shape == (10,)
empty_unmasked = ff_lib.ff_lookup(embeddings, "", mask_empty_string=False, mask_failed_lookup=False,
embedding_len=100)

with tf.Session() as sess:
with pytest.raises(tf.errors.UnknownError):
sess.run([ff_lib.initialize_ff_embeddings(embeddings, "foo.fifu", False)])

sess.run([ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)])

with pytest.raises(tf.errors.AlreadyExistsError):
sess.run([ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)])
with pytest.raises(tf.errors.InvalidArgumentError):
sess.run([tuebingen_unmasked])
with pytest.raises(tf.errors.InvalidArgumentError):
sess.run([empty_unmasked])
with pytest.raises(tf.errors.InvalidArgumentError):
sess.run([ber_bad_shape])
sess.run([ff_lib.close_ff_embeddings(embeddings)])
with pytest.raises(tf.errors.NotFoundError):
sess.run([ff_lib.close_ff_embeddings(embeddings)])
14 changes: 0 additions & 14 deletions tests/test_init_close.py

This file was deleted.

0 comments on commit 0c5e049

Please sign in to comment.