Skip to content

Commit

Permalink
Allow users to pass in a BufferPool implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzawa-san authored and luben committed Oct 20, 2020
1 parent 92c6058 commit dd2588e
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 78 deletions.
55 changes: 14 additions & 41 deletions src/main/java/com/github/luben/zstd/BufferPool.java
Original file line number Diff line number Diff line change
@@ -1,50 +1,23 @@
package com.github.luben.zstd;

import java.lang.ref.SoftReference;
import java.util.Map;
import java.util.HashMap;
import java.util.Deque;
import java.util.ArrayDeque;
import java.nio.ByteBuffer;

/**
* An pool of buffers which uses a simple reference queue to recycle old buffers.
* An an interface that allows users to customize how buffers are recycled.
*/
class BufferPool {
private static final Map<Integer, SoftReference<BufferPool>> pools = new HashMap<Integer, SoftReference<BufferPool>>();
public interface BufferPool {

static BufferPool get(int length) {
synchronized (pools) {
SoftReference<BufferPool> poolReference = pools.get(length);
BufferPool pool;
if (poolReference == null || (pool = poolReference.get()) == null) {
pool = new BufferPool(length);
poolReference = new SoftReference<BufferPool>(pool);
pools.put(length, poolReference);
}
return pool;
}
}
/**
* Fetch a buffer from the pool.
* @param capacity the desired size of the buffer
* @return a heap buffer with arrayOffset of 0
*/
ByteBuffer get(int capacity);

private final int length;
private final Deque<byte[]> queue;
/**
* Return a buffer to the pool.
* @param buffer the buffer to return
*/
void release(ByteBuffer buffer);

BufferPool(int length) {
this.length = length;
this.queue = new ArrayDeque<byte[]>();
}

synchronized byte[] checkOut() {
byte[] buffer = queue.pollFirst();
if (buffer == null) {
buffer = new byte[length];
}
return buffer;
}

synchronized void checkIn(byte[] buffer) {
if (length != buffer.length) {
throw new IllegalStateException("buffer size mismatch");
}
queue.addLast(buffer);
}
}
46 changes: 46 additions & 0 deletions src/main/java/com/github/luben/zstd/RecyclingBufferPool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.github.luben.zstd;

import java.lang.ref.SoftReference;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.HashMap;
import java.util.Deque;
import java.util.ArrayDeque;

/**
* An pool of buffers which uses a simple reference queue to recycle old buffers.
*/
public class RecyclingBufferPool implements BufferPool {
public static final BufferPool INSTANCE = new RecyclingBufferPool();

private final Map<Integer, SoftReference<Deque<ByteBuffer>>> pools;

private RecyclingBufferPool() {
this.pools = new HashMap<Integer, SoftReference<Deque<ByteBuffer>>>();
}

private Deque<ByteBuffer> getDeque(int capacity) {
SoftReference<Deque<ByteBuffer>> dequeReference = pools.get(capacity);
Deque<ByteBuffer> deque;
if (dequeReference == null || (deque = dequeReference.get()) == null) {
deque = new ArrayDeque<ByteBuffer>();
dequeReference = new SoftReference<Deque<ByteBuffer>>(deque);
pools.put(capacity, dequeReference);
}
return deque;
}

@Override
public synchronized ByteBuffer get(int capacity) {
ByteBuffer buffer = getDeque(capacity).pollFirst();
if (buffer == null) {
buffer = ByteBuffer.allocate(capacity);
}
return buffer;
}

@Override
public synchronized void release(ByteBuffer buffer) {
getDeque(buffer.capacity()).addLast(buffer);
}
}
7 changes: 7 additions & 0 deletions src/main/java/com/github/luben/zstd/Zstd.java
Original file line number Diff line number Diff line change
Expand Up @@ -1272,4 +1272,11 @@ public static ByteBuffer decompress(ByteBuffer srcBuff, ZstdDictDecompress dict,
ctx.close();
}
}

