Skip to content

Commit

Permalink
Add public API WriteWithCallback to support custom callbacks (faceb…
Browse files Browse the repository at this point in the history
…ook#12603)

Summary:
This PR adds a `DB::WriteWithCallback` API that does the same things as `DB::Write` while takes an argument `UserWriteCallback` to execute custom callback functions during the write.

We currently support two types of callback functions: `OnWriteEnqueued` and `OnWalWriteFinish`. The former is invoked   after the write is enqueued, and the later is invoked after WAL write finishes when applicable.

These callback functions are intended for users to use to improve synchronization between concurrent writes, their execution is on the write's critical path so it will impact the write's latency if not used properly. The documentation for the callback interface mentioned this and suggest user to keep these callback functions' implementation minimum.

Although transaction interfaces' writes doesn't yet allow user to specify such a user write callback argument, the `DBImpl::Write*` type of APIs do not differentiate between regular DB writes or writes coming from the transaction layer when it comes to supporting this `UserWriteCallback`. These callbacks works for all the write modes including: default write mode, Options.two_write_queues, Options.unordered_write, Options.enable_pipelined_write

Pull Request resolved: facebook#12603

Test Plan: Added unit test in ./write_callback_test

Reviewed By: anand1976

Differential Revision: D58044638

Pulled By: jowlyzhang

fbshipit-source-id: 87a84a0221df8f589ec8fc4d74597e72ce97e4cd
  • Loading branch information
jowlyzhang authored and facebook-github-bot committed Jun 1, 2024
1 parent f3b7e95 commit fc59d8f
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 64 deletions.
13 changes: 11 additions & 2 deletions db/db_impl/db_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "rocksdb/status.h"
#include "rocksdb/trace_reader_writer.h"
#include "rocksdb/transaction_log.h"
#include "rocksdb/user_write_callback.h"
#include "rocksdb/utilities/replayer.h"
#include "rocksdb/write_buffer_manager.h"
#include "table/merging_iterator.h"
Expand Down Expand Up @@ -231,6 +232,10 @@ class DBImpl : public DB {
using DB::Write;
Status Write(const WriteOptions& options, WriteBatch* updates) override;

using DB::WriteWithCallback;
Status WriteWithCallback(const WriteOptions& options, WriteBatch* updates,
UserWriteCallback* user_write_cb) override;

using DB::Get;
Status Get(const ReadOptions& _read_options,
ColumnFamilyHandle* column_family, const Slice& key,
Expand Down Expand Up @@ -688,7 +693,8 @@ class DBImpl : public DB {
// thread to determine whether it is safe to perform the write.
virtual Status WriteWithCallback(const WriteOptions& write_options,
WriteBatch* my_batch,
WriteCallback* callback);
WriteCallback* callback,
UserWriteCallback* user_write_cb = nullptr);

// Returns the sequence number that is guaranteed to be smaller than or equal
// to the sequence number of any key that could be inserted into the current
Expand Down Expand Up @@ -1497,6 +1503,7 @@ class DBImpl : public DB {
// batch that does not have duplicate keys.
Status WriteImpl(const WriteOptions& options, WriteBatch* updates,
WriteCallback* callback = nullptr,
UserWriteCallback* user_write_cb = nullptr,
uint64_t* log_used = nullptr, uint64_t log_ref = 0,
bool disable_memtable = false, uint64_t* seq_used = nullptr,
size_t batch_cnt = 0,
Expand All @@ -1505,6 +1512,7 @@ class DBImpl : public DB {

Status PipelinedWriteImpl(const WriteOptions& options, WriteBatch* updates,
WriteCallback* callback = nullptr,
UserWriteCallback* user_write_cb = nullptr,
uint64_t* log_used = nullptr, uint64_t log_ref = 0,
bool disable_memtable = false,
uint64_t* seq_used = nullptr);
Expand All @@ -1531,7 +1539,8 @@ class DBImpl : public DB {
// marks start of a new sub-batch.
Status WriteImplWALOnly(
WriteThread* write_thread, const WriteOptions& options,
WriteBatch* updates, WriteCallback* callback, uint64_t* log_used,
WriteBatch* updates, WriteCallback* callback,
UserWriteCallback* user_write_cb, uint64_t* log_used,
const uint64_t log_ref, uint64_t* seq_used, const size_t sub_batch_cnt,
PreReleaseCallback* pre_release_callback, const AssignOrder assign_order,
const PublishLastSeq publish_last_seq, const bool disable_memtable);
Expand Down
75 changes: 52 additions & 23 deletions db/db_impl/db_impl_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,36 @@ Status DBImpl::Write(const WriteOptions& write_options, WriteBatch* my_batch) {
}
if (s.ok()) {
s = WriteImpl(write_options, my_batch, /*callback=*/nullptr,
/*user_write_cb=*/nullptr,
/*log_used=*/nullptr);
}
return s;
}

Status DBImpl::WriteWithCallback(const WriteOptions& write_options,
WriteBatch* my_batch, WriteCallback* callback,
UserWriteCallback* user_write_cb) {
Status s;
if (write_options.protection_bytes_per_key > 0) {
s = WriteBatchInternal::UpdateProtectionInfo(
my_batch, write_options.protection_bytes_per_key);
}
if (s.ok()) {
s = WriteImpl(write_options, my_batch, callback, user_write_cb);
}
return s;
}

Status DBImpl::WriteWithCallback(const WriteOptions& write_options,
WriteBatch* my_batch,
WriteCallback* callback) {
UserWriteCallback* user_write_cb) {
Status s;
if (write_options.protection_bytes_per_key > 0) {
s = WriteBatchInternal::UpdateProtectionInfo(
my_batch, write_options.protection_bytes_per_key);
}
if (s.ok()) {
s = WriteImpl(write_options, my_batch, callback, nullptr);
s = WriteImpl(write_options, my_batch, /*callback=*/nullptr, user_write_cb);
}
return s;
}
Expand All @@ -179,9 +194,9 @@ Status DBImpl::WriteWithCallback(const WriteOptions& write_options,
// published sequence.
Status DBImpl::WriteImpl(const WriteOptions& write_options,
WriteBatch* my_batch, WriteCallback* callback,
uint64_t* log_used, uint64_t log_ref,
bool disable_memtable, uint64_t* seq_used,
size_t batch_cnt,
UserWriteCallback* user_write_cb, uint64_t* log_used,
uint64_t log_ref, bool disable_memtable,
uint64_t* seq_used, size_t batch_cnt,
PreReleaseCallback* pre_release_callback,
PostMemTableCallback* post_memtable_callback) {
assert(!seq_per_batch_ || batch_cnt != 0);
Expand Down Expand Up @@ -288,10 +303,10 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options,
seq_per_batch_ ? kDoAssignOrder : kDontAssignOrder;
// Otherwise it is WAL-only Prepare batches in WriteCommitted policy and
// they don't consume sequence.
return WriteImplWALOnly(&nonmem_write_thread_, write_options, my_batch,
callback, log_used, log_ref, seq_used, batch_cnt,
pre_release_callback, assign_order,
kDontPublishLastSeq, disable_memtable);
return WriteImplWALOnly(
&nonmem_write_thread_, write_options, my_batch, callback, user_write_cb,
log_used, log_ref, seq_used, batch_cnt, pre_release_callback,
assign_order, kDontPublishLastSeq, disable_memtable);
}

if (immutable_db_options_.unordered_write) {
Expand All @@ -303,9 +318,9 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options,
// Use a write thread to i) optimize for WAL write, ii) publish last
// sequence in in increasing order, iii) call pre_release_callback serially
Status status = WriteImplWALOnly(
&write_thread_, write_options, my_batch, callback, log_used, log_ref,
&seq, sub_batch_cnt, pre_release_callback, kDoAssignOrder,
kDoPublishLastSeq, disable_memtable);
&write_thread_, write_options, my_batch, callback, user_write_cb,
log_used, log_ref, &seq, sub_batch_cnt, pre_release_callback,
kDoAssignOrder, kDoPublishLastSeq, disable_memtable);
TEST_SYNC_POINT("DBImpl::WriteImpl:UnorderedWriteAfterWriteWAL");
if (!status.ok()) {
return status;
Expand All @@ -322,14 +337,14 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options,
}

if (immutable_db_options_.enable_pipelined_write) {
return PipelinedWriteImpl(write_options, my_batch, callback, log_used,
log_ref, disable_memtable, seq_used);
return PipelinedWriteImpl(write_options, my_batch, callback, user_write_cb,
log_used, log_ref, disable_memtable, seq_used);
}

PERF_TIMER_GUARD(write_pre_and_post_process_time);
WriteThread::Writer w(write_options, my_batch, callback, log_ref,
disable_memtable, batch_cnt, pre_release_callback,
post_memtable_callback);
WriteThread::Writer w(write_options, my_batch, callback, user_write_cb,
log_ref, disable_memtable, batch_cnt,
pre_release_callback, post_memtable_callback);
StopWatch write_sw(immutable_db_options_.clock, stats_, DB_WRITE);

write_thread_.JoinBatchGroup(&w);
Expand Down Expand Up @@ -686,15 +701,16 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options,

Status DBImpl::PipelinedWriteImpl(const WriteOptions& write_options,
WriteBatch* my_batch, WriteCallback* callback,
UserWriteCallback* user_write_cb,
uint64_t* log_used, uint64_t log_ref,
bool disable_memtable, uint64_t* seq_used) {
PERF_TIMER_GUARD(write_pre_and_post_process_time);
StopWatch write_sw(immutable_db_options_.clock, stats_, DB_WRITE);

WriteContext write_context;

WriteThread::Writer w(write_options, my_batch, callback, log_ref,
disable_memtable, /*_batch_cnt=*/0,
WriteThread::Writer w(write_options, my_batch, callback, user_write_cb,
log_ref, disable_memtable, /*_batch_cnt=*/0,
/*_pre_release_callback=*/nullptr);
write_thread_.JoinBatchGroup(&w);
TEST_SYNC_POINT("DBImplWrite::PipelinedWriteImpl:AfterJoinBatchGroup");
Expand Down Expand Up @@ -875,7 +891,8 @@ Status DBImpl::UnorderedWriteMemtable(const WriteOptions& write_options,
PERF_TIMER_GUARD(write_pre_and_post_process_time);
StopWatch write_sw(immutable_db_options_.clock, stats_, DB_WRITE);

WriteThread::Writer w(write_options, my_batch, callback, log_ref,
WriteThread::Writer w(write_options, my_batch, callback,
/*user_write_cb=*/nullptr, log_ref,
false /*disable_memtable*/);

if (w.CheckCallback(this) && w.ShouldWriteToMemtable()) {
Expand Down Expand Up @@ -925,13 +942,15 @@ Status DBImpl::UnorderedWriteMemtable(const WriteOptions& write_options,
// applicable in a two-queue setting.
Status DBImpl::WriteImplWALOnly(
WriteThread* write_thread, const WriteOptions& write_options,
WriteBatch* my_batch, WriteCallback* callback, uint64_t* log_used,
WriteBatch* my_batch, WriteCallback* callback,
UserWriteCallback* user_write_cb, uint64_t* log_used,
const uint64_t log_ref, uint64_t* seq_used, const size_t sub_batch_cnt,
PreReleaseCallback* pre_release_callback, const AssignOrder assign_order,
const PublishLastSeq publish_last_seq, const bool disable_memtable) {
PERF_TIMER_GUARD(write_pre_and_post_process_time);
WriteThread::Writer w(write_options, my_batch, callback, log_ref,
disable_memtable, sub_batch_cnt, pre_release_callback);
WriteThread::Writer w(write_options, my_batch, callback, user_write_cb,
log_ref, disable_memtable, sub_batch_cnt,
pre_release_callback);
StopWatch write_sw(immutable_db_options_.clock, stats_, DB_WRITE);

write_thread->JoinBatchGroup(&w);
Expand Down Expand Up @@ -1498,6 +1517,11 @@ IOStatus DBImpl::WriteToWAL(const WriteThread::WriteGroup& write_group,
RecordTick(stats_, WAL_FILE_BYTES, log_size);
stats->AddDBStats(InternalStats::kIntStatsWriteWithWal, write_with_wal);
RecordTick(stats_, WRITE_WITH_WAL, write_with_wal);
for (auto* writer : write_group) {
if (!writer->CallbackFailed()) {
writer->CheckPostWalWriteCallback();
}
}
}
return io_s;
}
Expand Down Expand Up @@ -1562,6 +1586,11 @@ IOStatus DBImpl::ConcurrentWriteToWAL(
stats->AddDBStats(InternalStats::kIntStatsWriteWithWal, write_with_wal,
concurrent);
RecordTick(stats_, WRITE_WITH_WAL, write_with_wal);
for (auto* writer : write_group) {
if (!writer->CallbackFailed()) {
writer->CheckPostWalWriteCallback();
}
}
}
return io_s;
}
Expand Down
51 changes: 46 additions & 5 deletions db/write_callback_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// This source code is licensed under both the GPLv2 (found in the
// COPYING file in the root directory) and Apache 2.0 License
// (found in the LICENSE.Apache file in the root directory).


#include "db/write_callback.h"

#include <atomic>
Expand All @@ -15,6 +13,7 @@
#include "db/db_impl/db_impl.h"
#include "port/port.h"
#include "rocksdb/db.h"
#include "rocksdb/user_write_callback.h"
#include "rocksdb/write_batch.h"
#include "test_util/sync_point.h"
#include "test_util/testharness.h"
Expand Down Expand Up @@ -84,6 +83,28 @@ class MockWriteCallback : public WriteCallback {
bool AllowWriteBatching() override { return allow_batching_; }
};

class MockUserWriteCallback : public UserWriteCallback {
public:
std::atomic<bool> write_enqueued_{false};
std::atomic<bool> wal_write_done_{false};

MockUserWriteCallback() = default;

MockUserWriteCallback(const MockUserWriteCallback& other) {
write_enqueued_.store(other.write_enqueued_.load());
wal_write_done_.store(other.wal_write_done_.load());
}

void OnWriteEnqueued() override { write_enqueued_.store(true); }

void OnWalWriteFinish() override { wal_write_done_.store(true); }

void Reset() {
write_enqueued_.store(false);
wal_write_done_.store(false);
}
};

#if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN)
class WriteCallbackPTest
: public WriteCallbackTest,
Expand Down Expand Up @@ -119,9 +140,11 @@ TEST_P(WriteCallbackPTest, WriteWithCallbackTest) {
kvs_.clear();
write_batch_.Clear();
callback_.was_called_.store(false);
user_write_cb_.Reset();
}

MockWriteCallback callback_;
MockUserWriteCallback user_write_cb_;
WriteBatch write_batch_;
std::vector<std::pair<string, string>> kvs_;
};
Expand Down Expand Up @@ -327,18 +350,26 @@ TEST_P(WriteCallbackPTest, WriteWithCallbackTest) {
ASSERT_OK(WriteBatchInternal::InsertNoop(&write_op.write_batch_));
const size_t ONE_BATCH = 1;
s = db_impl->WriteImpl(woptions, &write_op.write_batch_,
&write_op.callback_, nullptr, 0, false, nullptr,
ONE_BATCH,
&write_op.callback_, &write_op.user_write_cb_,
nullptr, 0, false, nullptr, ONE_BATCH,
two_queues_ ? &publish_seq_callback : nullptr);
} else {
s = db_impl->WriteWithCallback(woptions, &write_op.write_batch_,
&write_op.callback_);
&write_op.callback_,
&write_op.user_write_cb_);
}

ASSERT_TRUE(write_op.user_write_cb_.write_enqueued_.load());
if (write_op.callback_.should_fail_) {
ASSERT_TRUE(s.IsBusy());
ASSERT_FALSE(write_op.user_write_cb_.wal_write_done_.load());
} else {
ASSERT_OK(s);
if (enable_WAL_) {
ASSERT_TRUE(write_op.user_write_cb_.wal_write_done_.load());
} else {
ASSERT_FALSE(write_op.user_write_cb_.wal_write_done_.load());
}
}
};

Expand Down Expand Up @@ -440,6 +471,16 @@ TEST_F(WriteCallbackTest, WriteCallBackTest) {
ASSERT_OK(s);
ASSERT_EQ("value.a2", value);

MockUserWriteCallback user_write_cb;
WriteBatch wb4;
ASSERT_OK(wb4.Put("a", "value.a4"));

ASSERT_OK(db->WriteWithCallback(write_options, &wb4, &user_write_cb));
ASSERT_OK(db->Get(read_options, "a", &value));
ASSERT_EQ(value, "value.a4");
ASSERT_TRUE(user_write_cb.write_enqueued_.load());
ASSERT_TRUE(user_write_cb.wal_write_done_.load());

delete db;
ASSERT_OK(DestroyDB(dbname, options));
}
Expand Down
2 changes: 2 additions & 0 deletions db/write_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ void WriteThread::JoinBatchGroup(Writer* w) {

bool linked_as_leader = LinkOne(w, &newest_writer_);

w->CheckWriteEnqueuedCallback();

if (linked_as_leader) {
SetState(w, STATE_GROUP_LEADER);
}
Expand Down
Loading

0 comments on commit fc59d8f

Please sign in to comment.