Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PS-9148: implemented lazy query_cache initial population #2

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace masking_functions {
// mysql_command_query{
// mysql_service_mysql_command_query,
// mysql_service_mysql_command_query_result,
// mysql_service_mysql_command_field_info,
// mysql_service_mysql_command_options,
// mysql_service_mysql_command_factory
// };
Expand All @@ -43,6 +44,7 @@ namespace masking_functions {
struct command_service_tuple {
SERVICE_TYPE(mysql_command_query) * query;
SERVICE_TYPE(mysql_command_query_result) * query_result;
SERVICE_TYPE(mysql_command_field_info) * field_info;
SERVICE_TYPE(mysql_command_options) * options;
SERVICE_TYPE(mysql_command_factory) * factory;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ class query_cache {
bool remove(const std::string &dictionary_name, const std::string &term);
bool insert(const std::string &dictionary_name, const std::string &term);

bool load_cache();
void reload_cache();

private:
bookshelf_ptr m_dict_cache;

query_builder_ptr m_query_builder;

// TODO: in c++20 change this to std::atomic<bookshelf_ptr> and
// remove deprecated atomic_load() / atomic_store()
mutable bookshelf_ptr m_dict_cache;

std::uint64_t m_flusher_interval_seconds;
std::atomic<bool> m_is_flusher_stopped;
std::mutex m_flusher_mutex;
Expand All @@ -72,6 +74,10 @@ class query_cache {
void dict_flusher() noexcept;

static void *run_dict_flusher(void *arg);

bookshelf_ptr create_dict_cache_internal() const;
// returning deliberately by value to increase reference counter
bookshelf_ptr get_pinned_dict_cache_internal() const;
};

} // namespace masking_functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
#ifndef MASKING_FUNCTIONS_SQL_CONTEXT_HPP
#define MASKING_FUNCTIONS_SQL_CONTEXT_HPP

#include <algorithm>
#include <array>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <string_view>

#include "masking_functions/bookshelf_fwd.hpp"
#include "masking_functions/command_service_tuple_fwd.hpp"

namespace masking_functions {
Expand All @@ -32,6 +33,13 @@ namespace masking_functions {
// construction.
class sql_context {
public:
template <std::size_t NumberOfFields>
using field_value_container = std::array<std::string_view, NumberOfFields>;

template <std::size_t NumberOfFields>
using row_callback =
std::function<void(const field_value_container<NumberOfFields> &)>;

explicit sql_context(const command_service_tuple &services);

sql_context(sql_context const &) = delete;
Expand All @@ -46,9 +54,23 @@ class sql_context {
return *impl_.get_deleter().services;
}

bookshelf_ptr query_list(std::string_view query);
template <std::size_t NumberOfFields>
void execute_select(std::string_view query,
const row_callback<NumberOfFields> &callback) {
execute_select_internal(
query, NumberOfFields,
[&callback](char **field_values, std::size_t *lengths) {
field_value_container<NumberOfFields> wrapped_field_values;
std::transform(field_values, field_values + NumberOfFields, lengths,
std::begin(wrapped_field_values),
[](char *str, std::size_t len) {
return std::string_view{str, len};
});
callback(wrapped_field_values);
});
}

bool execute(std::string_view query);
bool execute_dml(std::string_view query);

private:
struct deleter {
Expand All @@ -57,6 +79,11 @@ class sql_context {
};
using impl_type = std::unique_ptr<void, deleter>;
impl_type impl_;

using row_internal_callback = std::function<void(char **, std::size_t *)>;
void execute_select_internal(std::string_view query,
std::size_t number_of_fields,
const row_internal_callback &callback);
};

} // namespace masking_functions
Expand Down
3 changes: 3 additions & 0 deletions components/masking_functions/src/component.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ REQUIRES_SERVICE_PLACEHOLDER(mysql_string_compare);

REQUIRES_SERVICE_PLACEHOLDER(mysql_command_query);
REQUIRES_SERVICE_PLACEHOLDER(mysql_command_query_result);
REQUIRES_SERVICE_PLACEHOLDER(mysql_command_field_info);
REQUIRES_SERVICE_PLACEHOLDER(mysql_command_options);
REQUIRES_SERVICE_PLACEHOLDER(mysql_command_factory);

Expand Down Expand Up @@ -118,6 +119,7 @@ static mysql_service_status_t component_init() {
// TODO: convert this to designated initializers in c++20
mysql_service_mysql_command_query,
mysql_service_mysql_command_query_result,
mysql_service_mysql_command_field_info,
mysql_service_mysql_command_options,
mysql_service_mysql_command_factory};
masking_functions::primitive_singleton<
Expand Down Expand Up @@ -227,6 +229,7 @@ BEGIN_COMPONENT_REQUIRES(CURRENT_COMPONENT_NAME)

REQUIRES_SERVICE(mysql_command_query),
REQUIRES_SERVICE(mysql_command_query_result),
REQUIRES_SERVICE(mysql_command_field_info),
REQUIRES_SERVICE(mysql_command_options),
REQUIRES_SERVICE(mysql_command_factory),

Expand Down
76 changes: 57 additions & 19 deletions components/masking_functions/src/masking_functions/query_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,15 @@ namespace masking_functions {
query_cache::query_cache(query_builder_ptr query_builder,
std::uint64_t flusher_interval_seconds)
: m_query_builder{std::move(query_builder)},
m_dict_cache{},
m_flusher_interval_seconds{flusher_interval_seconds},
m_is_flusher_stopped{true} {
load_cache();
// we do not initialize m_dict_cache with create_dict_cache_internal() here
// as this constructor is called from the component initialization method
// and any call to mysql_command_query service may mess up with current THD

// the cache will be loaded during the first call to one of the dictionary
// functions or by the flusher thread
if (m_flusher_interval_seconds > 0) {
PSI_thread_info thread_info{&m_psi_flusher_thread_key,
flusher_thd_psi_name,
Expand Down Expand Up @@ -82,63 +87,62 @@ query_cache::~query_cache() {
}
}

bool query_cache::load_cache() {
masking_functions::sql_context sql_ctx{global_command_services::instance()};
auto query = m_query_builder->select_all_from_dictionary();
auto result = sql_ctx.query_list(query);

if (result) {
// TODO: in c++20 change to m_dict_cache to std::atomic<bookshelf_ptr>
std::atomic_store(&m_dict_cache, result);
void query_cache::reload_cache() {
auto local_dict_cache{create_dict_cache_internal()};
if (!local_dict_cache) {
throw std::runtime_error{"Cannot load dictionary cache"};
}

return static_cast<bool>(result);
std::atomic_store(&m_dict_cache, local_dict_cache);
}

bool query_cache::contains(const std::string &dictionary_name,
const std::string &term) const {
return m_dict_cache->contains(dictionary_name, term);
return get_pinned_dict_cache_internal()->contains(dictionary_name, term);
}

optional_string query_cache::get_random(
const std::string &dictionary_name) const {
return m_dict_cache->get_random(dictionary_name);
return get_pinned_dict_cache_internal()->get_random(dictionary_name);
}

bool query_cache::remove(const std::string &dictionary_name) {
auto local_dict_cache{get_pinned_dict_cache_internal()};
masking_functions::sql_context sql_ctx{global_command_services::instance()};
auto query = m_query_builder->delete_for_dictionary(dictionary_name);

if (!sql_ctx.execute(query)) {
if (!sql_ctx.execute_dml(query)) {
return false;
}

return m_dict_cache->remove(dictionary_name);
return local_dict_cache->remove(dictionary_name);
}

bool query_cache::remove(const std::string &dictionary_name,
const std::string &term) {
auto local_dict_cache{get_pinned_dict_cache_internal()};
masking_functions::sql_context sql_ctx{global_command_services::instance()};
auto query =
m_query_builder->delete_for_dictionary_and_term(dictionary_name, term);

if (!sql_ctx.execute(query)) {
if (!sql_ctx.execute_dml(query)) {
return false;
}

return m_dict_cache->remove(dictionary_name, term);
return local_dict_cache->remove(dictionary_name, term);
}

bool query_cache::insert(const std::string &dictionary_name,
const std::string &term) {
auto local_dict_cache{get_pinned_dict_cache_internal()};
masking_functions::sql_context sql_ctx{global_command_services::instance()};
auto query = m_query_builder->insert_ignore_record(dictionary_name, term);

if (!sql_ctx.execute(query)) {
if (!sql_ctx.execute_dml(query)) {
return false;
}

return m_dict_cache->insert(dictionary_name, term);
return local_dict_cache->insert(dictionary_name, term);
}

void query_cache::init_thd() noexcept {
Expand Down Expand Up @@ -177,7 +181,10 @@ void query_cache::dict_flusher() noexcept {
});

if (!m_is_flusher_stopped) {
load_cache();
auto local_dict_cache{create_dict_cache_internal()};
if (local_dict_cache) {
std::atomic_store(&m_dict_cache, local_dict_cache);
}

DBUG_EXECUTE_IF("masking_functions_signal_on_cache_reload", {
const char act[] = "now SIGNAL masking_functions_cache_reload_done";
Expand All @@ -195,4 +202,35 @@ void *query_cache::run_dict_flusher(void *arg) {
return nullptr;
}

bookshelf_ptr query_cache::create_dict_cache_internal() const {
bookshelf_ptr result;
try {
masking_functions::sql_context sql_ctx{global_command_services::instance()};
auto query = m_query_builder->select_all_from_dictionary();
auto local_dict_cache{std::make_shared<bookshelf>()};
sql_context::row_callback<2> result_inserter{[&terms = *local_dict_cache](
const auto &field_values) {
terms.insert(std::string{field_values[0]}, std::string{field_values[1]});
}};
sql_ctx.execute_select(query, result_inserter);
result = local_dict_cache;
} catch (...) {
}

return result;
}

bookshelf_ptr query_cache::get_pinned_dict_cache_internal() const {
auto local_dict_cache{std::atomic_load(&m_dict_cache)};
if (!local_dict_cache) {
local_dict_cache = create_dict_cache_internal();
if (!local_dict_cache) {
throw std::runtime_error{"Cannot load dictionary cache"};
}
std::atomic_store(&m_dict_cache, local_dict_cache);
}

return local_dict_cache;
}

} // namespace masking_functions
Original file line number Diff line number Diff line change
Expand Up @@ -1061,9 +1061,7 @@ class masking_dictionaries_flush_impl {

mysqlpp::udf_result_t<STRING_RESULT> calculate(const mysqlpp::udf_context &ctx
[[maybe_unused]]) {
if (!global_query_cache::instance()->load_cache()) {
return std::nullopt;
}
global_query_cache::instance()->reload_cache();

return "1";
}
Expand Down
61 changes: 35 additions & 26 deletions components/masking_functions/src/masking_functions/sql_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <stdexcept>
#include <type_traits>

#include "masking_functions/bookshelf.hpp"
#include "masking_functions/command_service_tuple.hpp"
#include "masking_functions/sql_context.hpp"

Expand Down Expand Up @@ -82,12 +81,38 @@ sql_context::sql_context(const command_service_tuple &services)
}
}

bookshelf_ptr sql_context::query_list(std::string_view query) {
bool sql_context::execute_dml(std::string_view query) {
if ((*get_services().query->query)(to_mysql_h(impl_.get()), query.data(),
query.length()) != 0) {
return false;
}
std::uint64_t row_count = 0;
if ((*get_services().query->affected_rows)(to_mysql_h(impl_.get()),
&row_count) != 0) {
return false;
}
return row_count > 0;
}

void sql_context::execute_select_internal(
std::string_view query, std::size_t expected_number_of_fields,
const row_internal_callback &callback) {
if ((*get_services().query->query)(to_mysql_h(impl_.get()), query.data(),
query.length()) != 0) {
throw std::runtime_error{"Error while executing SQL query"};
}

unsigned int actual_number_of_fields;
if ((*get_services().field_info->field_count)(
to_mysql_h(impl_.get()), &actual_number_of_fields) != 0) {
throw std::runtime_error{"Couldn't get number of fields"};
}

if (actual_number_of_fields != expected_number_of_fields) {
throw std::runtime_error{
"Micmatch between actual and expected number of fields"};
}

MYSQL_RES_H mysql_res = nullptr;
if ((*get_services().query_result->store_result)(to_mysql_h(impl_.get()),
&mysql_res) != 0) {
Expand All @@ -106,7 +131,7 @@ bookshelf_ptr sql_context::query_list(std::string_view query) {
std::unique_ptr<mysql_res_type, decltype(mysql_res_deleter)>;

mysql_res_ptr mysql_res_guard(mysql_res, std::move(mysql_res_deleter));
uint64_t row_count = 0;
std::uint64_t row_count = 0;
// As the 'affected_rows()' method of the 'mysql_command_query' MySQL
// service is implementted via 'mysql_affected_rows()' MySQL client
// function, it is OK to use it for SELECT statements as well, because
Expand All @@ -115,35 +140,19 @@ bookshelf_ptr sql_context::query_list(std::string_view query) {
&row_count) != 0)
throw std::runtime_error{"Couldn't query row count"};

bookshelf_ptr result{std::make_shared<bookshelf>()};

for (auto i = row_count; i > 0; --i) {
MYSQL_ROW_H row = nullptr;
ulong *length = nullptr;
MYSQL_ROW_H field_values = nullptr;
ulong *field_value_lengths = nullptr;

if ((*get_services().query_result->fetch_row)(mysql_res, &row) != 0)
if ((*get_services().query_result->fetch_row)(mysql_res, &field_values) !=
0)
throw std::runtime_error{"Couldn't fetch length"};
if ((*get_services().query_result->fetch_lengths)(mysql_res, &length) != 0)
if ((*get_services().query_result->fetch_lengths)(
mysql_res, &field_value_lengths) != 0)
throw std::runtime_error{"Couldn't fetch length"};

result->insert(std::string{row[0], length[0]},
std::string{row[1], length[1]});
}

return result;
}

bool sql_context::execute(std::string_view query) {
if ((*get_services().query->query)(to_mysql_h(impl_.get()), query.data(),
query.length()) != 0) {
return false;
callback(field_values, field_value_lengths);
}
uint64_t row_count = 0;
if ((*get_services().query->affected_rows)(to_mysql_h(impl_.get()),
&row_count) != 0) {
return false;
}
return row_count > 0;
}

} // namespace masking_functions
Loading
Loading