From 2a262bf94c818919a6348bd232b63a238c7461fd Mon Sep 17 00:00:00 2001 From: Morten Grouleff Date: Wed, 3 Apr 2024 09:42:02 +0200 Subject: [PATCH] Add new constructor to ZstdDictCompress and ZstdDictDecompress that allows the byReference semantics for the provided byte buffer: If you set this to true, you avoid the copying of the dict data into a natively malloc'ed buffer, but then also have to promise that the byte buffer will not be modified before the CTX has been closed. --- .../github/luben/zstd/ZstdDictCompress.java | 28 +++++++++++++++++-- .../github/luben/zstd/ZstdDictDecompress.java | 26 +++++++++++++++-- src/main/native/jni_fast_zstd.c | 22 +++++++++++---- src/test/scala/ZstdDict.scala | 7 +++-- 4 files changed, 68 insertions(+), 15 deletions(-) diff --git a/src/main/java/com/github/luben/zstd/ZstdDictCompress.java b/src/main/java/com/github/luben/zstd/ZstdDictCompress.java index 1fed2ff..f1fb5ca 100644 --- a/src/main/java/com/github/luben/zstd/ZstdDictCompress.java +++ b/src/main/java/com/github/luben/zstd/ZstdDictCompress.java @@ -10,11 +10,14 @@ public class ZstdDictCompress extends SharedDictBase { } private long nativePtr = 0; + + private ByteBuffer sharedDict = null; + private int level = Zstd.defaultCompressionLevel(); private native void init(byte[] dict, int dict_offset, int dict_size, int level); - private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int level); + private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int level, int byReference); private native void free(); @@ -59,6 +62,18 @@ public ZstdDictCompress(byte[] dict, int offset, int length, int level) { * @param level compression level */ public ZstdDictCompress(ByteBuffer dict, int level) { + this(dict, level, false); + } + + /** + * Create a new dictionary for use with fast compress. + * If byReference is true, then the native code does not copy the data but keeps a reference to the byte buffer, which must then not be modified before this context has been closed. + * + * @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer. + * @param level compression level + * @param byReference tell the native part to use the byte buffer directly and not copy the data when true. + */ + public ZstdDictCompress(ByteBuffer dict, int level, boolean byReference) { this.level = level; int length = dict.limit() - dict.position(); if (!dict.isDirect()) { @@ -67,11 +82,14 @@ public ZstdDictCompress(ByteBuffer dict, int level) { if (length < 0) { throw new IllegalArgumentException("dict cannot be empty."); } - initDirect(dict, dict.position(), length, level); + initDirect(dict, dict.position(), length, level, byReference ? 1 : 0); if (nativePtr == 0L) { throw new IllegalStateException("ZSTD_createCDict failed"); } + if (byReference) { + sharedDict = dict; // ensures the dict is not garbage collected while this object remains, and flags that we should not use native free. + } // Ensures that even if ZstdDictCompress is created and published through a race, no thread could observe // nativePtr == 0. storeFence(); @@ -85,7 +103,11 @@ int level() { @Override void doClose() { if (nativePtr != 0) { - free(); + if (sharedDict == null) { + free(); + } else { + sharedDict = null; + } nativePtr = 0; } } diff --git a/src/main/java/com/github/luben/zstd/ZstdDictDecompress.java b/src/main/java/com/github/luben/zstd/ZstdDictDecompress.java index 70aac1a..f547698 100644 --- a/src/main/java/com/github/luben/zstd/ZstdDictDecompress.java +++ b/src/main/java/com/github/luben/zstd/ZstdDictDecompress.java @@ -11,9 +11,11 @@ public class ZstdDictDecompress extends SharedDictBase { private long nativePtr = 0L; + private ByteBuffer sharedDict = null; + private native void init(byte[] dict, int dict_offset, int dict_size); - private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size); + private native void initDirect(ByteBuffer dict, int dict_offset, int dict_size, int byReference); private native void free(); @@ -52,6 +54,17 @@ public ZstdDictDecompress(byte[] dict, int offset, int length) { * @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer. */ public ZstdDictDecompress(ByteBuffer dict) { + this(dict, false); + } + + /** + * Create a new dictionary for use with fast decompress. + * If byReference is true, then the native code does not copy the data but keeps a reference to the byte buffer, which must then not be modified before this context has been closed. + * + * @param dict Direct ByteBuffer containing dictionary using position and limit to define range in buffer. + * @param byReference tell the native part to use the byte buffer directly and not copy the data when true. + */ + public ZstdDictDecompress(ByteBuffer dict, boolean byReference) { int length = dict.limit() - dict.position(); if (!dict.isDirect()) { @@ -60,11 +73,14 @@ public ZstdDictDecompress(ByteBuffer dict) { if (length < 0) { throw new IllegalArgumentException("dict cannot be empty."); } - initDirect(dict, dict.position(), length); + initDirect(dict, dict.position(), length, byReference ? 1 : 0); if (nativePtr == 0L) { throw new IllegalStateException("ZSTD_createDDict failed"); } + if (byReference) { + sharedDict = dict; // ensures the dict is not garbage collected while this object remains, and flags that we should not use native free. + } // Ensures that even if ZstdDictDecompress is created and published through a race, no thread could observe // nativePtr == 0. storeFence(); @@ -74,7 +90,11 @@ public ZstdDictDecompress(ByteBuffer dict) { @Override void doClose() { if (nativePtr != 0) { - free(); + if (sharedDict == null) { + free(); + } else { + sharedDict = null; + } nativePtr = 0; } } diff --git a/src/main/native/jni_fast_zstd.c b/src/main/native/jni_fast_zstd.c index 7e5f2bb..c488277 100644 --- a/src/main/native/jni_fast_zstd.c +++ b/src/main/native/jni_fast_zstd.c @@ -32,17 +32,22 @@ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictCompress_init /* * Class: com_github_luben_zstd_ZstdDictCompress * Method: init - * Signature: (Ljava/nio/ByteBuffer;III)V + * Signature: (Ljava/nio/ByteBuffer;IIII)V */ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictCompress_initDirect - (JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint level) + (JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint level, jint byReference) { jclass clazz = (*env)->GetObjectClass(env, obj); compress_dict = (*env)->GetFieldID(env, clazz, "nativePtr", "J"); if (NULL == dict) return; void *dict_buff = (*env)->GetDirectBufferAddress(env, dict); if (NULL == dict_buff) return; - ZSTD_CDict* cdict = ZSTD_createCDict(((char *)dict_buff) + dict_offset, dict_size, level); + ZSTD_CDict* cdict = NULL; + if (byReference == 0) { + cdict = ZSTD_createCDict(((char *)dict_buff) + dict_offset, dict_size, level); + } else { + cdict = ZSTD_createCDict_byReference(((char *)dict_buff) + dict_offset, dict_size, level); + } if (NULL == cdict) return; (*env)->SetLongField(env, obj, compress_dict, (jlong)(intptr_t) cdict); } @@ -85,17 +90,22 @@ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictDecompress_init /* * Class: com_github_luben_zstd_ZstdDictDecompress * Method: initDirect - * Signature: (Ljava/nio/ByteBuffer;II)V + * Signature: (Ljava/nio/ByteBuffer;III)V */ JNIEXPORT void JNICALL Java_com_github_luben_zstd_ZstdDictDecompress_initDirect - (JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size) + (JNIEnv *env, jobject obj, jobject dict, jint dict_offset, jint dict_size, jint byReference) { jclass clazz = (*env)->GetObjectClass(env, obj); decompress_dict = (*env)->GetFieldID(env, clazz, "nativePtr", "J"); if (NULL == dict) return; void *dict_buff = (*env)->GetDirectBufferAddress(env, dict); - ZSTD_DDict* ddict = ZSTD_createDDict(((char *)dict_buff) + dict_offset, dict_size); + ZSTD_DDict* ddict = NULL; + if (byReference == 0) { + ddict = ZSTD_createDDict(((char *)dict_buff) + dict_offset, dict_size); + } else { + ddict = ZSTD_createDDict_byReference(((char *)dict_buff) + dict_offset, dict_size); + } if (NULL == ddict) return; (*env)->SetLongField(env, obj, decompress_dict, (jlong)(intptr_t) ddict); diff --git a/src/test/scala/ZstdDict.scala b/src/test/scala/ZstdDict.scala index a332775..aa2e902 100644 --- a/src/test/scala/ZstdDict.scala +++ b/src/test/scala/ZstdDict.scala @@ -104,17 +104,18 @@ class ZstdDictSpec extends AnyFlatSpec { assert(input.toSeq == decompressed.toSeq) } - it should s"round-trip compression/decompression ByteBuffers with fast dict at level $level with legacy $legacy" in { + it should s"round-trip compression/decompression ByteBuffers with fast dict at level $level with byReference $legacy" in { + val byReference = legacy // Reuse the variance flag here. val size = input.length val inBuf = ByteBuffer.allocateDirect(size) inBuf.put(input) inBuf.flip() - val cdict = new ZstdDictCompress(dictInDirectByteBuffer, level) + val cdict = new ZstdDictCompress(dictInDirectByteBuffer, level, byReference) val compressed = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt); Zstd.compress(compressed, inBuf, cdict) compressed.flip() cdict.close - val ddict = new ZstdDictDecompress(dictInDirectByteBuffer) + val ddict = new ZstdDictDecompress(dictInDirectByteBuffer, byReference) val decompressed = ByteBuffer.allocateDirect(size) Zstd.decompress(decompressed, compressed, ddict) decompressed.flip()