diff --git a/src/storage/compression/zstd.cpp b/src/storage/compression/zstd.cpp index 0976925b0a88..e92f5836ba3a 100644 --- a/src/storage/compression/zstd.cpp +++ b/src/storage/compression/zstd.cpp @@ -90,17 +90,30 @@ idx_t ZSTDStorage::StringFinalAnalyze(AnalyzeState &state_p) { //===--------------------------------------------------------------------===// // Compress //===--------------------------------------------------------------------===// -struct StringMetadata { +struct string_metadata_t { + idx_t size; +}; + +struct dictionary_metadata_t { idx_t size; }; class ZSTDCompressionState : public CompressionState { public: + static constexpr int COMPRESSION_LEVEL = 3; + explicit ZSTDCompressionState(ColumnDataCheckpointer &checkpointer) : checkpointer(checkpointer), function(checkpointer.GetCompressionFunction(CompressionType::COMPRESSION_ZSTD)), - heap(BufferAllocator::Get(checkpointer.GetDatabase())) { + heap(BufferAllocator::Get(checkpointer.GetDatabase())), + zstd_cdict(nullptr) { CreateEmptySegment(checkpointer.GetRowGroup().start); + zstd_context = duckdb_zstd::ZSTD_createCCtx(); + } + + ~ZSTDCompressionState() override { + duckdb_zstd::ZSTD_freeCCtx(zstd_context); + duckdb_zstd::ZSTD_freeCDict(zstd_cdict); } ColumnDataCheckpointer &checkpointer; @@ -111,6 +124,10 @@ class ZSTDCompressionState : public CompressionState { BufferHandle current_handle; // ZSTDDictionary current_dictionary + + duckdb_zstd::ZSTD_CCtx *zstd_context; + duckdb_zstd::ZSTD_CDict *zstd_cdict; + // buffer for current segment idx_t total_data_size; StringHeap heap; @@ -125,6 +142,21 @@ class ZSTDCompressionState : public CompressionState { //! The offset within the current block // idx_t offset; + void CreateCompressionDictionary(const char *str, size_t size) { + + zstd_cdict = duckdb_zstd::ZSTD_createCDict(str, size, COMPRESSION_LEVEL); + + size_t dict_size = duckdb_zstd::ZSTD_sizeof_CDict(zstd_cdict); + + dictionary_metadata_t meta { + .size = dict_size + }; + + // write meta & dictionary + current_data_ptr = data_ptr_cast(memcpy(current_data_ptr, &meta, sizeof(dictionary_metadata_t))); + current_data_ptr = data_ptr_cast(memcpy(current_data_ptr, &zstd_cdict, dict_size)); + } + void CreateEmptySegment(idx_t row_start) { auto &db = checkpointer.GetDatabase(); auto &type = checkpointer.GetType(); @@ -158,7 +190,7 @@ class ZSTDCompressionState : public CompressionState { } void AddNull() { - // TODO: make this more efficient + // TODO: fix AddString(""); } @@ -179,23 +211,28 @@ class ZSTDCompressionState : public CompressionState { // } void AddString(const string_t &str) { + // TODO: train dictionary in a better way + if (!zstd_cdict) { + CreateCompressionDictionary(str.GetData(), str.GetSize()); + } + + // TODO: check space + size_t dst_capacity = SIZE_T_MAX; - // TODO: add to dictionary + auto data_dst = current_data_ptr + sizeof(string_metadata_t); + size_t compressed_size = duckdb_zstd::ZSTD_compress_usingCDict(zstd_context, data_dst, dst_capacity, str.GetData(), str.GetSize(), zstd_cdict); // Create metadata - StringMetadata meta { - .size = str.GetSize() + string_metadata_t meta { + .size = compressed_size }; - // TODO: check if there is space - // Write metadata - current_data_ptr = data_ptr_cast(memcpy(current_data_ptr, &meta, sizeof(StringMetadata))); - total_data_size += sizeof(StringMetadata); + memcpy(current_data_ptr, &meta, sizeof(string_metadata_t)); - // Write string - current_data_ptr = data_ptr_cast(memcpy(current_data_ptr, str.GetData(), str.GetSize())); - total_data_size += str.GetSize(); + // move data ptr + current_data_ptr = data_dst + compressed_size; + total_data_size += sizeof(string_metadata_t) + compressed_size; } }; @@ -234,12 +271,25 @@ void ZSTDStorage::FinalizeCompress(CompressionState &state_p) { //===--------------------------------------------------------------------===// struct ZSTDScanState : public StringScanState { BufferHandle handle; + + duckdb_zstd::ZSTD_DDict *zstd_ddict; + + data_ptr_t current_data_ptr; }; + unique_ptr ZSTDStorage::StringInitScan(ColumnSegment &segment) { auto result = make_uniq(); auto &buffer_manager = BufferManager::GetBufferManager(segment.db); result->handle = buffer_manager.Pin(segment.block); + + // load dictionary + auto data = result->handle.Ptr() + segment.GetBlockOffset(); + dictionary_metadata_t *dict_meta = reinterpret_cast(data); + + result->zstd_ddict = duckdb_zstd::ZSTD_createDDict(data + sizeof(dictionary_metadata_t), dict_meta->size); + result->current_data_ptr = data + dict_meta->size; + return std::move(result); } @@ -260,7 +310,32 @@ void ZSTDStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &sta } void ZSTDStorage::StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result) { - StringScanPartial(segment, state, scan_count, result, 0); + // StringScanPartial(segment, state, scan_count, result, 0); + + auto &scan_state = state.scan_state->Cast(); + auto &block_manager = segment.GetBlockManager(); + auto &buffer_manager = block_manager.buffer_manager; + + data_ptr_t src = scan_state.current_data_ptr; + auto result_data = FlatVector::GetData(result); + + duckdb_zstd::ZSTD_DCtx *zstd_context = duckdb_zstd::ZSTD_createDCtx(); + + // create temporary buffer + // TODO: fix this + char buffer[1024]; + + for (idx_t i = 0; i < scan_count; i++) { + + // get metadata + string_metadata_t *meta = reinterpret_cast(src); + size_t uncompressed_size = duckdb_zstd::ZSTD_decompress_usingDDict(zstd_context, buffer, 1024, src + sizeof(string_metadata_t), meta->size, scan_state.zstd_ddict); + + // ALLOCATE STRING? + result_data[i] = string_t(buffer, uncompressed_size); + } + + duckdb_zstd::ZSTD_freeDCtx(zstd_context); } //===--------------------------------------------------------------------===// diff --git a/test/sql/storage/compression/zstd/zstd.test b/test/sql/storage/compression/zstd/zstd.test index 99dbff1f1683..cf1db384ee0c 100644 --- a/test/sql/storage/compression/zstd/zstd.test +++ b/test/sql/storage/compression/zstd/zstd.test @@ -12,5 +12,16 @@ statement ok CREATE TABLE test (a VARCHAR); statement ok -INSERT INTO test VALUES ('11'), ('11'), ('12'), (NULL) +INSERT INTO test VALUES ('11'), ('11'), ('12'), (NULL); + +statement ok +checkpoint; + +query I +SELECT * FROM test; +---- +11 +11 +12 +NULL