From 16bf38e2bbf46c24daa5af01b0f0d0849cbb0a02 Mon Sep 17 00:00:00 2001 From: v01dstar Date: Sun, 19 May 2024 01:11:56 -0700 Subject: [PATCH] Add patch files Signed-off-by: v01dstar --- .../0010-Implement-multi-batches-write.patch | 1722 +++++++++++++++ ...cryptedEnv-for-per-file-key-manageme.patch | 1909 +++++++++++++++++ 2 files changed, 3631 insertions(+) create mode 100644 tikv-rocksdb-patches/0010-Implement-multi-batches-write.patch create mode 100644 tikv-rocksdb-patches/0011-Add-KeyManagedEncryptedEnv-for-per-file-key-manageme.patch diff --git a/tikv-rocksdb-patches/0010-Implement-multi-batches-write.patch b/tikv-rocksdb-patches/0010-Implement-multi-batches-write.patch new file mode 100644 index 00000000000..e3d7af61d87 --- /dev/null +++ b/tikv-rocksdb-patches/0010-Implement-multi-batches-write.patch @@ -0,0 +1,1722 @@ +From 05caaee6d427be463c8f0918a303ea9949761edb Mon Sep 17 00:00:00 2001 +From: v01dstar +Date: Sun, 19 May 2024 00:18:22 -0700 +Subject: [PATCH] Implement multi batches write + +Implement multi batches write + +Signed-off-by: v01dstar + +Fix SIGABRT caused by uninitialized mutex (#296) (#298) + +* Fix SIGABRT caused by uninitialized mutex + +Signed-off-by: Wenbo Zhang + +* Use spinlock instead of mutex to reduce writer ctor cost + +Signed-off-by: Wenbo Zhang + +* Update db/write_thread.h + +Co-authored-by: Xinye Tao +Signed-off-by: Wenbo Zhang + +Co-authored-by: Xinye Tao +Signed-off-by: Wenbo Zhang + +Co-authored-by: Xinye Tao +--- + db/column_family.cc | 7 + + db/db_impl/db_impl.h | 14 +- + db/db_impl/db_impl_open.cc | 12 +- + db/db_impl/db_impl_secondary.cc | 2 +- + db/db_impl/db_impl_write.cc | 296 ++++++++++++++++++++++++++++--- + db/db_kv_checksum_test.cc | 17 +- + db/db_properties_test.cc | 3 + + db/db_test_util.cc | 6 + + db/db_test_util.h | 1 + + db/db_write_test.cc | 82 ++++++++- + db/external_sst_file_test.cc | 30 ++++ + db/write_batch.cc | 25 ++- + db/write_batch_internal.h | 13 +- + db/write_callback_test.cc | 6 + + db/write_thread.cc | 92 ++++++++-- + db/write_thread.h | 150 +++++++++++++++- + include/rocksdb/db.h | 5 + + include/rocksdb/options.h | 16 ++ + options/db_options.cc | 7 + + options/db_options.h | 1 + + options/options_helper.cc | 2 + + options/options_settable_test.cc | 1 + + tools/db_bench_tool.cc | 98 ++++++++-- + 23 files changed, 802 insertions(+), 84 deletions(-) + +diff --git a/db/column_family.cc b/db/column_family.cc +index be9e1e5d5..8baab2942 100644 +--- a/db/column_family.cc ++++ b/db/column_family.cc +@@ -1515,6 +1515,13 @@ Status ColumnFamilyData::ValidateOptions( + } + } + } ++ ++ if (db_options.enable_multi_batch_write && ++ cf_options.max_successive_merges > 0) { ++ return Status::NotSupported( ++ "Multi thread write is only supported with no successive merges"); ++ } ++ + return s; + } + +diff --git a/db/db_impl/db_impl.h b/db/db_impl/db_impl.h +index 43902fc49..64a9c3c4b 100644 +--- a/db/db_impl/db_impl.h ++++ b/db/db_impl/db_impl.h +@@ -232,6 +232,10 @@ class DBImpl : public DB { + virtual Status Write(const WriteOptions& options, + WriteBatch* updates) override; + ++ using DB::MultiBatchWrite; ++ virtual Status MultiBatchWrite(const WriteOptions& options, ++ std::vector&& updates) override; ++ + using DB::Get; + virtual Status Get(const ReadOptions& options, + ColumnFamilyHandle* column_family, const Slice& key, +@@ -1473,6 +1477,13 @@ class DBImpl : public DB { + PreReleaseCallback* pre_release_callback = nullptr, + PostMemTableCallback* post_memtable_callback = nullptr); + ++ Status MultiBatchWriteImpl(const WriteOptions& write_options, ++ std::vector&& my_batch, ++ WriteCallback* callback = nullptr, ++ uint64_t* log_used = nullptr, uint64_t log_ref = 0, ++ uint64_t* seq_used = nullptr); ++ void MultiBatchWriteCommit(CommitRequest* request); ++ + Status PipelinedWriteImpl(const WriteOptions& options, WriteBatch* updates, + WriteCallback* callback = nullptr, + uint64_t* log_used = nullptr, uint64_t log_ref = 0, +@@ -2003,7 +2014,8 @@ class DBImpl : public DB { + mutex_.Lock(); + } + +- if (!immutable_db_options_.unordered_write) { ++ if (!immutable_db_options_.unordered_write && ++ !immutable_db_options_.enable_multi_batch_write) { + // Then the writes are finished before the next write group starts + return; + } +diff --git a/db/db_impl/db_impl_open.cc b/db/db_impl/db_impl_open.cc +index e2f1aea23..8705b78ef 100644 +--- a/db/db_impl/db_impl_open.cc ++++ b/db/db_impl/db_impl_open.cc +@@ -151,6 +151,16 @@ DBOptions SanitizeOptions(const std::string& dbname, const DBOptions& src, + result.avoid_flush_during_recovery = false; + } + ++ // multi thread write do not support two-write-que or write in 2PC ++ if (result.two_write_queues || result.allow_2pc) { ++ result.enable_multi_batch_write = false; ++ } ++ ++ if (result.enable_multi_batch_write) { ++ result.enable_pipelined_write = false; ++ result.allow_concurrent_memtable_write = true; ++ } ++ + ImmutableDBOptions immutable_db_options(result); + if (!immutable_db_options.IsWalDirSameAsDBPath()) { + // Either the WAL dir and db_paths[0]/db_name are not the same, or we +@@ -1289,7 +1299,7 @@ Status DBImpl::RecoverLogFiles(const std::vector& wal_numbers, + bool has_valid_writes = false; + status = WriteBatchInternal::InsertInto( + batch_to_use, column_family_memtables_.get(), &flush_scheduler_, +- &trim_history_scheduler_, true, wal_number, this, ++ &trim_history_scheduler_, true, wal_number, 0, this, + false /* concurrent_memtable_writes */, next_sequence, + &has_valid_writes, seq_per_batch_, batch_per_txn_); + MaybeIgnoreError(&status); +diff --git a/db/db_impl/db_impl_secondary.cc b/db/db_impl/db_impl_secondary.cc +index 235a528ba..1ee6e9df0 100644 +--- a/db/db_impl/db_impl_secondary.cc ++++ b/db/db_impl/db_impl_secondary.cc +@@ -285,7 +285,7 @@ Status DBImplSecondary::RecoverLogFiles( + status = WriteBatchInternal::InsertInto( + &batch, column_family_memtables_.get(), + nullptr /* flush_scheduler */, nullptr /* trim_history_scheduler*/, +- true, log_number, this, false /* concurrent_memtable_writes */, ++ true, log_number, 0, this, false /* concurrent_memtable_writes */, + next_sequence, &has_valid_writes, seq_per_batch_, batch_per_txn_); + } + // If column family was not found, it might mean that the WAL write +diff --git a/db/db_impl/db_impl_write.cc b/db/db_impl/db_impl_write.cc +index cbe6e8fbd..b5ea42729 100644 +--- a/db/db_impl/db_impl_write.cc ++++ b/db/db_impl/db_impl_write.cc +@@ -174,6 +174,228 @@ Status DBImpl::WriteWithCallback(const WriteOptions& write_options, + return s; + } + ++void DBImpl::MultiBatchWriteCommit(CommitRequest* request) { ++ write_thread_.ExitWaitSequenceCommit(request, &versions_->last_sequence_); ++ size_t pending_cnt = pending_memtable_writes_.fetch_sub(1) - 1; ++ if (pending_cnt == 0) { ++ // switch_cv_ waits until pending_memtable_writes_ = 0. Locking its mutex ++ // before notify ensures that cv is in waiting state when it is notified ++ // thus not missing the update to pending_memtable_writes_ even though it ++ // is not modified under the mutex. ++ std::lock_guard lck(switch_mutex_); ++ switch_cv_.notify_all(); ++ } ++} ++ ++Status DBImpl::MultiBatchWrite(const WriteOptions& options, ++ std::vector&& updates) { ++ if (immutable_db_options_.enable_multi_batch_write) { ++ return MultiBatchWriteImpl(options, std::move(updates), nullptr, nullptr); ++ } else { ++ return Status::NotSupported(); ++ } ++} ++ ++// In this way, RocksDB will apply WriteBatch to memtable out of order but ++// commit ++// them in order. (We borrow the idea from: ++// https://github.com/cockroachdb/pebble/blob/master/docs/rocksdb.md#commit-pipeline. ++// On this basis, we split the WriteBatch into smaller-grained WriteBatch ++// vector, ++// and when the WriteBatch sizes of multiple writers are not balanced, writers ++// that finish first need to help the front writer finish writing the remaining ++// WriteBatch to increase cpu usage and reduce overall latency) ++// ++// More details: ++// ++// Request Queue WriteBatchVec ++// +--------------+ +---------------------+ ++// | Front Writer | -> | WB1 | WB2 | WB3|... | ++// +--------------+ +-----+ +---------------------+ ++// | Writer 2 | -> | WB1 | ++// +--------------+ +-----+ +-----------+ ++// | Writer 3 | -> | WB1 | WB2 | ++// +--------------+ +---+ +-----------+ ++// | ... | -> |...| ++// +--------------+ +---+ ++// ++// 1. Mutli Writers enter the `Request queue` to determine the commit order. ++// 2. Then all writers write to the memtable in parallel (Each thread iterates ++// over ++// its own write batch vector). ++// 3.1. If the Front Writer finishes writing and enters the commit phase first, ++// it will ++// pop itself from the `Request queue`, then this function will return to ++// its caller, ++// and the Writer 2 becomes the new front. ++// 3.2. If the Writer 2 or 3 finishes writing and enters the commit phase first, ++// it will ++// help the front writer complete its pending WBs one by one until all done ++// and wake ++// up the Front Writer, then the Front Writer will traverse and pop ++// completed writers, ++// the first unfinished writer encountered will become the new front. ++// ++Status DBImpl::MultiBatchWriteImpl(const WriteOptions& write_options, ++ std::vector&& my_batch, ++ WriteCallback* callback, uint64_t* log_used, ++ uint64_t log_ref, uint64_t* seq_used) { ++ PERF_TIMER_GUARD(write_pre_and_post_process_time); ++ StopWatch write_sw(immutable_db_options_.clock, ++ immutable_db_options_.statistics.get(), DB_WRITE); ++ WriteThread::Writer writer(write_options, std::move(my_batch), callback, ++ log_ref, false /*disable_memtable*/); ++ CommitRequest request(&writer); ++ writer.request = &request; ++ write_thread_.JoinBatchGroup(&writer); ++ ++ WriteContext write_context; ++ if (writer.state == WriteThread::STATE_GROUP_LEADER) { ++ WriteThread::WriteGroup wal_write_group; ++ if (writer.callback && !writer.callback->AllowWriteBatching()) { ++ WaitForPendingWrites(); ++ } ++ LogContext log_context(!write_options.disableWAL && write_options.sync); ++ PERF_TIMER_STOP(write_pre_and_post_process_time); ++ writer.status = ++ PreprocessWrite(write_options, &log_context, &write_context); ++ PERF_TIMER_START(write_pre_and_post_process_time); ++ ++ // This can set non-OK status if callback fail. ++ last_batch_group_size_ = ++ write_thread_.EnterAsBatchGroupLeader(&writer, &wal_write_group); ++ const SequenceNumber current_sequence = ++ write_thread_.UpdateLastSequence(versions_->LastSequence()) + 1; ++ size_t total_count = 0; ++ size_t total_byte_size = 0; ++ auto stats = default_cf_internal_stats_; ++ size_t memtable_write_cnt = 0; ++ IOStatus io_s; ++ io_s.PermitUncheckedError(); // Allow io_s to be uninitialized ++ if (writer.status.ok()) { ++ SequenceNumber next_sequence = current_sequence; ++ for (auto w : wal_write_group) { ++ if (w->CheckCallback(this)) { ++ if (w->ShouldWriteToMemtable()) { ++ w->sequence = next_sequence; ++ size_t count = WriteBatchInternal::Count(w->multi_batch.batches); ++ if (count > 0) { ++ auto sequence = w->sequence; ++ for (auto b : w->multi_batch.batches) { ++ WriteBatchInternal::SetSequence(b, sequence); ++ sequence += WriteBatchInternal::Count(b); ++ } ++ w->multi_batch.SetContext( ++ versions_->GetColumnFamilySet(), &flush_scheduler_, ++ &trim_history_scheduler_, ++ write_options.ignore_missing_column_families, this); ++ w->request->commit_lsn = next_sequence + count - 1; ++ write_thread_.EnterCommitQueue(w->request); ++ next_sequence += count; ++ total_count += count; ++ memtable_write_cnt++; ++ } ++ } ++ total_byte_size = WriteBatchInternal::AppendedByteSize( ++ total_byte_size, ++ WriteBatchInternal::ByteSize(w->multi_batch.batches)); ++ } ++ } ++ if (writer.disable_wal) { ++ has_unpersisted_data_.store(true, std::memory_order_relaxed); ++ } ++ write_thread_.UpdateLastSequence(current_sequence + total_count - 1); ++ stats->AddDBStats(InternalStats::kIntStatsNumKeysWritten, total_count); ++ RecordTick(stats_, NUMBER_KEYS_WRITTEN, total_count); ++ stats->AddDBStats(InternalStats::kIntStatsBytesWritten, total_byte_size); ++ RecordTick(stats_, BYTES_WRITTEN, total_byte_size); ++ RecordInHistogram(stats_, BYTES_PER_WRITE, total_byte_size); ++ ++ PERF_TIMER_STOP(write_pre_and_post_process_time); ++ if (!write_options.disableWAL) { ++ PERF_TIMER_GUARD(write_wal_time); ++ stats->AddDBStats(InternalStats::kIntStatsWriteDoneBySelf, 1); ++ RecordTick(stats_, WRITE_DONE_BY_SELF, 1); ++ if (wal_write_group.size > 1) { ++ stats->AddDBStats(InternalStats::kIntStatsWriteDoneByOther, ++ wal_write_group.size - 1); ++ ++ RecordTick(stats_, WRITE_DONE_BY_OTHER, wal_write_group.size - 1); ++ } ++ assert(log_context.log_file_number_size); ++ LogFileNumberSize& log_file_number_size = ++ *(log_context.log_file_number_size); ++ io_s = ++ WriteToWAL(wal_write_group, log_context.writer, log_used, ++ log_context.need_log_sync, log_context.need_log_dir_sync, ++ current_sequence, log_file_number_size); ++ writer.status = io_s; ++ } ++ } ++ if (!io_s.ok()) { ++ // Check WriteToWAL status ++ IOStatusCheck(io_s); ++ } else if (!writer.CallbackFailed()) { ++ WriteStatusCheck(writer.status); ++ } ++ ++ VersionEdit synced_wals; ++ if (log_context.need_log_sync) { ++ InstrumentedMutexLock l(&log_write_mutex_); ++ if (writer.status.ok()) { ++ MarkLogsSynced(logfile_number_, log_context.need_log_dir_sync, ++ &synced_wals); ++ } else { ++ MarkLogsNotSynced(logfile_number_); ++ } ++ } ++ if (writer.status.ok() && synced_wals.IsWalAddition()) { ++ InstrumentedMutexLock l(&mutex_); ++ const ReadOptions read_options; ++ writer.status = ApplyWALToManifest(read_options, &synced_wals); ++ } ++ if (writer.status.ok()) { ++ pending_memtable_writes_ += memtable_write_cnt; ++ } else { ++ // The `pending_wb_cnt` must be reset to avoid other writers helping ++ // the front writer write its WBs after it failed to write the WAL. ++ writer.ResetPendingWBCnt(); ++ } ++ write_thread_.ExitAsBatchGroupLeader(wal_write_group, writer.status); ++ } ++ ++ if (seq_used != nullptr) { ++ *seq_used = writer.sequence; ++ } ++ TEST_SYNC_POINT("DBImpl::WriteImpl:CommitAfterWriteWAL"); ++ ++ if (writer.request->commit_lsn != 0 && writer.status.ok()) { ++ TEST_SYNC_POINT("DBImpl::WriteImpl:BeforePipelineWriteMemtable"); ++ PERF_TIMER_GUARD(write_memtable_time); ++ size_t total_count = WriteBatchInternal::Count(my_batch); ++ InternalStats* stats = default_cf_internal_stats_; ++ stats->AddDBStats(InternalStats::kIntStatsNumKeysWritten, total_count); ++ RecordTick(stats_, NUMBER_KEYS_WRITTEN, total_count); ++ ++ while (writer.ConsumeOne()); ++ MultiBatchWriteCommit(writer.request); ++ ++ WriteStatusCheck(writer.status); ++ if (!writer.FinalStatus().ok()) { ++ writer.status = writer.FinalStatus(); ++ } ++ } else if (writer.request->commit_lsn != 0) { ++ // When the leader fails to write WAL, all writers in the group need to ++ // cancel ++ // the write to memtable. ++ writer.ResetPendingWBCnt(); ++ MultiBatchWriteCommit(writer.request); ++ } else { ++ writer.ResetPendingWBCnt(); ++ } ++ return writer.status; ++} ++ + // The main write queue. This is the only write queue that updates LastSequence. + // When using one write queue, the same sequence also indicates the last + // published sequence. +@@ -240,6 +462,10 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, + return Status::NotSupported( + "pipelined_writes is not compatible with concurrent prepares"); + } ++ if (two_write_queues_ && immutable_db_options_.enable_multi_batch_write) { ++ return Status::NotSupported( ++ "pipelined_writes is not compatible with concurrent prepares"); ++ } + if (seq_per_batch_ && immutable_db_options_.enable_pipelined_write) { + // TODO(yiwu): update pipeline write with seq_per_batch and batch_cnt + return Status::NotSupported( +@@ -308,6 +534,13 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, + return status; + } + ++ if (immutable_db_options_.enable_multi_batch_write && !disable_memtable) { ++ std::vector updates(1); ++ updates[0] = my_batch; ++ return MultiBatchWriteImpl(write_options, std::move(updates), callback, ++ log_used, log_ref, seq_used); ++ } ++ + if (immutable_db_options_.enable_pipelined_write) { + return PipelinedWriteImpl(write_options, my_batch, callback, log_used, + log_ref, disable_memtable, seq_used); +@@ -436,10 +669,12 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, + if (writer->CheckCallback(this)) { + valid_batches += writer->batch_cnt; + if (writer->ShouldWriteToMemtable()) { +- total_count += WriteBatchInternal::Count(writer->batch); ++ total_count += ++ WriteBatchInternal::Count(writer->multi_batch.batches[0]); + total_byte_size = WriteBatchInternal::AppendedByteSize( +- total_byte_size, WriteBatchInternal::ByteSize(writer->batch)); +- parallel = parallel && !writer->batch->HasMerge(); ++ total_byte_size, ++ WriteBatchInternal::ByteSize(writer->multi_batch.batches[0])); ++ parallel = parallel && !writer->multi_batch.batches[0]->HasMerge(); + } + if (writer->pre_release_callback) { + pre_release_callback_cnt++; +@@ -456,7 +691,7 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, + continue; + } + // TODO: maybe handle the tracing status? +- tracer_->Write(writer->batch).PermitUncheckedError(); ++ tracer_->Write(writer->multi_batch.batches[0]).PermitUncheckedError(); + } + } + } +@@ -551,7 +786,8 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, + assert(writer->batch_cnt); + next_sequence += writer->batch_cnt; + } else if (writer->ShouldWriteToMemtable()) { +- next_sequence += WriteBatchInternal::Count(writer->batch); ++ next_sequence += ++ WriteBatchInternal::Count(writer->multi_batch.batches[0]); + } + } + } +@@ -709,7 +945,8 @@ Status DBImpl::PipelinedWriteImpl(const WriteOptions& write_options, + if (tracer_ != nullptr && tracer_->IsWriteOrderPreserved()) { + for (auto* writer : wal_write_group) { + // TODO: maybe handle the tracing status? +- tracer_->Write(writer->batch).PermitUncheckedError(); ++ tracer_->Write(writer->multi_batch.batches[0]) ++ .PermitUncheckedError(); + } + } + } +@@ -719,9 +956,11 @@ Status DBImpl::PipelinedWriteImpl(const WriteOptions& write_options, + if (writer->CheckCallback(this)) { + if (writer->ShouldWriteToMemtable()) { + writer->sequence = next_sequence; +- size_t count = WriteBatchInternal::Count(writer->batch); ++ size_t count = ++ WriteBatchInternal::Count(writer->multi_batch.batches[0]); + total_byte_size = WriteBatchInternal::AppendedByteSize( +- total_byte_size, WriteBatchInternal::ByteSize(writer->batch)); ++ total_byte_size, ++ WriteBatchInternal::ByteSize(writer->multi_batch.batches[0])); + next_sequence += count; + total_count += count; + } +@@ -964,7 +1203,7 @@ Status DBImpl::WriteImplWALOnly( + if (tracer_ != nullptr && tracer_->IsWriteOrderPreserved()) { + for (auto* writer : write_group) { + // TODO: maybe handle the tracing status? +- tracer_->Write(writer->batch).PermitUncheckedError(); ++ tracer_->Write(writer->multi_batch.batches[0]).PermitUncheckedError(); + } + } + } +@@ -975,7 +1214,8 @@ Status DBImpl::WriteImplWALOnly( + assert(writer); + if (writer->CheckCallback(this)) { + total_byte_size = WriteBatchInternal::AppendedByteSize( +- total_byte_size, WriteBatchInternal::ByteSize(writer->batch)); ++ total_byte_size, ++ WriteBatchInternal::ByteSize(writer->multi_batch.batches[0])); + if (writer->pre_release_callback) { + pre_release_callback_cnt++; + } +@@ -1281,11 +1521,12 @@ Status DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group, + auto* leader = write_group.leader; + assert(!leader->disable_wal); // Same holds for all in the batch group + if (write_group.size == 1 && !leader->CallbackFailed() && +- leader->batch->GetWalTerminationPoint().is_cleared()) { ++ leader->multi_batch.batches.size() == 1 && ++ leader->multi_batch.batches[0]->GetWalTerminationPoint().is_cleared()) { + // we simply write the first WriteBatch to WAL if the group only + // contains one batch, that batch should be written to the WAL, + // and the batch is not wanting to be truncated +- *merged_batch = leader->batch; ++ *merged_batch = leader->multi_batch.batches[0]; + if (WriteBatchInternal::IsLatestPersistentState(*merged_batch)) { + *to_be_cached_state = *merged_batch; + } +@@ -1297,17 +1538,19 @@ Status DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group, + *merged_batch = tmp_batch; + for (auto writer : write_group) { + if (!writer->CallbackFailed()) { +- Status s = WriteBatchInternal::Append(*merged_batch, writer->batch, +- /*WAL_only*/ true); +- if (!s.ok()) { +- tmp_batch->Clear(); +- return s; +- } +- if (WriteBatchInternal::IsLatestPersistentState(writer->batch)) { +- // We only need to cache the last of such write batch +- *to_be_cached_state = writer->batch; ++ for (auto b : writer->multi_batch.batches) { ++ Status s = WriteBatchInternal::Append(*merged_batch, b, ++ /*WAL_only*/ true); ++ if (!s.ok()) { ++ tmp_batch->Clear(); ++ return s; ++ } ++ if (WriteBatchInternal::IsLatestPersistentState(b)) { ++ // We only need to cache the last of such write batch ++ *to_be_cached_state = b; ++ } ++ (*write_with_wal)++; + } +- (*write_with_wal)++; + } + } + } +@@ -1381,7 +1624,7 @@ IOStatus DBImpl::WriteToWAL(const WriteThread::WriteGroup& write_group, + return io_s; + } + +- if (merged_batch == write_group.leader->batch) { ++ if (merged_batch == write_group.leader->multi_batch.batches[0]) { + write_group.leader->log_used = logfile_number_; + } else if (write_with_wal > 1) { + for (auto writer : write_group) { +@@ -1481,7 +1724,7 @@ IOStatus DBImpl::ConcurrentWriteToWAL( + // We need to lock log_write_mutex_ since logs_ and alive_log_files might be + // pushed back concurrently + log_write_mutex_.Lock(); +- if (merged_batch == write_group.leader->batch) { ++ if (merged_batch == write_group.leader->multi_batch.batches[0]) { + write_group.leader->log_used = logfile_number_; + } else if (write_with_wal > 1) { + for (auto writer : write_group) { +@@ -1538,8 +1781,9 @@ Status DBImpl::WriteRecoverableState() { + auto status = WriteBatchInternal::InsertInto( + &cached_recoverable_state_, column_family_memtables_.get(), + &flush_scheduler_, &trim_history_scheduler_, true, +- 0 /*recovery_log_number*/, this, false /* concurrent_memtable_writes */, +- &next_seq, &dont_care_bool, seq_per_batch_); ++ 0 /*recovery_log_number*/, 0 /*log_ref*/, this, ++ false /* concurrent_memtable_writes */, &next_seq, &dont_care_bool, ++ seq_per_batch_); + auto last_seq = next_seq - 1; + if (two_write_queues_) { + versions_->FetchAddLastAllocatedSequence(last_seq - seq); +diff --git a/db/db_kv_checksum_test.cc b/db/db_kv_checksum_test.cc +index 614399243..8d39d72f8 100644 +--- a/db/db_kv_checksum_test.cc ++++ b/db/db_kv_checksum_test.cc +@@ -449,7 +449,7 @@ TEST_P(DbKvChecksumTestMergedBatch, WriteToWALCorrupted) { + // this writer joins the write group + ASSERT_NE(follower->state, WriteThread::STATE_GROUP_LEADER); + if (corrupt_byte_offset >= leader_batch_size) { +- Slice batch_content = follower->batch->Data(); ++ Slice batch_content = follower->multi_batch.batches[0]->Data(); + CorruptWriteBatch(&batch_content, + corrupt_byte_offset - leader_batch_size, + corrupt_byte_addend_); +@@ -473,9 +473,10 @@ TEST_P(DbKvChecksumTestMergedBatch, WriteToWALCorrupted) { + .IsCorruption()); + }); + +- ASSERT_EQ(leader->batch->GetDataSize(), leader_batch_size); ++ ASSERT_EQ(leader->multi_batch.batches[0]->GetDataSize(), ++ leader_batch_size); + if (corrupt_byte_offset < leader_batch_size) { +- Slice batch_content = leader->batch->Data(); ++ Slice batch_content = leader->multi_batch.batches[0]->Data(); + CorruptWriteBatch(&batch_content, corrupt_byte_offset, + corrupt_byte_addend_); + } +@@ -561,8 +562,8 @@ TEST_P(DbKvChecksumTestMergedBatch, WriteToWALWithColumnFamilyCorrupted) { + // this writer joins the write group + ASSERT_NE(follower->state, WriteThread::STATE_GROUP_LEADER); + if (corrupt_byte_offset >= leader_batch_size) { +- Slice batch_content = +- WriteBatchInternal::Contents(follower->batch); ++ Slice batch_content = WriteBatchInternal::Contents( ++ follower->multi_batch.batches[0]); + CorruptWriteBatch(&batch_content, + corrupt_byte_offset - leader_batch_size, + corrupt_byte_addend_); +@@ -585,9 +586,11 @@ TEST_P(DbKvChecksumTestMergedBatch, WriteToWALWithColumnFamilyCorrupted) { + .IsCorruption()); + }); + +- ASSERT_EQ(leader->batch->GetDataSize(), leader_batch_size); ++ ASSERT_EQ(leader->multi_batch.batches[0]->GetDataSize(), ++ leader_batch_size); + if (corrupt_byte_offset < leader_batch_size) { +- Slice batch_content = WriteBatchInternal::Contents(leader->batch); ++ Slice batch_content = ++ WriteBatchInternal::Contents(leader->multi_batch.batches[0]); + CorruptWriteBatch(&batch_content, corrupt_byte_offset, + corrupt_byte_addend_); + } +diff --git a/db/db_properties_test.cc b/db/db_properties_test.cc +index e761f96d9..337eadb73 100644 +--- a/db/db_properties_test.cc ++++ b/db/db_properties_test.cc +@@ -63,6 +63,9 @@ TEST_F(DBPropertiesTest, Empty) { + options.write_buffer_size = 100000; // Small write buffer + options.allow_concurrent_memtable_write = false; + options = CurrentOptions(options); ++ if (options.enable_multi_batch_write) { ++ continue; ++ } + CreateAndReopenWithCF({"pikachu"}, options); + + std::string num; +diff --git a/db/db_test_util.cc b/db/db_test_util.cc +index 3fb457676..d98c45019 100644 +--- a/db/db_test_util.cc ++++ b/db/db_test_util.cc +@@ -551,6 +551,12 @@ Options DBTestBase::GetOptions( + options.enable_pipelined_write = true; + break; + } ++ case kMultiBatchWrite: { ++ options.enable_multi_batch_write = true; ++ options.enable_pipelined_write = false; ++ options.two_write_queues = false; ++ break; ++ } + case kConcurrentWALWrites: { + // This options optimize 2PC commit path + options.two_write_queues = true; +diff --git a/db/db_test_util.h b/db/db_test_util.h +index 023784f61..34ef23b77 100644 +--- a/db/db_test_util.h ++++ b/db/db_test_util.h +@@ -998,6 +998,7 @@ class DBTestBase : public testing::Test { + kConcurrentSkipList = 27, + kPipelinedWrite = 28, + kConcurrentWALWrites = 29, ++ kMultiBatchWrite = 30, + kDirectIO, + kLevelSubcompactions, + kBlockBasedTableWithIndexRestartInterval, +diff --git a/db/db_write_test.cc b/db/db_write_test.cc +index 59c26eaaa..0c6fdf849 100644 +--- a/db/db_write_test.cc ++++ b/db/db_write_test.cc +@@ -39,8 +39,25 @@ class DBWriteTestUnparameterized : public DBTestBase { + : DBTestBase("pipelined_write_test", /*env_do_fsync=*/false) {} + }; + ++TEST_P(DBWriteTest, WriteEmptyBatch) { ++ Options options = GetOptions(); ++ options.write_buffer_size = 65536; ++ Reopen(options); ++ WriteOptions write_options; ++ WriteBatch batch; ++ Random rnd(301); ++ // Trigger a flush so that we will enter `WaitForPendingWrites`. ++ for (auto i = 0; i < 10; i++) { ++ batch.Clear(); ++ ASSERT_OK(dbfull()->Write(write_options, &batch)); ++ ASSERT_OK(batch.Put(std::to_string(i), rnd.RandomString(10240))); ++ ASSERT_OK(dbfull()->Write(write_options, &batch)); ++ } ++} ++ + // It is invalid to do sync write while disabling WAL. + TEST_P(DBWriteTest, SyncAndDisableWAL) { ++ Reopen(GetOptions()); + WriteOptions write_options; + write_options.sync = true; + write_options.disableWAL = true; +@@ -780,10 +797,73 @@ TEST_P(DBWriteTest, ConcurrentlyDisabledWAL) { + ASSERT_LE(bytes_num, 1024 * 100); + } + ++TEST_P(DBWriteTest, MultiThreadWrite) { ++ Options options = GetOptions(); ++ std::unique_ptr mock_env( ++ new FaultInjectionTestEnv(env_)); ++ if (!options.enable_multi_batch_write) { ++ return; ++ } ++ constexpr int kNumThreads = 4; ++ constexpr int kNumWrite = 4; ++ constexpr int kNumBatch = 8; ++ constexpr int kBatchSize = 16; ++ options.env = mock_env.get(); ++ options.write_buffer_size = 1024 * 128; ++ Reopen(options); ++ std::vector threads; ++ for (int t = 0; t < kNumThreads; t++) { ++ threads.push_back(port::Thread( ++ [&](int index) { ++ WriteOptions opt; ++ std::vector data(kNumBatch); ++ for (int j = 0; j < kNumWrite; j++) { ++ std::vector batches; ++ for (int i = 0; i < kNumBatch; i++) { ++ WriteBatch* batch = &data[i]; ++ batch->Clear(); ++ for (int k = 0; k < kBatchSize; k++) { ++ batch->Put("key_" + std::to_string(index) + "_" + ++ std::to_string(j) + "_" + std::to_string(i) + ++ "_" + std::to_string(k), ++ "value" + std::to_string(k)); ++ } ++ batches.push_back(batch); ++ } ++ dbfull()->MultiBatchWrite(opt, std::move(batches)); ++ } ++ }, ++ t)); ++ } ++ for (int i = 0; i < kNumThreads; i++) { ++ threads[i].join(); ++ } ++ ReadOptions opt; ++ for (int t = 0; t < kNumThreads; t++) { ++ std::string value; ++ for (int i = 0; i < kNumWrite; i++) { ++ for (int j = 0; j < kNumBatch; j++) { ++ for (int k = 0; k < kBatchSize; k++) { ++ ASSERT_OK(dbfull()->Get( ++ opt, ++ "key_" + std::to_string(t) + "_" + std::to_string(i) + "_" + ++ std::to_string(j) + "_" + std::to_string(k), ++ &value)); ++ std::string expected_value = "value" + std::to_string(k); ++ ASSERT_EQ(expected_value, value); ++ } ++ } ++ } ++ } ++ ++ Close(); ++} ++ + INSTANTIATE_TEST_CASE_P(DBWriteTestInstance, DBWriteTest, + testing::Values(DBTestBase::kDefault, + DBTestBase::kConcurrentWALWrites, +- DBTestBase::kPipelinedWrite)); ++ DBTestBase::kPipelinedWrite, ++ DBTestBase::kMultiBatchWrite)); + + } // namespace ROCKSDB_NAMESPACE + +diff --git a/db/external_sst_file_test.cc b/db/external_sst_file_test.cc +index ef4ab7fa5..250807baf 100644 +--- a/db/external_sst_file_test.cc ++++ b/db/external_sst_file_test.cc +@@ -1735,6 +1735,36 @@ TEST_F(ExternalSSTFileTest, WithUnorderedWrite) { + SyncPoint::GetInstance()->ClearAllCallBacks(); + } + ++TEST_F(ExternalSSTFileTest, WithMultiBatchWrite) { ++ SyncPoint::GetInstance()->DisableProcessing(); ++ SyncPoint::GetInstance()->LoadDependency( ++ {{"DBImpl::WriteImpl:CommitAfterWriteWAL", ++ "ExternalSSTFileTest::WithMultiBatchWrite:WaitWriteWAL"}, ++ {"DBImpl::WaitForPendingWrites:BeforeBlock", ++ "DBImpl::WriteImpl:BeforePipelineWriteMemtable"}}); ++ SyncPoint::GetInstance()->SetCallBack( ++ "DBImpl::IngestExternalFile:NeedFlush", [&](void* need_flush) { ++ ASSERT_TRUE(*reinterpret_cast(need_flush)); ++ }); ++ ++ Options options = CurrentOptions(); ++ options.unordered_write = false; ++ options.enable_multi_batch_write = true; ++ DestroyAndReopen(options); ++ Put("foo", "v1"); ++ SyncPoint::GetInstance()->EnableProcessing(); ++ port::Thread writer([&]() { Put("bar", "v2"); }); ++ ++ TEST_SYNC_POINT("ExternalSSTFileTest::WithMultiBatchWrite:WaitWriteWAL"); ++ ASSERT_OK(GenerateAndAddExternalFile(options, {{"bar", "v3"}}, -1, ++ true /* allow_global_seqno */)); ++ ASSERT_EQ(Get("bar"), "v3"); ++ ++ writer.join(); ++ SyncPoint::GetInstance()->DisableProcessing(); ++ SyncPoint::GetInstance()->ClearAllCallBacks(); ++} ++ + #if !defined(ROCKSDB_VALGRIND_RUN) || defined(ROCKSDB_FULL_VALGRIND_RUN) + TEST_P(ExternalSSTFileTest, IngestFileWithGlobalSeqnoRandomized) { + env_->skip_fsync_ = true; +diff --git a/db/write_batch.cc b/db/write_batch.cc +index e734bb3f6..0b55cb4aa 100644 +--- a/db/write_batch.cc ++++ b/db/write_batch.cc +@@ -732,6 +732,14 @@ uint32_t WriteBatchInternal::Count(const WriteBatch* b) { + return DecodeFixed32(b->rep_.data() + 8); + } + ++uint32_t WriteBatchInternal::Count(const std::vector b) { ++ uint32_t count = 0; ++ for (auto w : b) { ++ count += DecodeFixed32(w->rep_.data() + 8); ++ } ++ return count; ++} ++ + void WriteBatchInternal::SetCount(WriteBatch* b, uint32_t n) { + EncodeFixed32(&b->rep_[8], n); + } +@@ -2946,10 +2954,10 @@ Status WriteBatchInternal::InsertInto( + inserter.MaybeAdvanceSeq(true); + continue; + } +- SetSequence(w->batch, inserter.sequence()); ++ SetSequence(w->multi_batch.batches[0], inserter.sequence()); + inserter.set_log_number_ref(w->log_ref); +- inserter.set_prot_info(w->batch->prot_info_.get()); +- w->status = w->batch->Iterate(&inserter); ++ inserter.set_prot_info(w->multi_batch.batches[0]->prot_info_.get()); ++ w->status = w->multi_batch.batches[0]->Iterate(&inserter); + if (!w->status.ok()) { + return w->status; + } +@@ -2976,10 +2984,10 @@ Status WriteBatchInternal::InsertInto( + concurrent_memtable_writes, nullptr /* prot_info */, + nullptr /*has_valid_writes*/, seq_per_batch, + batch_per_txn, hint_per_batch); +- SetSequence(writer->batch, sequence); ++ SetSequence(writer->multi_batch.batches[0], sequence); + inserter.set_log_number_ref(writer->log_ref); +- inserter.set_prot_info(writer->batch->prot_info_.get()); +- Status s = writer->batch->Iterate(&inserter); ++ inserter.set_prot_info(writer->multi_batch.batches[0]->prot_info_.get()); ++ Status s = writer->multi_batch.batches[0]->Iterate(&inserter); + assert(!seq_per_batch || batch_cnt != 0); + assert(!seq_per_batch || inserter.sequence() - sequence == batch_cnt); + if (concurrent_memtable_writes) { +@@ -2992,14 +3000,15 @@ Status WriteBatchInternal::InsertInto( + const WriteBatch* batch, ColumnFamilyMemTables* memtables, + FlushScheduler* flush_scheduler, + TrimHistoryScheduler* trim_history_scheduler, +- bool ignore_missing_column_families, uint64_t log_number, DB* db, +- bool concurrent_memtable_writes, SequenceNumber* next_seq, ++ bool ignore_missing_column_families, uint64_t log_number, uint64_t log_ref, ++ DB* db, bool concurrent_memtable_writes, SequenceNumber* next_seq, + bool* has_valid_writes, bool seq_per_batch, bool batch_per_txn) { + MemTableInserter inserter(Sequence(batch), memtables, flush_scheduler, + trim_history_scheduler, + ignore_missing_column_families, log_number, db, + concurrent_memtable_writes, batch->prot_info_.get(), + has_valid_writes, seq_per_batch, batch_per_txn); ++ inserter.set_log_number_ref(log_ref); + Status s = batch->Iterate(&inserter); + if (next_seq != nullptr) { + *next_seq = inserter.sequence(); +diff --git a/db/write_batch_internal.h b/db/write_batch_internal.h +index 36e7f71f4..fcae19f0c 100644 +--- a/db/write_batch_internal.h ++++ b/db/write_batch_internal.h +@@ -134,6 +134,8 @@ class WriteBatchInternal { + // Return the number of entries in the batch. + static uint32_t Count(const WriteBatch* batch); + ++ static uint32_t Count(const std::vector batch); ++ + // Set the count for the number of entries in the batch. + static void SetCount(WriteBatch* batch, uint32_t n); + +@@ -152,6 +154,14 @@ class WriteBatchInternal { + + static size_t ByteSize(const WriteBatch* batch) { return batch->rep_.size(); } + ++ static size_t ByteSize(const std::vector batch) { ++ size_t count = 0; ++ for (auto w : batch) { ++ count += w->rep_.size(); ++ } ++ return count; ++ } ++ + static Status SetContents(WriteBatch* batch, const Slice& contents); + + static Status CheckSlicePartsLength(const SliceParts& key, +@@ -189,7 +199,8 @@ class WriteBatchInternal { + FlushScheduler* flush_scheduler, + TrimHistoryScheduler* trim_history_scheduler, + bool ignore_missing_column_families = false, uint64_t log_number = 0, +- DB* db = nullptr, bool concurrent_memtable_writes = false, ++ uint64_t log_ref = 0, DB* db = nullptr, ++ bool concurrent_memtable_writes = false, + SequenceNumber* next_seq = nullptr, bool* has_valid_writes = nullptr, + bool seq_per_batch = false, bool batch_per_txn = true); + +diff --git a/db/write_callback_test.cc b/db/write_callback_test.cc +index 7709257f0..6f4e108e0 100644 +--- a/db/write_callback_test.cc ++++ b/db/write_callback_test.cc +@@ -160,12 +160,18 @@ TEST_P(WriteCallbackPTest, WriteWithCallbackTest) { + if (options.enable_pipelined_write && options.two_write_queues) { + continue; + } ++ if (options.enable_multi_batch_write && options.two_write_queues) { ++ continue; ++ } + if (options.unordered_write && !options.allow_concurrent_memtable_write) { + continue; + } + if (options.unordered_write && options.enable_pipelined_write) { + continue; + } ++ if (options.unordered_write && options.enable_multi_batch_write) { ++ continue; ++ } + + ReadOptions read_options; + DB* db; +diff --git a/db/write_thread.cc b/db/write_thread.cc +index 798700775..b24d3667a 100644 +--- a/db/write_thread.cc ++++ b/db/write_thread.cc +@@ -400,7 +400,7 @@ void WriteThread::WaitForStallEndedCount(uint64_t stall_count) { + static WriteThread::AdaptationContext jbg_ctx("JoinBatchGroup"); + void WriteThread::JoinBatchGroup(Writer* w) { + TEST_SYNC_POINT_CALLBACK("WriteThread::JoinBatchGroup:Start", w); +- assert(w->batch != nullptr); ++ assert(!w->multi_batch.batches.empty()); + + bool linked_as_leader = LinkOne(w, &newest_writer_); + +@@ -437,10 +437,10 @@ void WriteThread::JoinBatchGroup(Writer* w) { + size_t WriteThread::EnterAsBatchGroupLeader(Writer* leader, + WriteGroup* write_group) { + assert(leader->link_older == nullptr); +- assert(leader->batch != nullptr); ++ assert(!leader->multi_batch.batches.empty()); + assert(write_group != nullptr); + +- size_t size = WriteBatchInternal::ByteSize(leader->batch); ++ size_t size = WriteBatchInternal::ByteSize(leader->multi_batch.batches); + + // Allow the group to grow up to a maximum size, but if the + // original write is small, limit the growth so we do not slow +@@ -498,7 +498,7 @@ size_t WriteThread::EnterAsBatchGroupLeader(Writer* leader, + break; + } + +- if (w->batch == nullptr) { ++ if (w->multi_batch.batches.empty()) { + // Do not include those writes with nullptr batch. Those are not writes, + // those are something else. They want to be alone + break; +@@ -509,7 +509,7 @@ size_t WriteThread::EnterAsBatchGroupLeader(Writer* leader, + break; + } + +- auto batch_size = WriteBatchInternal::ByteSize(w->batch); ++ auto batch_size = WriteBatchInternal::ByteSize(w->multi_batch.batches); + if (size + batch_size > max_size) { + // Do not make batch too big + break; +@@ -520,6 +520,7 @@ size_t WriteThread::EnterAsBatchGroupLeader(Writer* leader, + write_group->last_writer = w; + write_group->size++; + } ++ + TEST_SYNC_POINT_CALLBACK("WriteThread::EnterAsBatchGroupLeader:End", w); + return size; + } +@@ -528,10 +529,10 @@ void WriteThread::EnterAsMemTableWriter(Writer* leader, + WriteGroup* write_group) { + assert(leader != nullptr); + assert(leader->link_older == nullptr); +- assert(leader->batch != nullptr); ++ assert(!leader->multi_batch.batches.empty()); + assert(write_group != nullptr); + +- size_t size = WriteBatchInternal::ByteSize(leader->batch); ++ size_t size = WriteBatchInternal::ByteSize(leader->multi_batch.batches); + + // Allow the group to grow up to a maximum size, but if the + // original write is small, limit the growth so we do not slow +@@ -547,7 +548,8 @@ void WriteThread::EnterAsMemTableWriter(Writer* leader, + write_group->size = 1; + Writer* last_writer = leader; + +- if (!allow_concurrent_memtable_write_ || !leader->batch->HasMerge()) { ++ if (!allow_concurrent_memtable_write_ || ++ !leader->multi_batch.batches[0]->HasMerge()) { + Writer* newest_writer = newest_memtable_writer_.load(); + CreateMissingNewerLinks(newest_writer); + +@@ -556,16 +558,16 @@ void WriteThread::EnterAsMemTableWriter(Writer* leader, + assert(w->link_newer); + w = w->link_newer; + +- if (w->batch == nullptr) { ++ if (w->multi_batch.batches.empty()) { + break; + } + +- if (w->batch->HasMerge()) { ++ if (w->multi_batch.batches[0]->HasMerge()) { + break; + } + + if (!allow_concurrent_memtable_write_) { +- auto batch_size = WriteBatchInternal::ByteSize(w->batch); ++ auto batch_size = WriteBatchInternal::ByteSize(w->multi_batch.batches); + if (size + batch_size > max_size) { + // Do not make batch too big + break; +@@ -581,7 +583,8 @@ void WriteThread::EnterAsMemTableWriter(Writer* leader, + + write_group->last_writer = last_writer; + write_group->last_sequence = +- last_writer->sequence + WriteBatchInternal::Count(last_writer->batch) - 1; ++ last_writer->sequence + ++ WriteBatchInternal::Count(last_writer->multi_batch.batches) - 1; + } + + void WriteThread::ExitAsMemTableWriter(Writer* /*self*/, +@@ -802,7 +805,7 @@ void WriteThread::ExitAsBatchGroupLeader(WriteGroup& write_group, + + static WriteThread::AdaptationContext eu_ctx("EnterUnbatched"); + void WriteThread::EnterUnbatched(Writer* w, InstrumentedMutex* mu) { +- assert(w != nullptr && w->batch == nullptr); ++ assert(w != nullptr && w->multi_batch.batches.empty()); + mu->Unlock(); + bool linked_as_leader = LinkOne(w, &newest_writer_); + if (!linked_as_leader) { +@@ -841,4 +844,67 @@ void WriteThread::WaitForMemTableWriters() { + newest_memtable_writer_.store(nullptr); + } + ++RequestQueue::RequestQueue() {} ++ ++RequestQueue::~RequestQueue() {} ++ ++void RequestQueue::Enter(CommitRequest* req) { ++ std::unique_lock guard(commit_mu_); ++ requests_.push_back(req); ++} ++ ++void RequestQueue::CommitSequenceAwait(CommitRequest* req, ++ std::atomic* commit_sequence) { ++ std::unique_lock guard(commit_mu_); ++ while (!requests_.empty() && requests_.front() != req && !req->committed) { ++ // When the subsequent commit finds that the front writer has not yet ++ // submitted, it will help the front writer to perform some tasks ++ auto front = requests_.front()->writer; ++ if (front->ConsumableOnOtherThreads()) { ++ auto claimed = front->Claim(); ++ if (claimed < front->multi_batch.batches.size()) { ++ guard.unlock(); ++ front->ConsumeOne(claimed); ++ guard.lock(); ++ continue; ++ } ++ } else { ++ // The front writer may be waiting for this helper writer ++ commit_cv_.notify_all(); ++ } ++ commit_cv_.wait(guard); ++ } ++ if (req->committed) { ++ return; ++ } else if (requests_.front() == req) { ++ // As the front writer, some write tasks can be stolen by other writers. ++ // Wait for them to finish. ++ while (req->writer->HasPendingWB()) { ++ commit_cv_.wait(guard); ++ } ++ while (!requests_.empty() && !requests_.front()->writer->HasPendingWB()) { ++ CommitRequest* current = requests_.front(); ++ commit_sequence->store(current->commit_lsn, std::memory_order_release); ++ current->committed = true; ++ requests_.pop_front(); ++ } ++ commit_cv_.notify_all(); ++ } ++} ++ ++void WriteThread::Writer::ConsumeOne(size_t claimed) { ++ assert(claimed < multi_batch.batches.size()); ++ ColumnFamilyMemTablesImpl memtables(multi_batch.version_set); ++ Status s = WriteBatchInternal::InsertInto( ++ multi_batch.batches[claimed], &memtables, multi_batch.flush_scheduler, ++ multi_batch.trim_history_scheduler, ++ multi_batch.ignore_missing_column_families, 0, this->log_ref, ++ multi_batch.db, true); ++ if (!s.ok()) { ++ std::lock_guard guard(this->status_lock); ++ this->status = s; ++ } ++ multi_batch.pending_wb_cnt.fetch_sub(1, std::memory_order_acq_rel); ++} ++ + } // namespace ROCKSDB_NAMESPACE +diff --git a/db/write_thread.h b/db/write_thread.h +index 6e5805e37..b234fc15e 100644 +--- a/db/write_thread.h ++++ b/db/write_thread.h +@@ -10,6 +10,7 @@ + #include + #include + #include ++#include + #include + #include + #include +@@ -17,6 +18,7 @@ + #include "db/dbformat.h" + #include "db/post_memtable_callback.h" + #include "db/pre_release_callback.h" ++#include "db/trim_history_scheduler.h" + #include "db/write_callback.h" + #include "monitoring/instrumented_mutex.h" + #include "rocksdb/options.h" +@@ -24,9 +26,29 @@ + #include "rocksdb/types.h" + #include "rocksdb/write_batch.h" + #include "util/autovector.h" ++#include "util/mutexlock.h" + + namespace ROCKSDB_NAMESPACE { + ++struct CommitRequest; ++ ++class ColumnFamilySet; ++class FlushScheduler; ++ ++class RequestQueue { ++ public: ++ RequestQueue(); ++ ~RequestQueue(); ++ void Enter(CommitRequest* req); ++ void CommitSequenceAwait(CommitRequest* req, ++ std::atomic* commit_sequence); ++ ++ private: ++ std::mutex commit_mu_; ++ std::condition_variable commit_cv_; ++ std::deque requests_; ++}; ++ + class WriteThread { + public: + enum State : uint8_t { +@@ -112,9 +134,49 @@ class WriteThread { + Iterator end() const { return Iterator(nullptr, nullptr); } + }; + ++ struct MultiBatch { ++ std::vector batches; ++ std::atomic claimed_cnt; ++ std::atomic pending_wb_cnt; ++ ColumnFamilySet* version_set; ++ FlushScheduler* flush_scheduler; ++ TrimHistoryScheduler* trim_history_scheduler; ++ bool ignore_missing_column_families; ++ DB* db; ++ ++ MultiBatch() ++ : claimed_cnt(0), ++ pending_wb_cnt(0), ++ version_set(nullptr), ++ flush_scheduler(nullptr), ++ trim_history_scheduler(nullptr), ++ ignore_missing_column_families(false), ++ db(nullptr) {} ++ ++ explicit MultiBatch(std::vector&& _batch) ++ : batches(_batch), ++ claimed_cnt(0), ++ pending_wb_cnt(_batch.size()), ++ version_set(nullptr), ++ flush_scheduler(nullptr), ++ trim_history_scheduler(nullptr), ++ ignore_missing_column_families(false), ++ db(nullptr) {} ++ ++ void SetContext(ColumnFamilySet* _version_set, ++ FlushScheduler* _flush_scheduler, ++ TrimHistoryScheduler* _trim_history_scheduler, ++ bool _ignore_missing_column_families, DB* _db) { ++ version_set = _version_set; ++ flush_scheduler = _flush_scheduler; ++ trim_history_scheduler = _trim_history_scheduler; ++ ignore_missing_column_families = _ignore_missing_column_families; ++ db = _db; ++ } ++ }; ++ + // Information kept for every waiting writer. + struct Writer { +- WriteBatch* batch; + bool sync; + bool no_slowdown; + bool disable_wal; +@@ -130,8 +192,10 @@ class WriteThread { + bool made_waitable; // records lazy construction of mutex and cv + std::atomic state; // write under StateMutex() or pre-link + WriteGroup* write_group; ++ CommitRequest* request; + SequenceNumber sequence; // the sequence number to use for the first key +- Status status; ++ Status status; // write protected by status_lock in multi batch write. ++ SpinMutex status_lock; + Status callback_status; // status returned by callback->Callback() + + std::aligned_storage::type state_mutex_bytes; +@@ -139,9 +203,10 @@ class WriteThread { + Writer* link_older; // read/write only before linking, or as leader + Writer* link_newer; // lazy, read/write only before linking, or as leader + ++ MultiBatch multi_batch; ++ + Writer() +- : batch(nullptr), +- sync(false), ++ : sync(false), + no_slowdown(false), + disable_wal(false), + rate_limiter_priority(Env::IOPriority::IO_TOTAL), +@@ -156,6 +221,7 @@ class WriteThread { + made_waitable(false), + state(STATE_INIT), + write_group(nullptr), ++ request(nullptr), + sequence(kMaxSequenceNumber), + link_older(nullptr), + link_newer(nullptr) {} +@@ -165,8 +231,7 @@ class WriteThread { + size_t _batch_cnt = 0, + PreReleaseCallback* _pre_release_callback = nullptr, + PostMemTableCallback* _post_memtable_callback = nullptr) +- : batch(_batch), +- sync(write_options.sync), ++ : sync(write_options.sync), + no_slowdown(write_options.no_slowdown), + disable_wal(write_options.disableWAL), + rate_limiter_priority(write_options.rate_limiter_priority), +@@ -181,9 +246,36 @@ class WriteThread { + made_waitable(false), + state(STATE_INIT), + write_group(nullptr), ++ request(nullptr), + sequence(kMaxSequenceNumber), + link_older(nullptr), +- link_newer(nullptr) {} ++ link_newer(nullptr) { ++ multi_batch.batches.push_back(_batch); ++ multi_batch.pending_wb_cnt.fetch_add(1, std::memory_order_acq_rel); ++ } ++ ++ Writer(const WriteOptions& write_options, std::vector&& _batch, ++ WriteCallback* _callback, uint64_t _log_ref, bool _disable_memtable, ++ PreReleaseCallback* _pre_release_callback = nullptr, ++ PostMemTableCallback* _post_memtable_callback = nullptr) ++ : sync(write_options.sync), ++ no_slowdown(write_options.no_slowdown), ++ disable_wal(write_options.disableWAL), ++ disable_memtable(_disable_memtable), ++ batch_cnt(0), ++ pre_release_callback(_pre_release_callback), ++ post_memtable_callback(_post_memtable_callback), ++ log_used(0), ++ log_ref(_log_ref), ++ callback(_callback), ++ made_waitable(false), ++ state(STATE_INIT), ++ write_group(nullptr), ++ request(nullptr), ++ sequence(kMaxSequenceNumber), ++ link_older(nullptr), ++ link_newer(nullptr), ++ multi_batch(std::move(_batch)) {} + + ~Writer() { + if (made_waitable) { +@@ -256,6 +348,33 @@ class WriteThread { + return *static_cast( + static_cast(&state_cv_bytes)); + } ++ ++ bool ConsumableOnOtherThreads() { ++ return multi_batch.pending_wb_cnt.load(std::memory_order_acquire) > 1; ++ } ++ ++ size_t Claim() { ++ return multi_batch.claimed_cnt.fetch_add(1, std::memory_order_acq_rel); ++ } ++ ++ bool HasPendingWB() { ++ return multi_batch.pending_wb_cnt.load(std::memory_order_acquire) > 0; ++ } ++ ++ void ResetPendingWBCnt() { ++ multi_batch.pending_wb_cnt.store(0, std::memory_order_release); ++ } ++ ++ bool ConsumeOne() { ++ auto claimed = Claim(); ++ if (claimed < multi_batch.batches.size()) { ++ ConsumeOne(claimed); ++ return true; ++ } ++ return false; ++ } ++ ++ void ConsumeOne(size_t claimed); + }; + + struct AdaptationContext { +@@ -374,6 +493,13 @@ class WriteThread { + // (Does not require db mutex held) + void WaitForStallEndedCount(uint64_t stall_count); + ++ void EnterCommitQueue(CommitRequest* req) { return commit_queue_.Enter(req); } ++ ++ void ExitWaitSequenceCommit(CommitRequest* req, ++ std::atomic* commit_sequence) { ++ commit_queue_.CommitSequenceAwait(req, commit_sequence); ++ } ++ + private: + // See AwaitState. + const uint64_t max_yield_usec_; +@@ -406,6 +532,7 @@ class WriteThread { + // at the tail of the writer queue by the leader, so newer writers can just + // check for this and bail + Writer write_stall_dummy_; ++ RequestQueue commit_queue_; + + // Mutex and condvar for writers to block on a write stall. During a write + // stall, writers with no_slowdown set to false will wait on this rather +@@ -461,4 +588,13 @@ class WriteThread { + void CompleteFollower(Writer* w, WriteGroup& write_group); + }; + ++struct CommitRequest { ++ WriteThread::Writer* writer; ++ uint64_t commit_lsn; ++ // protected by RequestQueue::commit_mu_ ++ bool committed; ++ CommitRequest(WriteThread::Writer* w) ++ : writer(w), commit_lsn(0), committed(false) {} ++}; ++ + } // namespace ROCKSDB_NAMESPACE +diff --git a/include/rocksdb/db.h b/include/rocksdb/db.h +index f41a491cc..1791c812a 100644 +--- a/include/rocksdb/db.h ++++ b/include/rocksdb/db.h +@@ -542,6 +542,11 @@ class DB { + // Note: consider setting options.sync = true. + virtual Status Write(const WriteOptions& options, WriteBatch* updates) = 0; + ++ virtual Status MultiBatchWrite(const WriteOptions& /*options*/, ++ std::vector&& /*updates*/) { ++ return Status::NotSupported(); ++ } ++ + // If the column family specified by "column_family" contains an entry for + // "key", return the corresponding value in "*value". If the entry is a plain + // key-value, return the value as-is; if it is a wide-column entity, return +diff --git a/include/rocksdb/options.h b/include/rocksdb/options.h +index 1c2daed9a..a33f8eea4 100644 +--- a/include/rocksdb/options.h ++++ b/include/rocksdb/options.h +@@ -1129,6 +1129,22 @@ struct DBOptions { + // + // Default: false + bool unordered_write = false; ++ // By default, a single write thread queue is maintained. The thread gets ++ // to the head of the queue becomes write batch group leader and responsible ++ // for writing to WAL. ++ // ++ // If enable_multi_batch_write is true, RocksDB will apply WriteBatch to ++ // memtable out of order but commit them in order. (We borrow the idea from ++ // https://github.com/cockroachdb/pebble/blob/master/docs/rocksdb.md#commit-pipeline. ++ // On this basis, we split the WriteBatch into smaller-grained WriteBatch ++ // vector, ++ // and when the WriteBatch sizes of multiple writers are not balanced, writers ++ // that finish first need to help the front writer finish writing the ++ // remaining ++ // WriteBatch to increase cpu usage and reduce overall latency). ++ // ++ // Default: false ++ bool enable_multi_batch_write = false; + + // If true, allow multi-writers to update mem tables in parallel. + // Only some memtable_factory-s support concurrent writes; currently it +diff --git a/options/db_options.cc b/options/db_options.cc +index 9ef4ccac0..dd8630473 100644 +--- a/options/db_options.cc ++++ b/options/db_options.cc +@@ -327,6 +327,10 @@ static std::unordered_map + {offsetof(struct ImmutableDBOptions, enable_pipelined_write), + OptionType::kBoolean, OptionVerificationType::kNormal, + OptionTypeFlags::kNone}}, ++ {"enable_multi_batch_write", ++ {offsetof(struct ImmutableDBOptions, enable_multi_batch_write), ++ OptionType::kBoolean, OptionVerificationType::kNormal, ++ OptionTypeFlags::kNone}}, + {"unordered_write", + {offsetof(struct ImmutableDBOptions, unordered_write), + OptionType::kBoolean, OptionVerificationType::kNormal, +@@ -731,6 +735,7 @@ ImmutableDBOptions::ImmutableDBOptions(const DBOptions& options) + enable_thread_tracking(options.enable_thread_tracking), + enable_pipelined_write(options.enable_pipelined_write), + unordered_write(options.unordered_write), ++ enable_multi_batch_write(options.enable_multi_batch_write), + allow_concurrent_memtable_write(options.allow_concurrent_memtable_write), + enable_write_thread_adaptive_yield( + options.enable_write_thread_adaptive_yield), +@@ -882,6 +887,8 @@ void ImmutableDBOptions::Dump(Logger* log) const { + enable_pipelined_write); + ROCKS_LOG_HEADER(log, " Options.unordered_write: %d", + unordered_write); ++ ROCKS_LOG_HEADER(log, " Options.enable_multi_batch_write: %d", ++ enable_multi_batch_write); + ROCKS_LOG_HEADER(log, " Options.allow_concurrent_memtable_write: %d", + allow_concurrent_memtable_write); + ROCKS_LOG_HEADER(log, " Options.enable_write_thread_adaptive_yield: %d", +diff --git a/options/db_options.h b/options/db_options.h +index 701a83feb..86e17e967 100644 +--- a/options/db_options.h ++++ b/options/db_options.h +@@ -68,6 +68,7 @@ struct ImmutableDBOptions { + bool enable_thread_tracking; + bool enable_pipelined_write; + bool unordered_write; ++ bool enable_multi_batch_write; + bool allow_concurrent_memtable_write; + bool enable_write_thread_adaptive_yield; + uint64_t write_thread_max_yield_usec; +diff --git a/options/options_helper.cc b/options/options_helper.cc +index 362af2839..a9583dd1e 100644 +--- a/options/options_helper.cc ++++ b/options/options_helper.cc +@@ -133,6 +133,8 @@ DBOptions BuildDBOptions(const ImmutableDBOptions& immutable_db_options, + options.enable_thread_tracking = immutable_db_options.enable_thread_tracking; + options.delayed_write_rate = mutable_db_options.delayed_write_rate; + options.enable_pipelined_write = immutable_db_options.enable_pipelined_write; ++ options.enable_multi_batch_write = ++ immutable_db_options.enable_multi_batch_write; + options.unordered_write = immutable_db_options.unordered_write; + options.allow_concurrent_memtable_write = + immutable_db_options.allow_concurrent_memtable_write; +diff --git a/options/options_settable_test.cc b/options/options_settable_test.cc +index decd1c423..ced8597a9 100644 +--- a/options/options_settable_test.cc ++++ b/options/options_settable_test.cc +@@ -338,6 +338,7 @@ TEST_F(OptionsSettableTest, DBOptionsAllFieldsSettable) { + "advise_random_on_open=true;" + "fail_if_options_file_error=false;" + "enable_pipelined_write=false;" ++ "enable_multi_batch_write=false;" + "unordered_write=false;" + "allow_concurrent_memtable_write=true;" + "wal_recovery_mode=kPointInTimeRecovery;" +diff --git a/tools/db_bench_tool.cc b/tools/db_bench_tool.cc +index e177934b0..40132e69c 100644 +--- a/tools/db_bench_tool.cc ++++ b/tools/db_bench_tool.cc +@@ -243,9 +243,11 @@ DEFINE_string( + "operation includes a rare but possible retry in case it got " + "`Status::Incomplete()`. This happens upon encountering more keys than " + "have ever been seen by the thread (or eight initially)\n" +- "\tbackup -- Create a backup of the current DB and verify that a new backup is corrected. " ++ "\tbackup -- Create a backup of the current DB and verify that a new " ++ "backup is corrected. " + "Rate limit can be specified through --backup_rate_limit\n" +- "\trestore -- Restore the DB from the latest backup available, rate limit can be specified through --restore_rate_limit\n"); ++ "\trestore -- Restore the DB from the latest backup available, rate limit " ++ "can be specified through --restore_rate_limit\n"); + + DEFINE_int64(num, 1000000, "Number of key/values to place in database"); + +@@ -1001,6 +1003,9 @@ DEFINE_uint64(fifo_age_for_warm, 0, "age_for_warm for FIFO compaction."); + // Stacked BlobDB Options + DEFINE_bool(use_blob_db, false, "[Stacked BlobDB] Open a BlobDB instance."); + ++DEFINE_bool(use_multi_thread_write, false, ++ "Open a RocksDB with multi thread write pool"); ++ + DEFINE_bool( + blob_db_enable_gc, + ROCKSDB_NAMESPACE::blob_db::BlobDBOptions().enable_garbage_collection, +@@ -1049,7 +1054,6 @@ DEFINE_string( + static enum ROCKSDB_NAMESPACE::CompressionType + FLAGS_blob_db_compression_type_e = ROCKSDB_NAMESPACE::kSnappyCompression; + +- + // Integrated BlobDB options + DEFINE_bool( + enable_blob_files, +@@ -1120,7 +1124,6 @@ DEFINE_int32(prepopulate_blob_cache, 0, + "[Integrated BlobDB] Pre-populate hot/warm blobs in blob cache. 0 " + "to disable and 1 to insert during flush."); + +- + // Secondary DB instance Options + DEFINE_bool(use_secondary_db, false, + "Open a RocksDB secondary instance. A primary instance can be " +@@ -1134,14 +1137,12 @@ DEFINE_int32(secondary_update_interval, 5, + "Secondary instance attempts to catch up with the primary every " + "secondary_update_interval seconds."); + +- + DEFINE_bool(report_bg_io_stats, false, + "Measure times spents on I/Os while in compactions. "); + + DEFINE_bool(use_stderr_info_logger, false, + "Write info logs to stderr instead of to LOG file. "); + +- + DEFINE_string(trace_file, "", "Trace workload to a file. "); + + DEFINE_double(trace_replay_fast_forward, 1.0, +@@ -1798,6 +1799,60 @@ static Status CreateMemTableRepFactory( + return s; + } + ++class WriteBatchVec { ++ public: ++ explicit WriteBatchVec(uint32_t max_batch_size) ++ : max_batch_size_(max_batch_size), current_(0) {} ++ ~WriteBatchVec() { ++ for (auto w : batches_) { ++ delete w; ++ } ++ } ++ void Clear() { ++ for (size_t i = 0; i <= current_ && i < batches_.size(); i++) { ++ batches_[i]->Clear(); ++ } ++ current_ = 0; ++ } ++ ++ Status Put(const Slice& key, const Slice& value) { ++ if (current_ < batches_.size() && ++ batches_[current_]->Count() < max_batch_size_) { ++ return batches_[current_]->Put(key, value); ++ } else if (current_ + 1 >= batches_.size()) { ++ batches_.push_back(new WriteBatch); ++ } ++ if (current_ + 1 < batches_.size()) { ++ current_ += 1; ++ } ++ return batches_[current_]->Put(key, value); ++ } ++ ++ std::vector GetWriteBatch() const { ++ std::vector batches; ++ for (size_t i = 0; i < batches_.size(); i++) { ++ if (i > current_) { ++ break; ++ } ++ batches.push_back(batches_[i]); ++ } ++ return batches; ++ } ++ ++ uint32_t Count() const { ++ uint32_t count = 0; ++ for (size_t i = 0; i <= current_ && i < batches_.size(); i++) { ++ count += batches_[i]->Count(); ++ } ++ return count; ++ } ++ ++ private: ++ uint32_t max_batch_size_; ++ size_t current_; ++ std::vector batches_; ++}; ++ + } // namespace + + enum DistributionType : unsigned char { kFixed = 0, kUniform, kNormal }; +@@ -1962,11 +2017,7 @@ struct DBWithColumnFamilies { + std::vector cfh_idx_to_prob; // ith index holds probability of operating + // on cfh[i]. + +- DBWithColumnFamilies() +- : db(nullptr) +- , +- opt_txn_db(nullptr) +- { ++ DBWithColumnFamilies() : db(nullptr), opt_txn_db(nullptr) { + cfh.clear(); + num_created = 0; + num_hot = 0; +@@ -1978,8 +2029,7 @@ struct DBWithColumnFamilies { + opt_txn_db(other.opt_txn_db), + num_created(other.num_created.load()), + num_hot(other.num_hot), +- cfh_idx_to_prob(other.cfh_idx_to_prob) { +- } ++ cfh_idx_to_prob(other.cfh_idx_to_prob) {} + + void DeleteDBs() { + std::for_each(cfh.begin(), cfh.end(), +@@ -2734,6 +2784,7 @@ class Benchmark { + bool use_blob_db_; // Stacked BlobDB + bool read_operands_; // read via GetMergeOperands() + std::vector keys_; ++ bool use_multi_write_; + + class ErrorHandlerListener : public EventListener { + public: +@@ -3206,7 +3257,8 @@ class Benchmark { + merge_keys_(FLAGS_merge_keys < 0 ? FLAGS_num : FLAGS_merge_keys), + report_file_operations_(FLAGS_report_file_operations), + use_blob_db_(FLAGS_use_blob_db), // Stacked BlobDB +- read_operands_(false) { ++ read_operands_(false), ++ use_multi_write_(FLAGS_use_multi_thread_write) { + // use simcache instead of cache + if (FLAGS_simcache_size >= 0) { + if (FLAGS_cache_numshardbits >= 1) { +@@ -4845,6 +4897,9 @@ class Benchmark { + DBWithColumnFamilies* db) { + uint64_t open_start = FLAGS_report_open_timing ? FLAGS_env->NowNanos() : 0; + Status s; ++ if (use_multi_write_) { ++ options.enable_multi_batch_write = true; ++ } + // Open with column families if necessary. + if (FLAGS_num_column_families > 1) { + size_t num_hot = FLAGS_num_column_families; +@@ -5108,6 +5163,7 @@ class Benchmark { + WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0, + FLAGS_write_batch_protection_bytes_per_key, + user_timestamp_size_); ++ WriteBatchVec batches(32); + Status s; + int64_t bytes = 0; + +@@ -5242,6 +5298,7 @@ class Benchmark { + DBWithColumnFamilies* db_with_cfh = SelectDBWithCfh(id); + + batch.Clear(); ++ batches.Clear(); + int64_t batch_bytes = 0; + + for (int64_t j = 0; j < entries_per_batch_; j++) { +@@ -5357,7 +5414,9 @@ class Benchmark { + } else { + val = gen.Generate(); + } +- if (use_blob_db_) { ++ if (use_multi_write_) { ++ batches.Put(key, val); ++ } else if (use_blob_db_) { + // Stacked BlobDB + blob_db::BlobDB* blobdb = + static_cast(db_with_cfh->db); +@@ -5420,6 +5479,7 @@ class Benchmark { + batch.Delete(db_with_cfh->GetCfh(rand_num), + expanded_keys[offset]); + } ++ assert(!use_multi_write_); + } + } else { + GenerateKeyFromInt(begin_num, FLAGS_num, &begin_key); +@@ -5436,6 +5496,7 @@ class Benchmark { + batch.DeleteRange(db_with_cfh->GetCfh(rand_num), begin_key, + end_key); + } ++ assert(!use_multi_write_); + } + } + } +@@ -5458,7 +5519,10 @@ class Benchmark { + ErrorExit(); + } + } +- if (!use_blob_db_) { ++ if (use_multi_write_) { ++ s = db_with_cfh->db->MultiBatchWrite(write_options_, ++ batches.GetWriteBatch()); ++ } else if (!use_blob_db_) { + // Not stacked BlobDB + s = db_with_cfh->db->Write(write_options_, &batch); + } +@@ -8450,7 +8514,6 @@ class Benchmark { + } + } + +- + void Replay(ThreadState* thread) { + if (db_.db != nullptr) { + Replay(thread, &db_); +@@ -8538,7 +8601,6 @@ class Benchmark { + assert(s.ok()); + delete backup_engine; + } +- + }; + + int db_bench_tool(int argc, char** argv) { +-- +2.45.0 + diff --git a/tikv-rocksdb-patches/0011-Add-KeyManagedEncryptedEnv-for-per-file-key-manageme.patch b/tikv-rocksdb-patches/0011-Add-KeyManagedEncryptedEnv-for-per-file-key-manageme.patch new file mode 100644 index 00000000000..fc9fba0ab2e --- /dev/null +++ b/tikv-rocksdb-patches/0011-Add-KeyManagedEncryptedEnv-for-per-file-key-manageme.patch @@ -0,0 +1,1909 @@ +From ad7361d0e349926126788e88a4447f1f6f8dd65f Mon Sep 17 00:00:00 2001 +From: yiwu-arbug +Date: Mon, 16 Mar 2020 21:59:47 -0700 +Subject: [PATCH] Add KeyManagedEncryptedEnv for per file key management + +Add KeyManagedEncryptedEnv and AESBlockCipher (#151) + +Summary: +Introduce `KeyManagedEncryptedEnv` which wraps around `EncryptedEnv` but provides an `KeyManager` API to enable key management per file. Also implements `AESBlockCipher` with OpenSSL. + +Test Plan: +not tested yet. will update. + +Signed-off-by: Yi Wu +Signed-off-by: tabokie + +Fix build error when build wihtout openssl (#155) + +Summary: +Fix missing check for openssl in db_bench. + +Test Plan: +build without openssl + +Signed-off-by: Yi Wu +Signed-off-by: tabokie + +encryption: change to use openssl EVP API (#156) + +Summary: +Instead of using openssl's raw `AES_encrypt` and `AES_decrypt` API, which is a low level call to encrypt or decrypt exact one block (16 bytes), we change to use the `EVP_*` API. The former is deprecated, and will use the default C implementation without AES-NI support. Also the EVP API is capable of handing CTR mode on its own. + +Test Plan: +will add tests + +Signed-off-by: Yi Wu +Signed-off-by: tabokie + +Fix NewRandomRWFile and ReuseWritableFile in KeyManagedEncryptedEnv (#167) + +Summary: +Fix NewRandomRWFile and ReuseWritableFile misuse of `GetFile()` and `NewFile()`. See inline comments. + +Test Plan: +manual test with tikv + +Signed-off-by: Yi Wu +Signed-off-by: tabokie + +Atomize RenameFile in KeyManagedEncryptedEnv (#222) + +* adjust logic in KeyManagedEncryptedEnv::RenameFile to avoid poweroff + +Signed-off-by: Xintao +Signed-off-by: tabokie + +Atomize Rename operation when encryption is enabled (#224) + +Signed-off-by: Xintao +Signed-off-by: tabokie + +Fix the bug that the key manager is not updated during the Rename (#227) + +Signed-off-by: Xintao +Signed-off-by: tabokie + +Add sm4 encryption (#295) (#299) + +* Add SM4-CTR encryption algorithm + +Signed-off-by: Jarvis Zheng + +* Adjust block size for sm4 encryption + +Signed-off-by: Jarvis Zheng + +* Add UT for SM4 encryption + +Signed-off-by: Jarvis Zheng + +* Adjust macros indentation for sm4 + +Signed-off-by: Jarvis Zheng + +* Fix format for adding sm4 + +Signed-off-by: Jarvis Zheng + +Check OPENSSL_NO_SM4 before using sm4 encryption (#302) + +Signed-off-by: Jarvis Zheng +Signed-off-by: v01dstar +--- + CMakeLists.txt | 11 + + Makefile | 3 + + TARGETS | 1 + + build_tools/build_detect_platform | 13 + + db/db_options_test.cc | 3 +- + db/db_properties_test.cc | 2 +- + db/db_test.cc | 4 + + db/db_test2.cc | 3 + + db/db_test_util.cc | 17 +- + db/db_test_util.h | 1 + + db/db_wal_test.cc | 15 + + encryption/encryption.cc | 553 +++++++++++++++++++++++++++++ + encryption/encryption.h | 139 ++++++++ + encryption/encryption_test.cc | 181 ++++++++++ + encryption/in_memory_key_manager.h | 83 +++++ + env/env_basic_test.cc | 63 +++- + file/filename.cc | 32 +- + file/filename.h | 8 + + include/rocksdb/encryption.h | 111 ++++++ + src.mk | 1 + + test_util/testutil.cc | 9 +- + test_util/testutil.h | 53 +++ + tools/db_bench_tool.cc | 54 ++- + 23 files changed, 1335 insertions(+), 25 deletions(-) + create mode 100644 encryption/encryption.cc + create mode 100644 encryption/encryption.h + create mode 100644 encryption/encryption_test.cc + create mode 100644 encryption/in_memory_key_manager.h + create mode 100644 include/rocksdb/encryption.h + +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 6f9debfb5..b913d921a 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -70,6 +70,7 @@ option(WITH_SNAPPY "build with SNAPPY" OFF) + option(WITH_LZ4 "build with lz4" OFF) + option(WITH_ZLIB "build with zlib" OFF) + option(WITH_ZSTD "build with zstd" OFF) ++option(WITH_OPENSSL "build with openssl" OFF) + option(WITH_WINDOWS_UTF8_FILENAMES "use UTF8 as characterset for opening files, regardles of the system code page" OFF) + if (WITH_WINDOWS_UTF8_FILENAMES) + add_definitions(-DROCKSDB_WINDOWS_UTF8_FILENAMES) +@@ -174,6 +175,14 @@ else() + include_directories(${ZSTD_INCLUDE_DIR}) + list(APPEND THIRDPARTY_LIBS zstd::zstd) + endif() ++ ++ if(WITH_OPENSSL) ++ find_package(OpenSSL REQUIRED) ++ add_definitions(-DOPENSSL) ++ include_directories(${OPENSSL_INCLUDE_DIR}) ++ # Only the crypto library is needed. ++ list(APPEND THIRDPARTY_LIBS ${OPENSSL_CRYPTO_LIBRARIES}) ++ endif() + endif() + + option(WITH_MD_LIBRARY "build with MD" ON) +@@ -721,6 +730,7 @@ set(SOURCES + db/write_controller.cc + db/write_stall_stats.cc + db/write_thread.cc ++ encryption/encryption.cc + env/composite_env.cc + env/env.cc + env/env_chroot.cc +@@ -1372,6 +1382,7 @@ if(WITH_TESTS) + db/write_batch_test.cc + db/write_callback_test.cc + db/write_controller_test.cc ++ encryption/encryption_test.cc + env/env_test.cc + env/io_posix_test.cc + env/mock_env_test.cc +diff --git a/Makefile b/Makefile +index 8829be9d8..6293dcdd4 100644 +--- a/Makefile ++++ b/Makefile +@@ -703,6 +703,7 @@ TESTS_PLATFORM_DEPENDENT := \ + crc32c_test \ + coding_test \ + inlineskiplist_test \ ++ encryption_test \ + env_basic_test \ + env_test \ + env_logger_test \ +@@ -1988,6 +1989,8 @@ wide_column_serialization_test: $(OBJ_DIR)/db/wide/wide_column_serialization_tes + $(AM_LINK) + + wide_columns_helper_test: $(OBJ_DIR)/db/wide/wide_columns_helper_test.o $(TEST_LIBRARY) $(LIBRARY) ++ ++encryption_test: encryption/encryption_test.o $(LIBOBJECTS) $(TESTHARNESS) + $(AM_LINK) + + #------------------------------------------------- +diff --git a/TARGETS b/TARGETS +index 3200813ae..f6959a9ed 100644 +--- a/TARGETS ++++ b/TARGETS +@@ -108,6 +108,7 @@ cpp_library_wrapper(name="rocksdb_lib", srcs=[ + "db/write_controller.cc", + "db/write_stall_stats.cc", + "db/write_thread.cc", ++ "encryption/encryption.cc", + "env/composite_env.cc", + "env/env.cc", + "env/env_chroot.cc", +diff --git a/build_tools/build_detect_platform b/build_tools/build_detect_platform +index a5e2b5aa2..51b60d74c 100755 +--- a/build_tools/build_detect_platform ++++ b/build_tools/build_detect_platform +@@ -485,6 +485,19 @@ EOF + fi + fi + ++ if ! test $ROCKSDB_DISABLE_OPENSSL; then ++ # Test whether OpenSSL library is installed ++ $CXX $CFLAGS -x c++ - -o /dev/null 2>/dev/null < ++ int main() {} ++EOF ++ if [ "$?" = 0 ]; then ++ COMMON_FLAGS="$COMMON_FLAGS -DOPENSSL" ++ PLATFORM_LDFLAGS="$PLATFORM_LDFLAGS -lcrypto" ++ JAVA_LDFLAGS="$JAVA_LDFLAGS -lcrypto" ++ fi ++ fi ++ + if ! test $ROCKSDB_DISABLE_PTHREAD_MUTEX_ADAPTIVE_NP; then + # Test whether PTHREAD_MUTEX_ADAPTIVE_NP mutex type is available + $CXX $PLATFORM_CXXFLAGS -x c++ - -o test.o 2>/dev/null <GetSystemClock()); + } + if (getenv("ENCRYPTED_ENV")) { +- std::shared_ptr provider; +- std::string provider_id = getenv("ENCRYPTED_ENV"); +- if (provider_id.find("=") == std::string::npos && +- !EndsWith(provider_id, "://test")) { +- provider_id = provider_id + "://test"; +- } +- EXPECT_OK(EncryptionProvider::CreateFromString(ConfigOptions(), provider_id, +- &provider)); +- encrypted_env_ = NewEncryptedEnv(mem_env_ ? mem_env_ : base_env, provider); ++#ifdef OPENSSL ++ std::shared_ptr key_manager( ++ new test::TestKeyManager); ++ encrypted_env_ = NewKeyManagedEncryptedEnv(Env::Default(), key_manager); ++#else ++ fprintf(stderr, "EncryptedEnv is not available without OpenSSL."); ++ assert(false); ++#endif + } + env_ = new SpecialEnv(encrypted_env_ ? encrypted_env_ + : (mem_env_ ? mem_env_ : base_env)); +diff --git a/db/db_test_util.h b/db/db_test_util.h +index 34ef23b77..dc34352dc 100644 +--- a/db/db_test_util.h ++++ b/db/db_test_util.h +@@ -29,6 +29,7 @@ + #include "rocksdb/compaction_filter.h" + #include "rocksdb/convenience.h" + #include "rocksdb/db.h" ++#include "rocksdb/encryption.h" + #include "rocksdb/env.h" + #include "rocksdb/file_system.h" + #include "rocksdb/filter_policy.h" +diff --git a/db/db_wal_test.cc b/db/db_wal_test.cc +index fbc01131e..18ae59b65 100644 +--- a/db/db_wal_test.cc ++++ b/db/db_wal_test.cc +@@ -1677,6 +1677,9 @@ INSTANTIATE_TEST_CASE_P( + // at the end of any of the logs + // - We do not expect to open the data store for corruption + TEST_P(DBWALTestWithParams, kTolerateCorruptedTailRecords) { ++ if (getenv("ENCRYPTED_ENV")) { ++ return; ++ } + bool trunc = std::get<0>(GetParam()); // Corruption style + // Corruption offset position + int corrupt_offset = std::get<1>(GetParam()); +@@ -1739,6 +1742,9 @@ TEST_P(DBWALTestWithParams, kAbsoluteConsistency) { + // We don't expect the data store to be opened if there is any inconsistency + // between WAL and SST files + TEST_F(DBWALTest, kPointInTimeRecoveryCFConsistency) { ++ if (getenv("ENCRYPTED_ENV")) { ++ return; ++ } + Options options = CurrentOptions(); + options.avoid_flush_during_recovery = true; + +@@ -1946,6 +1952,9 @@ TEST_F(DBWALTest, FixSyncWalOnObseletedWalWithNewManifestCausingMissingWAL) { + // - We expect to open data store under all circumstances + // - We expect only data upto the point where the first error was encountered + TEST_P(DBWALTestWithParams, kPointInTimeRecovery) { ++ if (getenv("ENCRYPTED_ENV")) { ++ return; ++ } + const int maxkeys = + RecoveryTestHelper::kWALFilesCount * RecoveryTestHelper::kKeysPerWALFile; + +@@ -2006,6 +2015,9 @@ TEST_P(DBWALTestWithParams, kPointInTimeRecovery) { + // - We expect to open the data store under all scenarios + // - We expect to have recovered records past the corruption zone + TEST_P(DBWALTestWithParams, kSkipAnyCorruptedRecords) { ++ if (getenv("ENCRYPTED_ENV")) { ++ return; ++ } + bool trunc = std::get<0>(GetParam()); // Corruption style + // Corruption offset position + int corrupt_offset = std::get<1>(GetParam()); +@@ -2215,6 +2227,9 @@ TEST_F(DBWALTest, RecoverWithoutFlushMultipleCF) { + // 4. Open again. See if it can correctly handle previous corruption. + TEST_P(DBWALTestWithParamsVaryingRecoveryMode, + RecoverFromCorruptedWALWithoutFlush) { ++ if (getenv("ENCRYPTED_ENV")) { ++ return; ++ } + const int kAppendKeys = 100; + Options options = CurrentOptions(); + options.avoid_flush_during_recovery = true; +diff --git a/encryption/encryption.cc b/encryption/encryption.cc +new file mode 100644 +index 000000000..02f7f1bdc +--- /dev/null ++++ b/encryption/encryption.cc +@@ -0,0 +1,553 @@ ++// Copyright 2020 TiKV Project Authors. Licensed under Apache-2.0. ++ ++#ifdef OPENSSL ++ ++#include "encryption/encryption.h" ++ ++#include ++ ++#include ++#include ++ ++#include "file/filename.h" ++#include "port/port.h" ++#include "test_util/sync_point.h" ++ ++namespace ROCKSDB_NAMESPACE { ++namespace encryption { ++ ++namespace { ++uint64_t GetBigEndian64(const unsigned char* buf) { ++ if (port::kLittleEndian) { ++ return (static_cast(buf[0]) << 56) + ++ (static_cast(buf[1]) << 48) + ++ (static_cast(buf[2]) << 40) + ++ (static_cast(buf[3]) << 32) + ++ (static_cast(buf[4]) << 24) + ++ (static_cast(buf[5]) << 16) + ++ (static_cast(buf[6]) << 8) + ++ (static_cast(buf[7])); ++ } else { ++ return *(reinterpret_cast(buf)); ++ } ++} ++ ++void PutBigEndian64(uint64_t value, unsigned char* buf) { ++ if (port::kLittleEndian) { ++ buf[0] = static_cast((value >> 56) & 0xff); ++ buf[1] = static_cast((value >> 48) & 0xff); ++ buf[2] = static_cast((value >> 40) & 0xff); ++ buf[3] = static_cast((value >> 32) & 0xff); ++ buf[4] = static_cast((value >> 24) & 0xff); ++ buf[5] = static_cast((value >> 16) & 0xff); ++ buf[6] = static_cast((value >> 8) & 0xff); ++ buf[7] = static_cast(value & 0xff); ++ } else { ++ *(reinterpret_cast(buf)) = value; ++ } ++} ++} // anonymous namespace ++ ++// AESCTRCipherStream use OpenSSL EVP API with CTR mode to encrypt and decrypt ++// data, instead of using the CTR implementation provided by ++// BlockAccessCipherStream. Benefits: ++// ++// 1. The EVP API automatically figure out if AES-NI can be enabled. ++// 2. Keep the data format consistent with OpenSSL (e.g. how IV is interpreted ++// as block counter). ++// ++// References for the openssl EVP API: ++// * man page: https://www.openssl.org/docs/man1.1.1/man3/EVP_EncryptUpdate.html ++// * SO answer for random access: https://stackoverflow.com/a/57147140/11014942 ++// * ++// https://medium.com/@amit.kulkarni/encrypting-decrypting-a-file-using-openssl-evp-b26e0e4d28d4 ++Status AESCTRCipherStream::Cipher(uint64_t file_offset, char* data, ++ size_t data_size, bool is_encrypt) { ++#if OPENSSL_VERSION_NUMBER < 0x01000200f ++ (void)file_offset; ++ (void)data; ++ (void)data_size; ++ (void)is_encrypt; ++ return Status::NotSupported("OpenSSL version < 1.0.2"); ++#else ++ int ret = 1; ++ EVP_CIPHER_CTX* ctx = nullptr; ++ InitCipherContext(ctx); ++ if (ctx == nullptr) { ++ return Status::IOError("Failed to create cipher context."); ++ } ++ ++ const size_t block_size = BlockSize(); ++ ++ uint64_t block_index = file_offset / block_size; ++ uint64_t block_offset = file_offset % block_size; ++ ++ // In CTR mode, OpenSSL EVP API treat the IV as a 128-bit big-endien, and ++ // increase it by 1 for each block. ++ // ++ // In case of unsigned integer overflow in c++, the result is moduloed by ++ // range, means only the lowest bits of the result will be kept. ++ // http://www.cplusplus.com/articles/DE18T05o/ ++ uint64_t iv_high = initial_iv_high_; ++ uint64_t iv_low = initial_iv_low_ + block_index; ++ if (std::numeric_limits::max() - block_index < initial_iv_low_) { ++ iv_high++; ++ } ++ unsigned char iv[block_size]; ++ PutBigEndian64(iv_high, iv); ++ PutBigEndian64(iv_low, iv + sizeof(uint64_t)); ++ ++ ret = EVP_CipherInit(ctx, cipher_, ++ reinterpret_cast(key_.data()), iv, ++ (is_encrypt ? 1 : 0)); ++ if (ret != 1) { ++ return Status::IOError("Failed to init cipher."); ++ } ++ ++ // Disable padding. After disabling padding, data size should always be ++ // multiply of block size. ++ ret = EVP_CIPHER_CTX_set_padding(ctx, 0); ++ if (ret != 1) { ++ FreeCipherContext(ctx); ++ return Status::IOError("Failed to disable padding for cipher context."); ++ } ++ ++ uint64_t data_offset = 0; ++ size_t remaining_data_size = data_size; ++ int output_size = 0; ++ unsigned char partial_block[block_size]; ++ ++ // In the following we assume EVP_CipherUpdate allow in and out buffer are ++ // the same, to save one memcpy. This is not specified in official man page. ++ ++ // Handle partial block at the beginning. The parital block is copied to ++ // buffer to fake a full block. ++ if (block_offset > 0) { ++ size_t partial_block_size = ++ std::min(block_size - block_offset, remaining_data_size); ++ memcpy(partial_block + block_offset, data, partial_block_size); ++ ret = EVP_CipherUpdate(ctx, partial_block, &output_size, partial_block, ++ static_cast(block_size)); ++ if (ret != 1) { ++ FreeCipherContext(ctx); ++ return Status::IOError("Crypter failed for first block, offset " + ++ std::to_string(file_offset)); ++ } ++ if (output_size != static_cast(block_size)) { ++ FreeCipherContext(ctx); ++ return Status::IOError( ++ "Unexpected crypter output size for first block, expected " + ++ std::to_string(block_size) + " vs actual " + ++ std::to_string(output_size)); ++ } ++ memcpy(data, partial_block + block_offset, partial_block_size); ++ data_offset += partial_block_size; ++ remaining_data_size -= partial_block_size; ++ } ++ ++ // Handle full blocks in the middle. ++ if (remaining_data_size >= block_size) { ++ size_t actual_data_size = ++ remaining_data_size - remaining_data_size % block_size; ++ unsigned char* full_blocks = ++ reinterpret_cast(data) + data_offset; ++ ret = EVP_CipherUpdate(ctx, full_blocks, &output_size, full_blocks, ++ static_cast(actual_data_size)); ++ if (ret != 1) { ++ FreeCipherContext(ctx); ++ return Status::IOError("Crypter failed at offset " + ++ std::to_string(file_offset + data_offset)); ++ } ++ if (output_size != static_cast(actual_data_size)) { ++ FreeCipherContext(ctx); ++ return Status::IOError("Unexpected crypter output size, expected " + ++ std::to_string(actual_data_size) + " vs actual " + ++ std::to_string(output_size)); ++ } ++ data_offset += actual_data_size; ++ remaining_data_size -= actual_data_size; ++ } ++ ++ // Handle partial block at the end. The parital block is copied to buffer to ++ // fake a full block. ++ if (remaining_data_size > 0) { ++ assert(remaining_data_size < block_size); ++ memcpy(partial_block, data + data_offset, remaining_data_size); ++ ret = EVP_CipherUpdate(ctx, partial_block, &output_size, partial_block, ++ static_cast(block_size)); ++ if (ret != 1) { ++ FreeCipherContext(ctx); ++ return Status::IOError("Crypter failed for last block, offset " + ++ std::to_string(file_offset + data_offset)); ++ } ++ if (output_size != static_cast(block_size)) { ++ FreeCipherContext(ctx); ++ return Status::IOError( ++ "Unexpected crypter output size for last block, expected " + ++ std::to_string(block_size) + " vs actual " + ++ std::to_string(output_size)); ++ } ++ memcpy(data + data_offset, partial_block, remaining_data_size); ++ } ++ ++ // Since padding is disabled, and the cipher flow always passes a multiply ++ // of block size data while each EVP_CipherUpdate, there is no need to call ++ // EVP_CipherFinal_ex to finish the last block cipher. ++ // Reference to the implement of EVP_CipherFinal_ex: ++ // https://github.com/openssl/openssl/blob/OpenSSL_1_1_1-stable/crypto/evp/evp_enc.c#L219 ++ FreeCipherContext(ctx); ++ return Status::OK(); ++#endif ++} ++ ++Status NewAESCTRCipherStream(EncryptionMethod method, const std::string& key, ++ const std::string& iv, ++ std::unique_ptr* result) { ++ assert(result != nullptr); ++ const EVP_CIPHER* cipher = nullptr; ++ switch (method) { ++ case EncryptionMethod::kAES128_CTR: ++ cipher = EVP_aes_128_ctr(); ++ break; ++ case EncryptionMethod::kAES192_CTR: ++ cipher = EVP_aes_192_ctr(); ++ break; ++ case EncryptionMethod::kAES256_CTR: ++ cipher = EVP_aes_256_ctr(); ++ break; ++ case EncryptionMethod::kSM4_CTR: ++#if OPENSSL_VERSION_NUMBER < 0x1010100fL || defined(OPENSSL_NO_SM4) ++ return Status::InvalidArgument( ++ "Unsupport SM4 encryption method under OpenSSL version: " + ++ std::string(OPENSSL_VERSION_TEXT)); ++#else ++ // Openssl support SM4 after 1.1.1 release version. ++ cipher = EVP_sm4_ctr(); ++ break; ++#endif ++ default: ++ return Status::InvalidArgument("Unsupported encryption method: " + ++ std::to_string(static_cast(method))); ++ } ++ if (key.size() != KeySize(method)) { ++ return Status::InvalidArgument( ++ "Encryption key size mismatch. " + std::to_string(key.size()) + ++ "(actual) vs. " + std::to_string(KeySize(method)) + "(expected)."); ++ } ++ if (iv.size() != AES_BLOCK_SIZE) { ++ return Status::InvalidArgument( ++ "iv size not equal to block cipher block size: " + ++ std::to_string(iv.size()) + "(actual) vs. " + ++ std::to_string(AES_BLOCK_SIZE) + "(expected)."); ++ } ++ Slice iv_slice(iv); ++ uint64_t iv_high = ++ GetBigEndian64(reinterpret_cast(iv.data())); ++ uint64_t iv_low = GetBigEndian64( ++ reinterpret_cast(iv.data() + sizeof(uint64_t))); ++ result->reset(new AESCTRCipherStream(cipher, key, iv_high, iv_low)); ++ return Status::OK(); ++} ++ ++Status AESEncryptionProvider::CreateCipherStream( ++ const std::string& fname, const EnvOptions& /*options*/, Slice& /*prefix*/, ++ std::unique_ptr* result) { ++ assert(result != nullptr); ++ FileEncryptionInfo file_info; ++ Status s = key_manager_->GetFile(fname, &file_info); ++ if (!s.ok()) { ++ return s; ++ } ++ std::unique_ptr cipher_stream; ++ s = NewAESCTRCipherStream(file_info.method, file_info.key, file_info.iv, ++ &cipher_stream); ++ if (!s.ok()) { ++ return s; ++ } ++ *result = std::move(cipher_stream); ++ return Status::OK(); ++} ++ ++KeyManagedEncryptedEnv::KeyManagedEncryptedEnv( ++ Env* base_env, std::shared_ptr& key_manager, ++ std::shared_ptr& provider, ++ std::unique_ptr&& encrypted_env) ++ : EnvWrapper(base_env), ++ key_manager_(key_manager), ++ provider_(provider), ++ encrypted_env_(std::move(encrypted_env)) {} ++ ++KeyManagedEncryptedEnv::~KeyManagedEncryptedEnv() = default; ++ ++Status KeyManagedEncryptedEnv::NewSequentialFile( ++ const std::string& fname, std::unique_ptr* result, ++ const EnvOptions& options) { ++ FileEncryptionInfo file_info; ++ Status s = key_manager_->GetFile(fname, &file_info); ++ if (!s.ok()) { ++ return s; ++ } ++ switch (file_info.method) { ++ case EncryptionMethod::kPlaintext: ++ s = target()->NewSequentialFile(fname, result, options); ++ break; ++ case EncryptionMethod::kAES128_CTR: ++ case EncryptionMethod::kAES192_CTR: ++ case EncryptionMethod::kAES256_CTR: ++ case EncryptionMethod::kSM4_CTR: ++ s = encrypted_env_->NewSequentialFile(fname, result, options); ++ // Hack: when upgrading from TiKV <= v5.0.0-rc, the old current ++ // file is encrypted but it could be replaced with a plaintext ++ // current file. The operation below guarantee that the current ++ // file is read correctly. ++ if (s.ok() && IsCurrentFile(fname)) { ++ if (!IsValidCurrentFile(std::move(*result))) { ++ s = target()->NewSequentialFile(fname, result, options); ++ } else { ++ s = encrypted_env_->NewSequentialFile(fname, result, options); ++ } ++ } ++ break; ++ default: ++ s = Status::InvalidArgument( ++ "Unsupported encryption method: " + ++ std::to_string(static_cast(file_info.method))); ++ } ++ return s; ++} ++ ++Status KeyManagedEncryptedEnv::NewRandomAccessFile( ++ const std::string& fname, std::unique_ptr* result, ++ const EnvOptions& options) { ++ FileEncryptionInfo file_info; ++ Status s = key_manager_->GetFile(fname, &file_info); ++ if (!s.ok()) { ++ return s; ++ } ++ switch (file_info.method) { ++ case EncryptionMethod::kPlaintext: ++ s = target()->NewRandomAccessFile(fname, result, options); ++ break; ++ case EncryptionMethod::kAES128_CTR: ++ case EncryptionMethod::kAES192_CTR: ++ case EncryptionMethod::kAES256_CTR: ++ case EncryptionMethod::kSM4_CTR: ++ s = encrypted_env_->NewRandomAccessFile(fname, result, options); ++ break; ++ default: ++ s = Status::InvalidArgument( ++ "Unsupported encryption method: " + ++ std::to_string(static_cast(file_info.method))); ++ } ++ return s; ++} ++ ++Status KeyManagedEncryptedEnv::NewWritableFile( ++ const std::string& fname, std::unique_ptr* result, ++ const EnvOptions& options) { ++ FileEncryptionInfo file_info; ++ Status s; ++ bool skipped = IsCurrentFile(fname); ++ TEST_SYNC_POINT_CALLBACK("KeyManagedEncryptedEnv::NewWritableFile", &skipped); ++ if (!skipped) { ++ s = key_manager_->NewFile(fname, &file_info); ++ if (!s.ok()) { ++ return s; ++ } ++ } else { ++ file_info.method = EncryptionMethod::kPlaintext; ++ } ++ ++ switch (file_info.method) { ++ case EncryptionMethod::kPlaintext: ++ s = target()->NewWritableFile(fname, result, options); ++ break; ++ case EncryptionMethod::kAES128_CTR: ++ case EncryptionMethod::kAES192_CTR: ++ case EncryptionMethod::kAES256_CTR: ++ case EncryptionMethod::kSM4_CTR: ++ s = encrypted_env_->NewWritableFile(fname, result, options); ++ break; ++ default: ++ s = Status::InvalidArgument( ++ "Unsupported encryption method: " + ++ std::to_string(static_cast(file_info.method))); ++ } ++ if (!s.ok() && !skipped) { ++ // Ignore error ++ key_manager_->DeleteFile(fname); ++ } ++ return s; ++} ++ ++Status KeyManagedEncryptedEnv::ReopenWritableFile( ++ const std::string& fname, std::unique_ptr* result, ++ const EnvOptions& options) { ++ FileEncryptionInfo file_info; ++ Status s = key_manager_->GetFile(fname, &file_info); ++ if (!s.ok()) { ++ return s; ++ } ++ switch (file_info.method) { ++ case EncryptionMethod::kPlaintext: ++ s = target()->ReopenWritableFile(fname, result, options); ++ break; ++ case EncryptionMethod::kAES128_CTR: ++ case EncryptionMethod::kAES192_CTR: ++ case EncryptionMethod::kAES256_CTR: ++ case EncryptionMethod::kSM4_CTR: ++ s = encrypted_env_->ReopenWritableFile(fname, result, options); ++ break; ++ default: ++ s = Status::InvalidArgument( ++ "Unsupported encryption method: " + ++ std::to_string(static_cast(file_info.method))); ++ } ++ return s; ++} ++ ++Status KeyManagedEncryptedEnv::ReuseWritableFile( ++ const std::string& fname, const std::string& old_fname, ++ std::unique_ptr* result, const EnvOptions& options) { ++ FileEncryptionInfo file_info; ++ // ReuseWritableFile is only used in the context of rotating WAL file and ++ // reuse them. Old content is discardable and new WAL records are to ++ // overwrite the file. So NewFile() should be called. ++ Status s = key_manager_->NewFile(fname, &file_info); ++ if (!s.ok()) { ++ return s; ++ } ++ switch (file_info.method) { ++ case EncryptionMethod::kPlaintext: ++ s = target()->ReuseWritableFile(fname, old_fname, result, options); ++ break; ++ case EncryptionMethod::kAES128_CTR: ++ case EncryptionMethod::kAES192_CTR: ++ case EncryptionMethod::kAES256_CTR: ++ case EncryptionMethod::kSM4_CTR: ++ s = encrypted_env_->ReuseWritableFile(fname, old_fname, result, options); ++ break; ++ default: ++ s = Status::InvalidArgument( ++ "Unsupported encryption method: " + ++ std::to_string(static_cast(file_info.method))); ++ } ++ if (!s.ok()) { ++ return s; ++ } ++ s = key_manager_->LinkFile(old_fname, fname); ++ if (!s.ok()) { ++ return s; ++ } ++ s = key_manager_->DeleteFile(old_fname); ++ return s; ++} ++ ++Status KeyManagedEncryptedEnv::NewRandomRWFile( ++ const std::string& fname, std::unique_ptr* result, ++ const EnvOptions& options) { ++ FileEncryptionInfo file_info; ++ // NewRandomRWFile is only used in the context of external file ingestion, ++ // for rewriting global seqno. So it should call GetFile() instead of ++ // NewFile(). ++ Status s = key_manager_->GetFile(fname, &file_info); ++ if (!s.ok()) { ++ return s; ++ } ++ switch (file_info.method) { ++ case EncryptionMethod::kPlaintext: ++ s = target()->NewRandomRWFile(fname, result, options); ++ break; ++ case EncryptionMethod::kAES128_CTR: ++ case EncryptionMethod::kAES192_CTR: ++ case EncryptionMethod::kAES256_CTR: ++ case EncryptionMethod::kSM4_CTR: ++ s = encrypted_env_->NewRandomRWFile(fname, result, options); ++ break; ++ default: ++ s = Status::InvalidArgument( ++ "Unsupported encryption method: " + ++ std::to_string(static_cast(file_info.method))); ++ } ++ if (!s.ok()) { ++ // Ignore error ++ key_manager_->DeleteFile(fname); ++ } ++ return s; ++} ++ ++Status KeyManagedEncryptedEnv::DeleteFile(const std::string& fname) { ++ // Try deleting the file from file system before updating key_manager. ++ Status s = target()->DeleteFile(fname); ++ if (!s.ok()) { ++ return s; ++ } ++ return key_manager_->DeleteFile(fname); ++} ++ ++Status KeyManagedEncryptedEnv::LinkFile(const std::string& src_fname, ++ const std::string& dst_fname) { ++ if (IsCurrentFile(dst_fname)) { ++ assert(IsCurrentFile(src_fname)); ++ Status s = target()->LinkFile(src_fname, dst_fname); ++ return s; ++ } else { ++ assert(!IsCurrentFile(src_fname)); ++ } ++ Status s = key_manager_->LinkFile(src_fname, dst_fname); ++ if (!s.ok()) { ++ return s; ++ } ++ s = target()->LinkFile(src_fname, dst_fname); ++ if (!s.ok()) { ++ Status delete_status __attribute__((__unused__)) = ++ key_manager_->DeleteFile(dst_fname); ++ assert(delete_status.ok()); ++ } ++ return s; ++} ++ ++Status KeyManagedEncryptedEnv::RenameFile(const std::string& src_fname, ++ const std::string& dst_fname) { ++ if (IsCurrentFile(dst_fname)) { ++ assert(IsCurrentFile(src_fname)); ++ Status s = target()->RenameFile(src_fname, dst_fname); ++ // Replacing with plaintext requires deleting the info in the key manager. ++ // The stale current file info exists when upgrading from TiKV <= v5.0.0-rc. ++ Status delete_status __attribute__((__unused__)) = ++ key_manager_->DeleteFile(dst_fname); ++ assert(delete_status.ok()); ++ return s; ++ } else { ++ assert(!IsCurrentFile(src_fname)); ++ } ++ // Link(copy)File instead of RenameFile to avoid losing src_fname info when ++ // failed to rename the src_fname in the file system. ++ Status s = key_manager_->LinkFile(src_fname, dst_fname); ++ if (!s.ok()) { ++ return s; ++ } ++ s = target()->RenameFile(src_fname, dst_fname); ++ if (s.ok()) { ++ s = key_manager_->DeleteFile(src_fname); ++ } else { ++ Status delete_status __attribute__((__unused__)) = ++ key_manager_->DeleteFile(dst_fname); ++ assert(delete_status.ok()); ++ } ++ return s; ++} ++ ++Env* NewKeyManagedEncryptedEnv(Env* base_env, ++ std::shared_ptr& key_manager) { ++ std::shared_ptr provider( ++ new AESEncryptionProvider(key_manager.get())); ++ std::unique_ptr encrypted_env(NewEncryptedEnv(base_env, provider)); ++ return new KeyManagedEncryptedEnv(base_env, key_manager, provider, ++ std::move(encrypted_env)); ++} ++ ++} // namespace encryption ++} // namespace ROCKSDB_NAMESPACE ++ ++#endif // OPENSSL +diff --git a/encryption/encryption.h b/encryption/encryption.h +new file mode 100644 +index 000000000..b62c75fc3 +--- /dev/null ++++ b/encryption/encryption.h +@@ -0,0 +1,139 @@ ++// Copyright 2020 TiKV Project Authors. Licensed under Apache-2.0. ++ ++#pragma once ++#ifdef OPENSSL ++#include ++#include ++ ++#include ++ ++#include "rocksdb/encryption.h" ++#include "rocksdb/env_encryption.h" ++#include "util/string_util.h" ++ ++namespace ROCKSDB_NAMESPACE { ++namespace encryption { ++ ++#if OPENSSL_VERSION_NUMBER < 0x01010000f ++ ++#define InitCipherContext(ctx) \ ++ EVP_CIPHER_CTX ctx##_var; \ ++ ctx = &ctx##_var; \ ++ EVP_CIPHER_CTX_init(ctx); ++ ++// do nothing ++#define FreeCipherContext(ctx) ++ ++#else ++ ++#define InitCipherContext(ctx) \ ++ ctx = EVP_CIPHER_CTX_new(); \ ++ if (ctx != nullptr) { \ ++ if (EVP_CIPHER_CTX_reset(ctx) != 1) { \ ++ ctx = nullptr; \ ++ } \ ++ } ++ ++#define FreeCipherContext(ctx) EVP_CIPHER_CTX_free(ctx); ++ ++#endif ++ ++// TODO: OpenSSL Lib does not export SM4_BLOCK_SIZE by now. ++// Need to remove SM4_BLOCK_Size once Openssl lib support the definition. ++// SM4 uses 128-bit block size as AES. ++// Ref: ++// https://github.com/openssl/openssl/blob/OpenSSL_1_1_1-stable/include/crypto/sm4.h#L24 ++#define SM4_BLOCK_SIZE 16 ++ ++class AESCTRCipherStream : public BlockAccessCipherStream { ++ public: ++ AESCTRCipherStream(const EVP_CIPHER* cipher, const std::string& key, ++ uint64_t iv_high, uint64_t iv_low) ++ : cipher_(cipher), ++ key_(key), ++ initial_iv_high_(iv_high), ++ initial_iv_low_(iv_low) {} ++ ++ ~AESCTRCipherStream() = default; ++ ++ size_t BlockSize() override { ++ // Openssl support SM4 after 1.1.1 release version. ++#if OPENSSL_VERSION_NUMBER >= 0x1010100fL && !defined(OPENSSL_NO_SM4) ++ if (EVP_CIPHER_nid(cipher_) == NID_sm4_ctr) { ++ return SM4_BLOCK_SIZE; ++ } ++#endif ++ return AES_BLOCK_SIZE; // 16 ++ } ++ ++ Status Encrypt(uint64_t file_offset, char* data, size_t data_size) override { ++ return Cipher(file_offset, data, data_size, true /*is_encrypt*/); ++ } ++ ++ Status Decrypt(uint64_t file_offset, char* data, size_t data_size) override { ++ return Cipher(file_offset, data, data_size, false /*is_encrypt*/); ++ } ++ ++ protected: ++ // Following methods required by BlockAccessCipherStream is unused. ++ ++ void AllocateScratch(std::string& /*scratch*/) override { ++ // should not be called. ++ assert(false); ++ } ++ ++ Status EncryptBlock(uint64_t /*block_index*/, char* /*data*/, ++ char* /*scratch*/) override { ++ return Status::NotSupported("EncryptBlock should not be called."); ++ } ++ ++ Status DecryptBlock(uint64_t /*block_index*/, char* /*data*/, ++ char* /*scratch*/) override { ++ return Status::NotSupported("DecryptBlock should not be called."); ++ } ++ ++ private: ++ Status Cipher(uint64_t file_offset, char* data, size_t data_size, ++ bool is_encrypt); ++ ++ const EVP_CIPHER* cipher_; ++ const std::string key_; ++ const uint64_t initial_iv_high_; ++ const uint64_t initial_iv_low_; ++}; ++ ++extern Status NewAESCTRCipherStream( ++ EncryptionMethod method, const std::string& key, const std::string& iv, ++ std::unique_ptr* result); ++ ++class AESEncryptionProvider : public EncryptionProvider { ++ public: ++ AESEncryptionProvider(KeyManager* key_manager) : key_manager_(key_manager) {} ++ virtual ~AESEncryptionProvider() = default; ++ ++ const char* Name() const override { return "AESEncryptionProvider"; } ++ ++ size_t GetPrefixLength() const override { return 0; } ++ ++ Status CreateNewPrefix(const std::string& /*fname*/, char* /*prefix*/, ++ size_t /*prefix_length*/) const override { ++ return Status::OK(); ++ } ++ ++ Status AddCipher(const std::string& /*descriptor*/, const char* /*cipher*/, ++ size_t /*len*/, bool /*for_write*/) override { ++ return Status::NotSupported(); ++ } ++ ++ Status CreateCipherStream( ++ const std::string& fname, const EnvOptions& options, Slice& prefix, ++ std::unique_ptr* result) override; ++ ++ private: ++ KeyManager* key_manager_; ++}; ++ ++} // namespace encryption ++} // namespace ROCKSDB_NAMESPACE ++ ++#endif // OPENSSL +diff --git a/encryption/encryption_test.cc b/encryption/encryption_test.cc +new file mode 100644 +index 000000000..4902fa065 +--- /dev/null ++++ b/encryption/encryption_test.cc +@@ -0,0 +1,181 @@ ++// Copyright 2020 TiKV Project Authors. Licensed under Apache-2.0. ++ ++#include "encryption/encryption.h" ++ ++#include "port/stack_trace.h" ++#include "test_util/testharness.h" ++#include "test_util/testutil.h" ++ ++#ifdef OPENSSL ++ ++namespace ROCKSDB_NAMESPACE { ++namespace encryption { ++ ++const unsigned char KEY[33] = ++ "\xe4\x3e\x8e\xca\x2a\x83\xe1\x88\xfb\xd8\x02\xdc\xf3\x62\x65\x3e" ++ "\x00\xee\x31\x39\xe7\xfd\x1d\x92\x20\xb1\x62\xae\xb2\xaf\x0f\x1a"; ++const unsigned char IV_RANDOM[17] = ++ "\x77\x9b\x82\x72\x26\xb5\x76\x50\xf7\x05\xd2\xd6\xb8\xaa\xa9\x2c"; ++const unsigned char IV_OVERFLOW_LOW[17] = ++ "\x77\x9b\x82\x72\x26\xb5\x76\x50\xff\xff\xff\xff\xff\xff\xff\xff"; ++const unsigned char IV_OVERFLOW_FULL[17] = ++ "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"; ++ ++constexpr size_t MAX_SIZE = 16 * 10; ++ ++// Test to make sure output of AESCTRCipherStream is the same as output from ++// OpenSSL EVP API. ++class EncryptionTest ++ : public testing::TestWithParam> { ++ public: ++ unsigned char plaintext[MAX_SIZE]; ++ // Reserve a bit more room to make sure OpenSSL have enough buffer. ++ unsigned char ciphertext[MAX_SIZE + 16 * 2]; ++ ++ void GenerateCiphertext(const unsigned char* iv) { ++ Random rnd(666); ++ std::string random_string = ++ rnd.HumanReadableString(static_cast(MAX_SIZE)); ++ memcpy(plaintext, random_string.data(), MAX_SIZE); ++ ++ int ret = 1; ++ EVP_CIPHER_CTX* ctx; ++ InitCipherContext(ctx); ++ assert(ctx != nullptr); ++ ++ const EVP_CIPHER* cipher = nullptr; ++ EncryptionMethod method = std::get<1>(GetParam()); ++ switch (method) { ++ case EncryptionMethod::kAES128_CTR: ++ cipher = EVP_aes_128_ctr(); ++ break; ++ case EncryptionMethod::kAES192_CTR: ++ cipher = EVP_aes_192_ctr(); ++ break; ++ case EncryptionMethod::kAES256_CTR: ++ cipher = EVP_aes_256_ctr(); ++ break; ++#if OPENSSL_VERSION_NUMBER >= 0x1010100fL && !defined(OPENSSL_NO_SM4) ++ // Openssl support SM4 after 1.1.1 release version. ++ case EncryptionMethod::kSM4_CTR: ++ cipher = EVP_sm4_ctr(); ++ break; ++#endif ++ default: ++ assert(false); ++ } ++ assert(cipher != nullptr); ++ ++ ret = EVP_EncryptInit(ctx, cipher, KEY, iv); ++ assert(ret == 1); ++ int output_size = 0; ++ ret = EVP_EncryptUpdate(ctx, ciphertext, &output_size, plaintext, ++ static_cast(MAX_SIZE)); ++ assert(ret == 1); ++ int final_output_size = 0; ++ ret = EVP_EncryptFinal(ctx, ciphertext + output_size, &final_output_size); ++ assert(ret == 1); ++ assert(output_size + final_output_size == MAX_SIZE); ++ FreeCipherContext(ctx); ++ } ++ ++ void TestEncryptionImpl(size_t start, size_t end, const unsigned char* iv, ++ bool* success) { ++ assert(start < end && end <= MAX_SIZE); ++ GenerateCiphertext(iv); ++ ++ EncryptionMethod method = std::get<1>(GetParam()); ++ std::string key_str(reinterpret_cast(KEY), KeySize(method)); ++ std::string iv_str(reinterpret_cast(iv), 16); ++ std::unique_ptr cipher_stream; ++ ASSERT_OK(NewAESCTRCipherStream(method, key_str, iv_str, &cipher_stream)); ++ ++ size_t data_size = end - start; ++ // Allocate exact size. AESCTRCipherStream should make sure there will be ++ // no memory corruption. ++ std::unique_ptr data(new char[data_size]); ++ ++ if (std::get<0>(GetParam())) { ++ // Encrypt ++ memcpy(data.get(), plaintext + start, data_size); ++ ASSERT_OK(cipher_stream->Encrypt(start, data.get(), data_size)); ++ ASSERT_EQ(0, memcmp(ciphertext + start, data.get(), data_size)); ++ } else { ++ // Decrypt ++ memcpy(data.get(), ciphertext + start, data_size); ++ ASSERT_OK(cipher_stream->Decrypt(start, data.get(), data_size)); ++ ASSERT_EQ(0, memcmp(plaintext + start, data.get(), data_size)); ++ } ++ ++ *success = true; ++ } ++ ++ bool TestEncryption(size_t start, size_t end, ++ const unsigned char* iv = IV_RANDOM) { ++ // Workaround failure of ASSERT_* result in return immediately. ++ bool success = false; ++ TestEncryptionImpl(start, end, iv, &success); ++ return success; ++ } ++}; ++ ++TEST_P(EncryptionTest, EncryptionTest) { ++ // One full block. ++ EXPECT_TRUE(TestEncryption(0, 16)); ++ // One block in the middle. ++ EXPECT_TRUE(TestEncryption(16 * 5, 16 * 6)); ++ // Multiple aligned blocks. ++ EXPECT_TRUE(TestEncryption(16 * 5, 16 * 8)); ++ ++ // Random byte at the beginning of a block. ++ EXPECT_TRUE(TestEncryption(16 * 5, 16 * 5 + 1)); ++ // Random byte in the middle of a block. ++ EXPECT_TRUE(TestEncryption(16 * 5 + 4, 16 * 5 + 5)); ++ // Random byte at the end of a block. ++ EXPECT_TRUE(TestEncryption(16 * 5 + 15, 16 * 6)); ++ ++ // Partial block aligned at the beginning. ++ EXPECT_TRUE(TestEncryption(16 * 5, 16 * 5 + 15)); ++ // Partial block aligned at the end. ++ EXPECT_TRUE(TestEncryption(16 * 5 + 1, 16 * 6)); ++ // Multiple blocks with a partial block at the end. ++ EXPECT_TRUE(TestEncryption(16 * 5, 16 * 8 + 15)); ++ // Multiple blocks with a partial block at the beginning. ++ EXPECT_TRUE(TestEncryption(16 * 5 + 1, 16 * 8)); ++ // Partial block at both ends. ++ EXPECT_TRUE(TestEncryption(16 * 5 + 1, 16 * 8 + 15)); ++ ++ // Lower bits of IV overflow. ++ EXPECT_TRUE(TestEncryption(16, 16 * 2, IV_OVERFLOW_LOW)); ++ // Full IV overflow. ++ EXPECT_TRUE(TestEncryption(16, 16 * 2, IV_OVERFLOW_FULL)); ++} ++ ++// Openssl support SM4 after 1.1.1 release version. ++#if OPENSSL_VERSION_NUMBER < 0x1010100fL || defined(OPENSSL_NO_SM4) ++INSTANTIATE_TEST_CASE_P( ++ EncryptionTestInstance, EncryptionTest, ++ testing::Combine(testing::Bool(), ++ testing::Values(EncryptionMethod::kAES128_CTR, ++ EncryptionMethod::kAES192_CTR, ++ EncryptionMethod::kAES256_CTR))); ++#else ++INSTANTIATE_TEST_CASE_P( ++ EncryptionTestInstance, EncryptionTest, ++ testing::Combine(testing::Bool(), ++ testing::Values(EncryptionMethod::kAES128_CTR, ++ EncryptionMethod::kAES192_CTR, ++ EncryptionMethod::kAES256_CTR, ++ EncryptionMethod::kSM4_CTR))); ++#endif ++ ++} // namespace encryption ++} // namespace ROCKSDB_NAMESPACE ++ ++#endif // OPENSSL ++ ++int main(int argc, char** argv) { ++ rocksdb::port::InstallStackTraceHandler(); ++ ::testing::InitGoogleTest(&argc, argv); ++ return RUN_ALL_TESTS(); ++} +diff --git a/encryption/in_memory_key_manager.h b/encryption/in_memory_key_manager.h +new file mode 100644 +index 000000000..dc216b6db +--- /dev/null ++++ b/encryption/in_memory_key_manager.h +@@ -0,0 +1,83 @@ ++#pragma once ++#ifdef OPENSSL ++#include ++ ++#include ++#include ++ ++#include "encryption/encryption.h" ++#include "port/port.h" ++#include "test_util/testutil.h" ++#include "util/mutexlock.h" ++ ++namespace ROCKSDB_NAMESPACE { ++namespace encryption { ++ ++// KeyManager store metadata in memory. It is used in tests and db_bench only. ++class InMemoryKeyManager final : public KeyManager { ++ public: ++ InMemoryKeyManager(EncryptionMethod method) ++ : rnd_(42), ++ method_(method), ++ key_(rnd_.HumanReadableString(static_cast(KeySize(method)))) { ++ assert(method != EncryptionMethod::kUnknown); ++ } ++ ++ virtual ~InMemoryKeyManager() = default; ++ ++ Status GetFile(const std::string& fname, ++ FileEncryptionInfo* file_info) override { ++ assert(file_info != nullptr); ++ MutexLock l(&mu_); ++ if (files_.count(fname) == 0) { ++ return Status::Corruption("File not found: " + fname); ++ } ++ file_info->method = method_; ++ file_info->key = key_; ++ file_info->iv = files_[fname]; ++ return Status::OK(); ++ } ++ ++ Status NewFile(const std::string& fname, ++ FileEncryptionInfo* file_info) override { ++ assert(file_info != nullptr); ++ MutexLock l(&mu_); ++ std::string iv = rnd_.HumanReadableString(AES_BLOCK_SIZE); ++ files_[fname] = iv; ++ file_info->method = method_; ++ file_info->key = key_; ++ file_info->iv = iv; ++ return Status::OK(); ++ } ++ ++ Status DeleteFile(const std::string& fname) override { ++ MutexLock l(&mu_); ++ if (files_.count(fname) == 0) { ++ return Status::Corruption("File not found: " + fname); ++ } ++ files_.erase(fname); ++ return Status::OK(); ++ } ++ ++ Status LinkFile(const std::string& src_fname, ++ const std::string& dst_fname) override { ++ MutexLock l(&mu_); ++ if (files_.count(src_fname) == 0) { ++ return Status::Corruption("File not found: " + src_fname); ++ } ++ files_[dst_fname] = files_[src_fname]; ++ return Status::OK(); ++ } ++ ++ private: ++ mutable port::Mutex mu_; ++ Random rnd_; ++ const EncryptionMethod method_; ++ const std::string key_; ++ std::unordered_map files_; ++}; ++ ++} // namespace encryption ++} // namespace ROCKSDB_NAMESPACE ++ ++#endif // OPENSSL +diff --git a/env/env_basic_test.cc b/env/env_basic_test.cc +index 3a5472203..6211f2193 100644 +--- a/env/env_basic_test.cc ++++ b/env/env_basic_test.cc +@@ -9,6 +9,7 @@ + #include + #include + ++#include "db/db_test_util.h" + #include "env/mock_env.h" + #include "file/file_util.h" + #include "rocksdb/convenience.h" +@@ -117,6 +118,16 @@ static Env* GetInspectedEnv() { + return inspected_env.get(); + } + ++#ifdef OPENSSL ++static Env* GetKeyManagedEncryptedEnv() { ++ static std::shared_ptr key_manager( ++ new test::TestKeyManager); ++ static std::unique_ptr key_managed_encrypted_env( ++ NewKeyManagedEncryptedEnv(Env::Default(), key_manager)); ++ return key_managed_encrypted_env.get(); ++} ++#endif // OPENSSL ++ + } // namespace + class EnvBasicTestWithParam + : public testing::Test, +@@ -157,8 +168,12 @@ INSTANTIATE_TEST_CASE_P(MemEnv, EnvBasicTestWithParam, + INSTANTIATE_TEST_CASE_P(InspectedEnv, EnvBasicTestWithParam, + ::testing::Values(&GetInspectedEnv)); + +-namespace { ++#ifdef OPENSSL ++INSTANTIATE_TEST_CASE_P(KeyManagedEncryptedEnv, EnvBasicTestWithParam, ++ ::testing::Values(&GetKeyManagedEncryptedEnv)); ++#endif // OPENSSL + ++namespace { + // Returns a vector of 0 or 1 Env*, depending whether an Env is registered for + // TEST_ENV_URI. + // +@@ -185,6 +200,52 @@ INSTANTIATE_TEST_CASE_P(CustomEnv, EnvBasicTestWithParam, + INSTANTIATE_TEST_CASE_P(CustomEnv, EnvMoreTestWithParam, + ::testing::ValuesIn(GetCustomEnvs())); + ++TEST_P(EnvBasicTestWithParam, RenameCurrent) { ++ if (!getenv("ENCRYPTED_ENV")) { ++ return; ++ } ++ Slice result; ++ char scratch[100]; ++ std::unique_ptr seq_file; ++ std::unique_ptr writable_file; ++ std::vector children; ++ ++ // Create an encrypted `CURRENT` file so it shouldn't be skipped . ++ SyncPoint::GetInstance()->SetCallBack( ++ "KeyManagedEncryptedEnv::NewWritableFile", [&](void* arg) { ++ bool* skip = static_cast(arg); ++ *skip = false; ++ }); ++ SyncPoint::GetInstance()->EnableProcessing(); ++ ASSERT_OK( ++ env_->NewWritableFile(test_dir_ + "/CURRENT", &writable_file, soptions_)); ++ SyncPoint::GetInstance()->ClearAllCallBacks(); ++ SyncPoint::GetInstance()->DisableProcessing(); ++ ASSERT_OK(writable_file->Append("MANIFEST-0")); ++ ASSERT_OK(writable_file->Close()); ++ writable_file.reset(); ++ ++ ASSERT_OK( ++ env_->NewSequentialFile(test_dir_ + "/CURRENT", &seq_file, soptions_)); ++ ASSERT_OK(seq_file->Read(100, &result, scratch)); ++ ASSERT_EQ(0, result.compare("MANIFEST-0")); ++ ++ // Create a plaintext `CURRENT` temp file. ++ ASSERT_OK(env_->NewWritableFile(test_dir_ + "/current.dbtmp.plain", ++ &writable_file, soptions_)); ++ ASSERT_OK(writable_file->Append("MANIFEST-1")); ++ ASSERT_OK(writable_file->Close()); ++ writable_file.reset(); ++ ++ ASSERT_OK(env_->RenameFile(test_dir_ + "/current.dbtmp.plain", ++ test_dir_ + "/CURRENT")); ++ ++ ASSERT_OK( ++ env_->NewSequentialFile(test_dir_ + "/CURRENT", &seq_file, soptions_)); ++ ASSERT_OK(seq_file->Read(100, &result, scratch)); ++ ASSERT_EQ(0, result.compare("MANIFEST-1")); ++} ++ + TEST_P(EnvBasicTestWithParam, Basics) { + uint64_t file_size; + std::unique_ptr writable_file; +diff --git a/file/filename.cc b/file/filename.cc +index fb7d25472..35059622d 100644 +--- a/file/filename.cc ++++ b/file/filename.cc +@@ -29,6 +29,32 @@ static const std::string kRocksDbTFileExt = "sst"; + static const std::string kLevelDbTFileExt = "ldb"; + static const std::string kRocksDBBlobFileExt = "blob"; + static const std::string kArchivalDirName = "archive"; ++static const std::string kUnencryptedTempFileNameSuffix = "dbtmp.plain"; ++ ++bool IsCurrentFile(const std::string& fname) { ++ // skip CURRENT file. ++ size_t current_length = strlen("CURRENT"); ++ if (fname.length() >= current_length && ++ !fname.compare(fname.length() - current_length, current_length, ++ "CURRENT")) { ++ return true; ++ } ++ // skip temporary file for CURRENT file. ++ size_t temp_length = kUnencryptedTempFileNameSuffix.length(); ++ if (fname.length() >= temp_length && ++ !fname.compare(fname.length() - temp_length, temp_length, ++ kUnencryptedTempFileNameSuffix)) { ++ return true; ++ } ++ return false; ++} ++ ++bool IsValidCurrentFile(std::unique_ptr seq_file) { ++ Slice result; ++ char scratch[64]; ++ seq_file->Read(8, &result, scratch); ++ return result.compare("MANIFEST") == 0; ++} + + // Given a path, flatten the path name by replacing all chars not in + // {[0-9,a-z,A-Z,-,_,.]} with _. And append '_LOG\0' at the end. +@@ -182,6 +208,10 @@ std::string TempFileName(const std::string& dbname, uint64_t number) { + return MakeFileName(dbname, number, kTempFileNameSuffix.c_str()); + } + ++std::string TempPlainFileName(const std::string& dbname, uint64_t number) { ++ return MakeFileName(dbname, number, kUnencryptedTempFileNameSuffix.c_str()); ++} ++ + InfoLogPrefix::InfoLogPrefix(bool has_log_dir, + const std::string& db_absolute_path) { + if (!has_log_dir) { +@@ -392,7 +422,7 @@ IOStatus SetCurrentFile(FileSystem* fs, const std::string& dbname, + Slice contents = manifest; + assert(contents.starts_with(dbname + "/")); + contents.remove_prefix(dbname.size() + 1); +- std::string tmp = TempFileName(dbname, descriptor_number); ++ std::string tmp = TempPlainFileName(dbname, descriptor_number); + IOStatus s = WriteStringToFile(fs, contents.ToString() + "\n", tmp, true); + TEST_SYNC_POINT_CALLBACK("SetCurrentFile:BeforeRename", &s); + if (s.ok()) { +diff --git a/file/filename.h b/file/filename.h +index 2eb125b6a..44501128a 100644 +--- a/file/filename.h ++++ b/file/filename.h +@@ -37,6 +37,14 @@ constexpr char kFilePathSeparator = '\\'; + constexpr char kFilePathSeparator = '/'; + #endif + ++// Some non-sensitive files are not encrypted to preserve atomicity of file ++// operations. ++extern bool IsCurrentFile(const std::string& fname); ++ ++// Determine if the content is read from the valid current file. ++extern bool IsValidCurrentFile( ++ std::unique_ptr seq_file); ++ + // Return the name of the log file with the specified number + // in the db named by "dbname". The result will be prefixed with + // "dbname". +diff --git a/include/rocksdb/encryption.h b/include/rocksdb/encryption.h +new file mode 100644 +index 000000000..b8f5e91e9 +--- /dev/null ++++ b/include/rocksdb/encryption.h +@@ -0,0 +1,111 @@ ++// Copyright 2020 TiKV Project Authors. Licensed under Apache-2.0. ++ ++#pragma once ++#ifdef OPENSSL ++ ++#include ++#include ++ ++#include "rocksdb/env.h" ++ ++namespace ROCKSDB_NAMESPACE { ++namespace encryption { ++ ++class AESEncryptionProvider; ++ ++enum class EncryptionMethod : int { ++ kUnknown = 0, ++ kPlaintext = 1, ++ kAES128_CTR = 2, ++ kAES192_CTR = 3, ++ kAES256_CTR = 4, ++ kSM4_CTR = 5, ++}; ++ ++inline size_t KeySize(EncryptionMethod method) { ++ switch (method) { ++ case EncryptionMethod::kAES128_CTR: ++ return 16; ++ case EncryptionMethod::kAES192_CTR: ++ return 24; ++ case EncryptionMethod::kAES256_CTR: ++ return 32; ++ case EncryptionMethod::kSM4_CTR: ++ return 16; ++ default: ++ return 0; ++ }; ++} ++ ++struct FileEncryptionInfo { ++ EncryptionMethod method = EncryptionMethod::kUnknown; ++ std::string key; ++ std::string iv; ++}; ++ ++// Interface to manage encryption keys for files. KeyManagedEncryptedEnv ++// will query KeyManager for the key being used for each file, and update ++// KeyManager when it creates a new file or moving files around. ++class KeyManager { ++ public: ++ virtual ~KeyManager() = default; ++ ++ virtual Status GetFile(const std::string& fname, ++ FileEncryptionInfo* file_info) = 0; ++ virtual Status NewFile(const std::string& fname, ++ FileEncryptionInfo* file_info) = 0; ++ virtual Status DeleteFile(const std::string& fname) = 0; ++ virtual Status LinkFile(const std::string& src_fname, ++ const std::string& dst_fname) = 0; ++}; ++ ++// An Env with underlying files being encrypted. It holds a reference to an ++// external KeyManager for encryption key management. ++class KeyManagedEncryptedEnv : public EnvWrapper { ++ public: ++ KeyManagedEncryptedEnv(Env* base_env, ++ std::shared_ptr& key_manager, ++ std::shared_ptr& provider, ++ std::unique_ptr&& encrypted_env); ++ ++ virtual ~KeyManagedEncryptedEnv(); ++ ++ Status NewSequentialFile(const std::string& fname, ++ std::unique_ptr* result, ++ const EnvOptions& options) override; ++ Status NewRandomAccessFile(const std::string& fname, ++ std::unique_ptr* result, ++ const EnvOptions& options) override; ++ Status NewWritableFile(const std::string& fname, ++ std::unique_ptr* result, ++ const EnvOptions& options) override; ++ Status ReopenWritableFile(const std::string& fname, ++ std::unique_ptr* result, ++ const EnvOptions& options) override; ++ Status ReuseWritableFile(const std::string& fname, ++ const std::string& old_fname, ++ std::unique_ptr* result, ++ const EnvOptions& options) override; ++ Status NewRandomRWFile(const std::string& fname, ++ std::unique_ptr* result, ++ const EnvOptions& options) override; ++ ++ Status DeleteFile(const std::string& fname) override; ++ Status LinkFile(const std::string& src_fname, ++ const std::string& dst_fname) override; ++ Status RenameFile(const std::string& src_fname, ++ const std::string& dst_fname) override; ++ ++ private: ++ const std::shared_ptr key_manager_; ++ const std::shared_ptr provider_; ++ const std::unique_ptr encrypted_env_; ++}; ++ ++extern Env* NewKeyManagedEncryptedEnv(Env* base_env, ++ std::shared_ptr& key_manager); ++ ++} // namespace encryption ++} // namespace ROCKSDB_NAMESPACE ++ ++#endif // OPENSSL +diff --git a/src.mk b/src.mk +index d68733fba..a870f4447 100644 +--- a/src.mk ++++ b/src.mk +@@ -101,6 +101,7 @@ LIB_SOURCES = \ + db/write_controller.cc \ + db/write_stall_stats.cc \ + db/write_thread.cc \ ++ encryption/encryption.cc \ + env/composite_env.cc \ + env/env.cc \ + env/env_chroot.cc \ +diff --git a/test_util/testutil.cc b/test_util/testutil.cc +index ce221e79b..334bfe6f6 100644 +--- a/test_util/testutil.cc ++++ b/test_util/testutil.cc +@@ -37,6 +37,14 @@ void RegisterCustomObjects(int /*argc*/, char** /*argv*/) {} + namespace ROCKSDB_NAMESPACE { + namespace test { + ++#ifdef OPENSSL ++const std::string TestKeyManager::default_key = ++ "\x12\x34\x56\x78\x12\x34\x56\x78\x12\x34\x56\x78\x12\x34\x56\x78\x12\x34" ++ "\x56\x78\x12\x34\x56\x78"; ++const std::string TestKeyManager::default_iv = ++ "\xaa\xbb\xcc\xdd\xaa\xbb\xcc\xdd\xaa\xbb\xcc\xdd\xaa\xbb\xcc\xdd"; ++#endif ++ + const uint32_t kDefaultFormatVersion = BlockBasedTableOptions().format_version; + const std::set kFooterFormatVersionsToTest{ + // Non-legacy, before big footer changes +@@ -739,7 +747,6 @@ int RegisterTestObjects(ObjectLibrary& library, const std::string& arg) { + return static_cast(library.GetFactoryCount(&num_types)); + } + +- + void RegisterTestLibrary(const std::string& arg) { + static bool registered = false; + if (!registered) { +diff --git a/test_util/testutil.h b/test_util/testutil.h +index eca1ff794..5a173ca40 100644 +--- a/test_util/testutil.h ++++ b/test_util/testutil.h +@@ -10,12 +10,17 @@ + #pragma once + #include + #include ++#include ++#include + #include + #include + + #include "env/composite_env_wrapper.h" ++#include "file/filename.h" + #include "file/writable_file_writer.h" + #include "rocksdb/compaction_filter.h" ++#include "rocksdb/db.h" ++#include "rocksdb/encryption.h" + #include "rocksdb/env.h" + #include "rocksdb/iterator.h" + #include "rocksdb/merge_operator.h" +@@ -43,6 +48,54 @@ class SequentialFileReader; + + namespace test { + ++// TODO(yiwu): Use InMemoryKeyManager instead for tests. ++#ifdef OPENSSL ++class TestKeyManager : public encryption::KeyManager { ++ public: ++ virtual ~TestKeyManager() = default; ++ ++ static const std::string default_key; ++ static const std::string default_iv; ++ std::mutex mutex; ++ std::set file_set; ++ ++ Status GetFile(const std::string& fname, ++ encryption::FileEncryptionInfo* file_info) override { ++ std::lock_guard l(mutex); ++ if (file_set.find(fname) == file_set.end()) { ++ file_info->method = encryption::EncryptionMethod::kPlaintext; ++ } else { ++ file_info->method = encryption::EncryptionMethod::kAES192_CTR; ++ } ++ file_info->key = default_key; ++ file_info->iv = default_iv; ++ return Status::OK(); ++ } ++ ++ Status NewFile(const std::string& fname, ++ encryption::FileEncryptionInfo* file_info) override { ++ std::lock_guard l(mutex); ++ file_info->method = encryption::EncryptionMethod::kAES192_CTR; ++ file_info->key = default_key; ++ file_info->iv = default_iv; ++ file_set.insert(fname); ++ return Status::OK(); ++ } ++ ++ Status DeleteFile(const std::string& fname) override { ++ std::lock_guard l(mutex); ++ file_set.erase(fname); ++ return Status::OK(); ++ } ++ ++ Status LinkFile(const std::string& /*src*/, const std::string& dst) override { ++ std::lock_guard l(mutex); ++ file_set.insert(dst); ++ return Status::OK(); ++ } ++}; ++#endif ++ + extern const uint32_t kDefaultFormatVersion; + extern const std::set kFooterFormatVersionsToTest; + +diff --git a/tools/db_bench_tool.cc b/tools/db_bench_tool.cc +index 40132e69c..53c56fc5e 100644 +--- a/tools/db_bench_tool.cc ++++ b/tools/db_bench_tool.cc +@@ -42,6 +42,7 @@ + #include "db/db_impl/db_impl.h" + #include "db/malloc_stats.h" + #include "db/version_set.h" ++#include "encryption/in_memory_key_manager.h" + #include "monitoring/histogram.h" + #include "monitoring/statistics_impl.h" + #include "options/cf_options.h" +@@ -50,6 +51,7 @@ + #include "rocksdb/cache.h" + #include "rocksdb/convenience.h" + #include "rocksdb/db.h" ++#include "rocksdb/encryption.h" + #include "rocksdb/env.h" + #include "rocksdb/filter_policy.h" + #include "rocksdb/memtablerep.h" +@@ -1773,6 +1775,10 @@ DEFINE_bool(build_info, false, + DEFINE_bool(track_and_verify_wals_in_manifest, false, + "If true, enable WAL tracking in the MANIFEST"); + ++DEFINE_string( ++ encryption_method, "", ++ "If non-empty, enable encryption with the specific encryption method."); ++ + namespace ROCKSDB_NAMESPACE { + namespace { + static Status CreateMemTableRepFactory( +@@ -2189,7 +2195,7 @@ enum OperationType : unsigned char { + kOthers + }; + +-static std::unordered_map> ++static std::unordered_map > + OperationTypeString = {{kRead, "read"}, {kWrite, "write"}, + {kDelete, "delete"}, {kSeek, "seek"}, + {kMerge, "merge"}, {kUpdate, "update"}, +@@ -2213,7 +2219,7 @@ class Stats { + uint64_t last_op_finish_; + uint64_t last_report_finish_; + std::unordered_map, +- std::hash> ++ std::hash > + hist_; + std::string message_; + bool exclude_from_merge_; +@@ -5137,7 +5143,7 @@ class Benchmark { + if (db_.db == nullptr) { + num_key_gens = multi_dbs_.size(); + } +- std::vector> key_gens(num_key_gens); ++ std::vector > key_gens(num_key_gens); + int64_t max_ops = num_ops * num_key_gens; + int64_t ops_per_stage = max_ops; + if (FLAGS_num_column_families > 1 && FLAGS_num_hot_column_families > 0) { +@@ -5242,11 +5248,11 @@ class Benchmark { + std::string random_value; + // Queue that stores scheduled timestamp of disposable entries deletes, + // along with starting index of disposable entry keys to delete. +- std::vector>> disposable_entries_q( +- num_key_gens); ++ std::vector > > ++ disposable_entries_q(num_key_gens); + // --- End of variables used in disposable/persistent keys simulation. + +- std::vector> expanded_key_guards; ++ std::vector > expanded_key_guards; + std::vector expanded_keys; + if (FLAGS_expand_range_tombstones) { + expanded_key_guards.resize(range_tombstone_width_); +@@ -5604,7 +5610,8 @@ class Benchmark { + auto num_db = db_list.size(); + size_t num_levels = static_cast(open_options_.num_levels); + size_t output_level = open_options_.num_levels - 1; +- std::vector>> sorted_runs(num_db); ++ std::vector > > sorted_runs( ++ num_db); + std::vector num_files_at_level0(num_db, 0); + if (compaction_style == kCompactionStyleLevel) { + if (num_levels == 0) { +@@ -6216,7 +6223,7 @@ class Benchmark { + int64_t found = 0; + ReadOptions options = read_options_; + std::vector keys; +- std::vector> key_guards; ++ std::vector > key_guards; + std::vector values(entries_per_batch_); + PinnableSlice* pin_values = new PinnableSlice[entries_per_batch_]; + std::unique_ptr pin_values_guard(pin_values); +@@ -6313,9 +6320,9 @@ class Benchmark { + const size_t batch_size = entries_per_batch_; + std::vector ranges; + std::vector lkeys; +- std::vector> lkey_guards; ++ std::vector > lkey_guards; + std::vector rkeys; +- std::vector> rkey_guards; ++ std::vector > rkey_guards; + std::vector sizes; + while (ranges.size() < batch_size) { + // Ugly without C++17 return from emplace_back +@@ -7025,7 +7032,7 @@ class Benchmark { + std::unique_ptr end_key_guard; + Slice end_key = AllocateKey(&end_key_guard); + uint64_t num_range_deletions = 0; +- std::vector> expanded_key_guards; ++ std::vector > expanded_key_guards; + std::vector expanded_keys; + if (FLAGS_expand_range_tombstones) { + expanded_key_guards.resize(range_tombstone_width_); +@@ -8715,6 +8722,31 @@ int db_bench_tool(int argc, char** argv) { + exit(1); + } + ++#ifdef OPENSSL ++ if (!FLAGS_encryption_method.empty()) { ++ ROCKSDB_NAMESPACE::encryption::EncryptionMethod method = ++ ROCKSDB_NAMESPACE::encryption::EncryptionMethod::kUnknown; ++ if (!strcasecmp(FLAGS_encryption_method.c_str(), "AES128CTR")) { ++ method = ROCKSDB_NAMESPACE::encryption::EncryptionMethod::kAES128_CTR; ++ } else if (!strcasecmp(FLAGS_encryption_method.c_str(), "AES192CTR")) { ++ method = ROCKSDB_NAMESPACE::encryption::EncryptionMethod::kAES192_CTR; ++ } else if (!strcasecmp(FLAGS_encryption_method.c_str(), "AES256CTR")) { ++ method = ROCKSDB_NAMESPACE::encryption::EncryptionMethod::kAES256_CTR; ++ } else if (!strcasecmp(FLAGS_encryption_method.c_str(), "SM4CTR")) { ++ method = ROCKSDB_NAMESPACE::encryption::EncryptionMethod::kSM4_CTR; ++ } ++ if (method == ROCKSDB_NAMESPACE::encryption::EncryptionMethod::kUnknown) { ++ fprintf(stderr, "Unknown encryption method %s\n", ++ FLAGS_encryption_method.c_str()); ++ exit(1); ++ } ++ std::shared_ptr key_manager( ++ new ROCKSDB_NAMESPACE::encryption::InMemoryKeyManager(method)); ++ FLAGS_env = ROCKSDB_NAMESPACE::encryption::NewKeyManagedEncryptedEnv( ++ FLAGS_env, key_manager); ++ } ++#endif // OPENSSL ++ + if (!strcasecmp(FLAGS_compaction_fadvice.c_str(), "NONE")) { + FLAGS_compaction_fadvice_e = ROCKSDB_NAMESPACE::Options::NONE; + } else if (!strcasecmp(FLAGS_compaction_fadvice.c_str(), "NORMAL")) { +-- +2.45.0 +