Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lua: Manage imported public keys in stream handle #12664

Merged
merged 7 commits into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions source/extensions/filters/http/lua/lua_filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,25 +555,30 @@ int StreamHandleWrapper::luaLogCritical(lua_State* state) {
}

int StreamHandleWrapper::luaVerifySignature(lua_State* state) {
// Step 1: get hash function
// Step 1: Get hash function.
absl::string_view hash = luaL_checkstring(state, 2);

// Step 2: get key pointer
auto ptr = lua_touserdata(state, 3);
// Step 2: Get the key pointer.
auto key = luaL_checkstring(state, 3);
auto ptr = public_key_storage_.find(key);
if (ptr == public_key_storage_.end()) {
luaL_error(state, "invalid public key");
return 0;
}

// Step 3: get signature
// Step 3: Get signature from args.
const char* signature = luaL_checkstring(state, 4);
int sig_len = luaL_checknumber(state, 5);
const std::vector<uint8_t> sig_vec(signature, signature + sig_len);

// Step 4: get clear text
// Step 4: Get clear text from args.
const char* clear_text = luaL_checkstring(state, 6);
int text_len = luaL_checknumber(state, 7);
const std::vector<uint8_t> text_vec(clear_text, clear_text + text_len);
// Step 5: verify signature
auto crypto = reinterpret_cast<Envoy::Common::Crypto::CryptoObject*>(ptr);

// Step 5: Verify signature.
auto& crypto_util = Envoy::Common::Crypto::UtilitySingleton::get();
auto output = crypto_util.verifySignature(hash, *crypto, sig_vec, text_vec);
auto output = crypto_util.verifySignature(hash, *ptr->second, sig_vec, text_vec);
lua_pushboolean(state, output.result_);
if (output.result_) {
lua_pushnil(state);
Expand All @@ -593,7 +598,18 @@ int StreamHandleWrapper::luaImportPublicKey(lua_State* state) {
} else {
auto& crypto_util = Envoy::Common::Crypto::UtilitySingleton::get();
Envoy::Common::Crypto::CryptoObjectPtr crypto_ptr = crypto_util.importPublicKey(key);
public_key_wrapper_.reset(PublicKeyWrapper::create(state, std::move(crypto_ptr)), true);
auto wrapper = Envoy::Common::Crypto::Access::getTyped<Envoy::Common::Crypto::PublicKeyObject>(
*crypto_ptr);
EVP_PKEY* pkey = wrapper->getEVP_PKEY();
if (pkey == nullptr) {
// TODO(dio): Call luaL_error here instead of failing silently. However, the current behavior
// is to return nil (when calling get() to the wrapped object, hence we create a wrapper
// initialized by an empty string here) when importing a public key is failed.
public_key_wrapper_.reset(PublicKeyWrapper::create(state, EMPTY_STRING), true);
}

public_key_storage_.insert({std::string(str).substr(0, n), std::move(crypto_ptr)});
public_key_wrapper_.reset(PublicKeyWrapper::create(state, str), true);
}

return 1;
Expand Down
3 changes: 3 additions & 0 deletions source/extensions/filters/http/lua/lua_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ class StreamHandleWrapper : public Filters::Common::Lua::BaseLuaObject<StreamHan
State state_{State::Running};
std::function<void()> yield_callback_;
Http::AsyncClient::Request* http_request_{};

// The inserted crypto object pointers will not be removed from this map.
absl::flat_hash_map<std::string, Envoy::Common::Crypto::CryptoObjectPtr> public_key_storage_;
mattklein123 marked this conversation as resolved.
Show resolved Hide resolved
};

/**
Expand Down
10 changes: 2 additions & 8 deletions source/extensions/filters/http/lua/wrappers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,10 @@ int DynamicMetadataMapWrapper::luaPairs(lua_State* state) {
}

int PublicKeyWrapper::luaGet(lua_State* state) {
if (public_key_ == nullptr || public_key_.get() == nullptr) {
if (public_key_.empty()) {
lua_pushnil(state);
} else {
auto wrapper = Common::Crypto::Access::getTyped<Common::Crypto::PublicKeyObject>(*public_key_);
EVP_PKEY* pkey = wrapper->getEVP_PKEY();
if (pkey == nullptr) {
lua_pushnil(state);
} else {
lua_pushlightuserdata(state, public_key_.get());
mattklein123 marked this conversation as resolved.
Show resolved Hide resolved
}
lua_pushstring(state, public_key_.c_str());
}
return 1;
}
Expand Down
10 changes: 5 additions & 5 deletions source/extensions/filters/http/lua/wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,21 +218,21 @@ class StreamInfoWrapper : public Filters::Common::Lua::BaseLuaObject<StreamInfoW
};

/**
* Lua wrapper for EVP_PKEY.
* Lua wrapper for key for accessing the imported public keys.
*/
class PublicKeyWrapper : public Filters::Common::Lua::BaseLuaObject<PublicKeyWrapper> {
public:
PublicKeyWrapper(Envoy::Common::Crypto::CryptoObjectPtr key) : public_key_(std::move(key)) {}
explicit PublicKeyWrapper(absl::string_view key) : public_key_(key) {}
static ExportedFunctions exportedFunctions() { return {{"get", static_luaGet}}; }

private:
/**
* Get a pointer to public key.
* @return pointer to public key.
* Get public key value.
* @return public key value or nil if key is empty.
*/
DECLARE_LUA_FUNCTION(PublicKeyWrapper, luaGet);

Envoy::Common::Crypto::CryptoObjectPtr public_key_;
const std::string public_key_;
};

} // namespace Lua
Expand Down
17 changes: 11 additions & 6 deletions test/extensions/filters/common/lua/lua_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,41 @@ namespace Filters {
namespace Common {
namespace Lua {

// A helper to be called inside the registered closure.
class Printer {
public:
MOCK_CONST_METHOD1(testPrint, void(const std::string&));
};

const Printer& getPrinter() { CONSTRUCT_ON_FIRST_USE(Printer); }

template <class T> class LuaWrappersTestBase : public testing::Test {
public:
virtual void setup(const std::string& code) {
coroutine_.reset();
state_ = std::make_unique<ThreadLocalState>(code, tls_);
state_->registerType<T>();
coroutine_ = state_->createCoroutine();
lua_pushlightuserdata(coroutine_->luaState(), this);
lua_pushcclosure(coroutine_->luaState(), luaTestPrint, 1);
lua_setglobal(coroutine_->luaState(), "testPrint");
testing::Mock::AllowLeak(&printer_);
}

void start(const std::string& method) {
coroutine_->start(state_->getGlobalRef(state_->registerGlobal(method)), 1, yield_callback_);
}

static int luaTestPrint(lua_State* state) {
LuaWrappersTestBase* test =
static_cast<LuaWrappersTestBase*>(lua_touserdata(state, lua_upvalueindex(1)));
const char* message = luaL_checkstring(state, 1);
test->testPrint(message);
getPrinter().testPrint(message);
return 0;
}

MOCK_METHOD(void, testPrint, (const std::string&));

NiceMock<ThreadLocal::MockInstance> tls_;
ThreadLocalStatePtr state_;
std::function<void()> yield_callback_;
CoroutinePtr coroutine_;
const Printer& printer_{getPrinter()};
};

} // namespace Lua
Expand Down
44 changes: 22 additions & 22 deletions test/extensions/filters/common/lua/wrappers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ class LuaConnectionWrapperTest : public LuaWrappersTestBase<ConnectionWrapper> {
EXPECT_CALL(Const(connection_), ssl()).WillOnce(Return(secure ? ssl_ : nullptr));

ConnectionWrapper::create(coroutine_->luaState(), &connection_);
EXPECT_CALL(*this, testPrint(secure ? "secure" : "plain"));
EXPECT_CALL(printer_, testPrint(secure ? "secure" : "plain"));
EXPECT_CALL(Const(connection_), ssl()).WillOnce(Return(secure ? ssl_ : nullptr));
EXPECT_CALL(*this, testPrint(secure ? "userdata" : "nil"));
EXPECT_CALL(printer_, testPrint(secure ? "userdata" : "nil"));
start("callMe");
}

Expand All @@ -82,9 +82,9 @@ TEST_F(LuaBufferWrapperTest, Methods) {
setup(SCRIPT);
Buffer::OwnedImpl data("hello world");
BufferWrapper::create(coroutine_->luaState(), data);
EXPECT_CALL(*this, testPrint("11"));
EXPECT_CALL(*this, testPrint("he"));
EXPECT_CALL(*this, testPrint("world"));
EXPECT_CALL(printer_, testPrint("11"));
EXPECT_CALL(printer_, testPrint("he"));
EXPECT_CALL(printer_, testPrint("world"));
start("callMe");
}

Expand Down Expand Up @@ -169,23 +169,23 @@ TEST_F(LuaMetadataMapWrapperTest, Methods) {
const auto filter_metadata = metadata.filter_metadata().at("envoy.filters.http.lua");
MetadataMapWrapper::create(coroutine_->luaState(), filter_metadata);

EXPECT_CALL(*this, testPrint("pulla"));
EXPECT_CALL(*this, testPrint("finland"));
EXPECT_CALL(printer_, testPrint("pulla"));
EXPECT_CALL(printer_, testPrint("finland"));

EXPECT_CALL(*this, testPrint("true"));
EXPECT_CALL(*this, testPrint("false"));
EXPECT_CALL(printer_, testPrint("true"));
EXPECT_CALL(printer_, testPrint("false"));

EXPECT_CALL(*this, testPrint("5"));
EXPECT_CALL(*this, testPrint("30.5"));
EXPECT_CALL(printer_, testPrint("5"));
EXPECT_CALL(printer_, testPrint("30.5"));

EXPECT_CALL(*this, testPrint("grass_fed"));
EXPECT_CALL(*this, testPrint("false"));
EXPECT_CALL(printer_, testPrint("grass_fed"));
EXPECT_CALL(printer_, testPrint("false"));

EXPECT_CALL(*this, testPrint("flour"));
EXPECT_CALL(*this, testPrint("milk"));
EXPECT_CALL(printer_, testPrint("flour"));
EXPECT_CALL(printer_, testPrint("milk"));

EXPECT_CALL(*this, testPrint("nil"));
EXPECT_CALL(*this, testPrint("0"));
EXPECT_CALL(printer_, testPrint("nil"));
EXPECT_CALL(printer_, testPrint("0"));

start("callMe");
}
Expand Down Expand Up @@ -225,11 +225,11 @@ TEST_F(LuaMetadataMapWrapperTest, Iterators) {
const auto filter_metadata = metadata.filter_metadata().at("envoy.filters.http.lua");
MetadataMapWrapper::create(coroutine_->luaState(), filter_metadata);

EXPECT_CALL(*this, testPrint("'make.delicious.bread' 'pulla'"));
EXPECT_CALL(*this, testPrint("'make.delicious.cookie' 'chewy'"));
EXPECT_CALL(*this, testPrint("'make.nothing0' 'nothing'"));
EXPECT_CALL(*this, testPrint("'make.nothing1' 'nothing'"));
EXPECT_CALL(*this, testPrint("'make.nothing2' 'nothing'"));
EXPECT_CALL(printer_, testPrint("'make.delicious.bread' 'pulla'"));
EXPECT_CALL(printer_, testPrint("'make.delicious.cookie' 'chewy'"));
EXPECT_CALL(printer_, testPrint("'make.nothing0' 'nothing'"));
EXPECT_CALL(printer_, testPrint("'make.nothing1' 'nothing'"));
EXPECT_CALL(printer_, testPrint("'make.nothing2' 'nothing'"));

start("callMe");
}
Expand Down
58 changes: 29 additions & 29 deletions test/extensions/filters/http/lua/wrappers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ TEST_F(LuaHeaderMapWrapperTest, Methods) {

Http::TestRequestHeaderMapImpl headers;
HeaderMapWrapper::create(coroutine_->luaState(), headers, []() { return true; });
EXPECT_CALL(*this, testPrint("WORLD"));
EXPECT_CALL(*this, testPrint("'hello' 'WORLD'"));
EXPECT_CALL(*this, testPrint("'header1' ''"));
EXPECT_CALL(*this, testPrint("'header2' 'foo'"));
EXPECT_CALL(*this, testPrint("'hello' 'WORLD'"));
EXPECT_CALL(*this, testPrint("'header2' 'foo'"));
EXPECT_CALL(printer_, testPrint("WORLD"));
EXPECT_CALL(printer_, testPrint("'hello' 'WORLD'"));
EXPECT_CALL(printer_, testPrint("'header1' ''"));
EXPECT_CALL(printer_, testPrint("'header2' 'foo'"));
EXPECT_CALL(printer_, testPrint("'hello' 'WORLD'"));
EXPECT_CALL(printer_, testPrint("'header2' 'foo'"));
start("callMe");
}

Expand Down Expand Up @@ -169,9 +169,9 @@ TEST_F(LuaHeaderMapWrapperTest, ModifyAfterIteration) {

Http::TestRequestHeaderMapImpl headers{{"foo", "bar"}};
HeaderMapWrapper::create(coroutine_->luaState(), headers, []() { return true; });
EXPECT_CALL(*this, testPrint("'foo' 'bar'"));
EXPECT_CALL(*this, testPrint("'foo' 'bar'"));
EXPECT_CALL(*this, testPrint("'hello' 'world'"));
EXPECT_CALL(printer_, testPrint("'foo' 'bar'"));
EXPECT_CALL(printer_, testPrint("'foo' 'bar'"));
EXPECT_CALL(printer_, testPrint("'hello' 'world'"));
start("callMe");
}

Expand Down Expand Up @@ -242,7 +242,7 @@ class LuaStreamInfoWrapperTest
ON_CALL(stream_info, protocol()).WillByDefault(ReturnPointee(&protocol));
Filters::Common::Lua::LuaDeathRef<StreamInfoWrapper> wrapper(
StreamInfoWrapper::create(coroutine_->luaState(), stream_info), true);
EXPECT_CALL(*this,
EXPECT_CALL(printer_,
testPrint(fmt::format("'{}'", Http::Utility::getProtocolString(protocol.value()))));
start("callMe");
wrapper.reset();
Expand Down Expand Up @@ -295,12 +295,12 @@ TEST_F(LuaStreamInfoWrapperTest, SetGetAndIterateDynamicMetadata) {
EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size());
Filters::Common::Lua::LuaDeathRef<StreamInfoWrapper> wrapper(
StreamInfoWrapper::create(coroutine_->luaState(), stream_info), true);
EXPECT_CALL(*this, testPrint("userdata"));
EXPECT_CALL(*this, testPrint("bar"));
EXPECT_CALL(*this, testPrint("cool"));
EXPECT_CALL(*this, testPrint("'foo' 'bar'"));
EXPECT_CALL(*this, testPrint("'so' 'cool'"));
EXPECT_CALL(*this, testPrint("0"));
EXPECT_CALL(printer_, testPrint("userdata"));
EXPECT_CALL(printer_, testPrint("bar"));
EXPECT_CALL(printer_, testPrint("cool"));
EXPECT_CALL(printer_, testPrint("'foo' 'bar'"));
EXPECT_CALL(printer_, testPrint("'so' 'cool'"));
EXPECT_CALL(printer_, testPrint("0"));
start("callMe");

EXPECT_EQ(1, stream_info.dynamicMetadata().filter_metadata_size());
Expand Down Expand Up @@ -337,13 +337,13 @@ TEST_F(LuaStreamInfoWrapperTest, SetGetComplexDynamicMetadata) {
EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size());
Filters::Common::Lua::LuaDeathRef<StreamInfoWrapper> wrapper(
StreamInfoWrapper::create(coroutine_->luaState(), stream_info), true);
EXPECT_CALL(*this, testPrint("1234"));
EXPECT_CALL(*this, testPrint("baz"));
EXPECT_CALL(*this, testPrint("true"));
EXPECT_CALL(*this, testPrint("cool"));
EXPECT_CALL(*this, testPrint("and"));
EXPECT_CALL(*this, testPrint("dynamic"));
EXPECT_CALL(*this, testPrint("true"));
EXPECT_CALL(printer_, testPrint("1234"));
EXPECT_CALL(printer_, testPrint("baz"));
EXPECT_CALL(printer_, testPrint("true"));
EXPECT_CALL(printer_, testPrint("cool"));
EXPECT_CALL(printer_, testPrint("and"));
EXPECT_CALL(printer_, testPrint("dynamic"));
EXPECT_CALL(printer_, testPrint("true"));
start("callMe");

EXPECT_EQ(1, stream_info.dynamicMetadata().filter_metadata_size());
Expand Down Expand Up @@ -440,12 +440,12 @@ TEST_F(LuaStreamInfoWrapperTest, ModifyAfterIterationForDynamicMetadata) {
EXPECT_EQ(0, stream_info.dynamicMetadata().filter_metadata_size());
Filters::Common::Lua::LuaDeathRef<StreamInfoWrapper> wrapper(
StreamInfoWrapper::create(coroutine_->luaState(), stream_info), true);
EXPECT_CALL(*this, testPrint("envoy.lb"));
EXPECT_CALL(*this, testPrint("'hello' 'world'"));
EXPECT_CALL(*this, testPrint("envoy.proxy"));
EXPECT_CALL(*this, testPrint("'proto' 'grpc'"));
EXPECT_CALL(*this, testPrint("envoy.lb"));
EXPECT_CALL(*this, testPrint("'hello' 'envoy'"));
EXPECT_CALL(printer_, testPrint("envoy.lb"));
EXPECT_CALL(printer_, testPrint("'hello' 'world'"));
EXPECT_CALL(printer_, testPrint("envoy.proxy"));
EXPECT_CALL(printer_, testPrint("'proto' 'grpc'"));
EXPECT_CALL(printer_, testPrint("envoy.lb"));
EXPECT_CALL(printer_, testPrint("'hello' 'envoy'"));
start("callMe");
}

Expand Down
5 changes: 5 additions & 0 deletions tools/code_format/check_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,11 @@ def checkSourceLine(line, file_path, reportError):
reportError("Don't introduce throws into exception-free files, use error " +
"statuses instead.")

if "lua_pushlightuserdata" in line:
reportError("Don't use lua_pushlightuserdata, since it can cause unprotected error in call to" +
"Lua API (bad light userdata pointer) on ARM64 architecture. See " +
"https://github.com/LuaJIT/LuaJIT/issues/450#issuecomment-433659873 for details.")


def checkBuildLine(line, file_path, reportError):
if "@bazel_tools" in line and not (isSkylarkFile(file_path) or file_path.startswith("./bazel/") or
Expand Down