Skip to content

Commit

Permalink
[api] Refactor PublisherBytesSupplier.java
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Oct 31, 2023
1 parent 3927867 commit adbca4c
Showing 1 changed file with 15 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@

import ai.djl.ndarray.BytesSupplier;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

/**
Expand All @@ -29,16 +25,12 @@
*/
public class PublisherBytesSupplier implements BytesSupplier {

private final List<byte[]> allData;
private final AtomicBoolean completed;
private Consumer<byte[]> subscriber;
private final AtomicInteger dataPushed;
private CountDownLatch latch;

/** Constructs a {@link PublisherBytesSupplier}. */
public PublisherBytesSupplier() {
allData = new ArrayList<>();
completed = new AtomicBoolean();
dataPushed = new AtomicInteger();
latch = new CountDownLatch(1);
}

/**
Expand All @@ -48,13 +40,19 @@ public PublisherBytesSupplier() {
* @param lastChunk true if this is the last chunk
*/
public void appendContent(byte[] data, boolean lastChunk) {
synchronized (allData) {
allData.add(data);
if (subscriber == null) {
try {
if (!latch.await(2, TimeUnit.MINUTES)) {
throw new IllegalStateException("Wait for subscriber timeout.");
}
} catch (InterruptedException e) {
throw new IllegalStateException("Append content interrupted.", e);
}
}
subscriber.accept(data);
if (lastChunk) {
completed.set(true);
subscriber.accept(null);
}
pushData();
}

/**
Expand All @@ -69,62 +67,11 @@ public void subscribe(Consumer<byte[]> subscriber) {
"The PublisherBytesSupplier only allows a single Subscriber");
}
this.subscriber = subscriber;
pushData();
}

private void pushData() {
if (subscriber == null) {
return;
}

int dataAvailable;
synchronized (allData) {
dataAvailable = allData.size();
}

int sent = dataPushed.getAndSet(dataAvailable);
if (sent < dataAvailable) {
synchronized (this) {
for (; sent < dataAvailable; sent++) {
subscriber.accept(allData.get(sent));
}
if (completed.get()) {
subscriber.accept(null);
}
}
}
}

/** Waits until completed before passing thread (BLOCKS THREAD!). */
@SuppressWarnings("PMD.EmptyControlStatement")
public void waitToRead() {
// Block until complete!!!
while (!completed.get()) {
// Do nothing
}
}

/** {@inheritDoc} */
@Override
public byte[] getAsBytes() {
if (!completed.get()) {
throw new IllegalStateException(
"PublisherByteSupplier must be completely filled before reading.");
}

try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
for (byte[] data : allData) {
bos.write(data);
}
return bos.toByteArray();
} catch (IOException e) {
throw new AssertionError("Failed to read BytesSupplier", e);
}
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
return ByteBuffer.wrap(getAsBytes());
throw new UnsupportedOperationException("Not supported.");
}
}

0 comments on commit adbca4c

Please sign in to comment.