static final byte[] extractArray(ByteBuffer buffer) {
if (!buffer.hasArray() || buffer.arrayOffset() != 0) {
throw new IllegalArgumentException("provided ByteBuffer lacks array or has non-zero arrayOffset");
}
return buffer.array();
}
}
26 changes: 20 additions & 6 deletions src/main/java/com/github/luben/zstd/ZstdInputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.io.FilterInputStream;
import java.io.IOException;
import java.lang.IndexOutOfBoundsException;
import java.nio.ByteBuffer;

import com.github.luben.zstd.util.Native;

Expand All @@ -28,7 +29,8 @@ public class ZstdInputStream extends FilterInputStream {
private long srcSize = 0;
private boolean needRead = true;
private boolean finalize = true;
private BufferPool srcPool;
private final BufferPool bufferPool;
private final ByteBuffer srcByteBuffer;
private final byte[] src;
private static final int srcBuffSize = (int) recommendedDInSize();

Expand All @@ -44,12 +46,25 @@ public class ZstdInputStream extends FilterInputStream {
private native int initDStream(long stream);
private native int decompressStream(long stream, byte[] dst, int dst_size, byte[] src, int src_size);

// The main constructor / legacy version dispatcher
/**
* create a new decompressing InputStream
* @param inStream the stream to wrap
*/
public ZstdInputStream(InputStream inStream) {
this(inStream, RecyclingBufferPool.INSTANCE);
}

/**
* create a new decompressing InputStream
* @param inStream the stream to wrap
* @param bufferPool the pool to fetch and return buffers
*/
public ZstdInputStream(InputStream inStream, BufferPool bufferPool) {
// FilterInputStream constructor
super(inStream);
this.srcPool = BufferPool.get(srcBuffSize);
this.src = srcPool.checkOut();
this.bufferPool = bufferPool;
this.srcByteBuffer = bufferPool.get(srcBuffSize);
this.src = Zstd.extractArray(srcByteBuffer);
// memory barrier
synchronized(this) {
this.stream = createDStream();
Expand Down Expand Up @@ -238,8 +253,7 @@ public synchronized void close() throws IOException {
return;
}
isClosed = true;
srcPool.checkIn(src);
srcPool = null;
bufferPool.release(srcByteBuffer);
freeDStream(stream);
in.close();
}
Expand Down
36 changes: 25 additions & 11 deletions src/main/java/com/github/luben/zstd/ZstdOutputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.io.OutputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;

import com.github.luben.zstd.util.Native;

Expand All @@ -19,7 +20,8 @@ public class ZstdOutputStream extends FilterOutputStream {
private final long stream;
private long srcPos = 0;
private long dstPos = 0;
private BufferPool dstPool;
private final BufferPool bufferPool;
private final ByteBuffer dstByteBuffer;
private final byte[] dst;
private boolean isClosed = false;
private boolean finalize = true;
Expand Down Expand Up @@ -64,15 +66,28 @@ public ZstdOutputStream(OutputStream outStream, int level) throws IOException {
this.closeFrameOnFlush = false;
Zstd.setCompressionLevel(this.stream, level);
}

/* The constructor */

/**
* create a new compressing OutputStream
* @param outStream the stream to wrap
*/
public ZstdOutputStream(OutputStream outStream) throws IOException {
this(outStream, RecyclingBufferPool.INSTANCE);
}

/**
* create a new compressing OutputStream
* @param outStream the stream to wrap
* @param bufferPool the pool to fetch and return buffers
*/
public ZstdOutputStream(OutputStream outStream, BufferPool bufferPool) throws IOException {
super(outStream);
// create compression context
this.stream = createCStream();
this.closeFrameOnFlush = false;
this.dstPool = BufferPool.get(dstSize);
this.dst = dstPool.checkOut();
this.bufferPool = bufferPool;
this.dstByteBuffer = bufferPool.get(dstSize);
this.dst = Zstd.extractArray(dstByteBuffer);
}

public synchronized ZstdOutputStream setChecksum(boolean useChecksums) throws IOException {
Expand Down Expand Up @@ -164,7 +179,7 @@ public synchronized void write(byte[] src, int offset, int len) throws IOExcepti
int srcSize = offset + len;
srcPos = offset;
while (srcPos < srcSize) {
int size = compressStream(stream, dst, dstSize, src, srcSize);
int size = compressStream(stream, dst, dst.length, src, srcSize);
if (Zstd.isError(size)) {
throw new IOException("Compression error: " + Zstd.getErrorName(size));
}
Expand Down Expand Up @@ -192,7 +207,7 @@ public synchronized void flush() throws IOException {
// compress the remaining output and close the frame
int size;
do {
size = endStream(stream, dst, dstSize);
size = endStream(stream, dst, dst.length);
if (Zstd.isError(size)) {
throw new IOException("Compression error: " + Zstd.getErrorName(size));
}
Expand All @@ -203,7 +218,7 @@ public synchronized void flush() throws IOException {
// compress the remaining input
int size;
do {
size = flushStream(stream, dst, dstSize);
size = flushStream(stream, dst, dst.length);
if (Zstd.isError(size)) {
throw new IOException("Compression error: " + Zstd.getErrorName(size));
}
Expand All @@ -223,7 +238,7 @@ public synchronized void close() throws IOException {
// compress the remaining input and close the frame
int size;
do {
size = endStream(stream, dst, dstSize);
size = endStream(stream, dst, dst.length);
if (Zstd.isError(size)) {
throw new IOException("Compression error: " + Zstd.getErrorName(size));
}
Expand All @@ -234,8 +249,7 @@ public synchronized void close() throws IOException {
} finally {
// release the resources even if underlying stream throw an exception
isClosed = true;
dstPool.checkIn(dst);
dstPool = null;
bufferPool.release(dstByteBuffer);
freeCStream(stream);
}
}
Expand Down
55 changes: 35 additions & 20 deletions src/test/scala/Zstd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -842,25 +842,40 @@ class ZstdSpec extends FlatSpec with Checkers {
assert(zis.read(buf, 0, 1) == -1)
}

"BufferPool" should "recycle buffers" in {
val pool = BufferPool.get(10)
val largeBuf1 = pool.checkOut()
val largeBuf2 = pool.checkOut()
val largeBuf3 = pool.checkOut()
pool.checkIn(largeBuf1)
pool.checkIn(largeBuf2)
val largeBuf4 = pool.checkOut()
val largeBuf5 = pool.checkOut()
val largeBuf6 = pool.checkOut()
assert(largeBuf1 != largeBuf2)
assert(largeBuf1 != largeBuf3)
assert(largeBuf2 != largeBuf3)
assert(largeBuf4 == largeBuf1)
assert(largeBuf5 == largeBuf2)
assert(largeBuf6 != largeBuf1)
assert(largeBuf6 != largeBuf2)
assert(largeBuf6 != largeBuf3)
assert(largeBuf6 != largeBuf4)
assert(largeBuf6 != largeBuf5)
"RecyclingBufferPool" should "recycle buffers" in {
val pool = RecyclingBufferPool.INSTANCE
val largeBuf1 = pool.get(10)
val largeBuf2 = pool.get(10)
val largeBuf3 = pool.get(10)
pool.release(largeBuf1)
pool.release(largeBuf2)
val largeBuf4 = pool.get(10)
val largeBuf5 = pool.get(10)
val largeBuf6 = pool.get(10)
assert(!largeBuf1.eq(largeBuf2))
assert(!largeBuf1.eq(largeBuf3))
assert(!largeBuf2.eq(largeBuf3))
assert(largeBuf4.eq(largeBuf1))
assert(largeBuf5.eq(largeBuf2))
assert(!largeBuf6.eq(largeBuf1))
assert(!largeBuf6.eq(largeBuf2))
assert(!largeBuf6.eq(largeBuf3))
assert(!largeBuf6.eq(largeBuf4))
assert(!largeBuf6.eq(largeBuf5))
assert(largeBuf6.hasArray)
assert(largeBuf6.arrayOffset == 0)
assert(largeBuf6.capacity == 10)
assert(largeBuf6.array.length == 10)
}

"Zstd" should "validate when extracting backing arrays from ByteBuffers" in {
assertThrows[IllegalArgumentException] {
Zstd.extractArray(ByteBuffer.allocateDirect(10))
}
assertThrows[IllegalArgumentException] {
val buf = ByteBuffer.allocate(10);
buf.putInt(1);
Zstd.extractArray(buf.slice)
}
}
}

0 comments on commit dd2588e

Please sign in to comment.