Skip to content

Commit

Permalink
make it actually work
Browse files Browse the repository at this point in the history
Signed-off-by: Marc Handalian <[email protected]>
  • Loading branch information
mch2 committed Jan 28, 2025
1 parent 7c9d9c5 commit 5d0ee1d
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.datafusion;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
Expand Down Expand Up @@ -48,14 +49,14 @@ public CompletableFuture<ArrowReader> collect(BufferAllocator allocator) {
}

// return a stream over the dataframe
public CompletableFuture<RecordBatchStream> getStream(BufferAllocator allocator) {
public CompletableFuture<RecordBatchStream> getStream(BufferAllocator allocator, VectorSchemaRoot root) {
CompletableFuture<RecordBatchStream> result = new CompletableFuture<>();
long runtimePointer = ctx.getRuntime();
DataFusion.executeStream(runtimePointer, ptr, (String errString, long streamId) -> {
if (errString != null && errString.isEmpty() == false) {
result.completeExceptionally(new RuntimeException(errString));
} else {
result.complete(new RecordBatchStream(ctx, streamId, allocator));
result.complete(new RecordBatchStream(ctx, streamId, allocator, root));
}
});
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class DataFrameStreamProducer implements PartitionedStreamProducer {

private StreamTicket rootTicket;
private Set<StreamTicket> partitions;
VectorSchemaRoot root;

public DataFrameStreamProducer(Function<StreamProducer, StreamTicket> streamRegistrar, Set<StreamTicket> partitions, Function<StreamTicket, CompletableFuture<DataFrame>> frameSupplier) {
logger.info("Constructed DataFrameFlightProducer");
Expand All @@ -57,7 +58,8 @@ public VectorSchemaRoot createRoot(BufferAllocator allocator) {
arrowFields.put("count", countField);
arrowFields.put("ord", new Field("ord", FieldType.nullable(new ArrowType.Utf8()), null));
Schema schema = new Schema(arrowFields.values());
return VectorSchemaRoot.create(schema, allocator);
root = VectorSchemaRoot.create(schema, allocator);
return root;
}

@Override
Expand All @@ -72,9 +74,9 @@ public void run(VectorSchemaRoot root, FlushSignal flushSignal) {
try {
assert rootTicket != null;
df = frameSupplier.apply(rootTicket).join();
recordBatchStream = df.getStream(allocator).get();
recordBatchStream = df.getStream(allocator, root).get();
while (recordBatchStream.loadNextBatch().join()) {
logger.info(recordBatchStream.getVectorSchemaRoot().getRowCount());
// logger.info(recordBatchStream.getVectorSchemaRoot().getRowCount());
// wait for a signal to load the next batch
flushSignal.awaitConsumption(1000);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ public class RecordBatchStream implements AutoCloseable {
private final BufferAllocator allocator;
private final CDataDictionaryProvider dictionaryProvider;

public RecordBatchStream(SessionContext ctx, long streamId, BufferAllocator allocator) {
public RecordBatchStream(SessionContext ctx, long streamId, BufferAllocator allocator, VectorSchemaRoot root) {
this.context = ctx;
this.ptr = streamId;
this.allocator = allocator;
this.dictionaryProvider = new CDataDictionaryProvider();
this.vectorSchemaRoot = root;
}

private static native void destroy(long pointer);
Expand All @@ -50,7 +51,7 @@ public void close() throws Exception {
public static Logger logger = LogManager.getLogger(RecordBatchStream.class);

public CompletableFuture<Boolean> loadNextBatch() {
ensureInitialized();
// ensureInitialized();
long runtimePointer = context.getRuntime();
CompletableFuture<Boolean> result = new CompletableFuture<>();
next(
Expand Down Expand Up @@ -111,6 +112,7 @@ private Schema getSchema() {
result.complete(schema);
// The FFI schema will be released from rust when it is dropped
} catch (Exception e) {
logger.error("WTF", e);
result.completeExceptionally(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ public ReducedQueryPhase reducedAggsFromStream(List<StreamSearchResult> list) th
new DataFrameStreamProducer((p -> streamManager.registerStream(p, TaskId.EMPTY_TASK_ID)), streamTickets, (t) -> DataFusion.agg(t.toBytes())));

logger.info("Register stream at coordinator");
AccessController.doPrivileged((PrivilegedAction<ReducedQueryPhase>) () -> {
return AccessController.doPrivileged((PrivilegedAction<ReducedQueryPhase>) () -> {
StreamReader streamIterator = streamManager.getStreamReader(producer.getRootTicket());
logger.info("Finished register stream at coordinator");

Expand All @@ -760,6 +760,7 @@ public ReducedQueryPhase reducedAggsFromStream(List<StreamSearchResult> list) th
UInt8Vector count = (UInt8Vector) root.getVector("count");

Long bucketCount = (Long) getValue(count, row);
logger.info("Got data from DF {} {}", ordName, bucketCount);
buckets.add(new StringTerms.Bucket(new BytesRef(ordName.getBytes()), bucketCount.longValue(), new InternalAggregations(List.of()), false, 0, DocValueFormat.RAW));
}
}
Expand All @@ -782,8 +783,9 @@ public ReducedQueryPhase reducedAggsFromStream(List<StreamSearchResult> list) th
// buckets.add(new StringTerms.Bucket(new BytesRef(ordName.getBytes()), bucketCount.longValue(), new InternalAggregations(List.of()), false, 0, DocValueFormat.RAW));
// }
// }
logger.info("Buckets are {}", buckets);
aggs.add(new StringTerms(
root.getSchema().getCustomMetadata().get("name"),
"category",
InternalOrder.key(true),
InternalOrder.key(true),
null,
Expand Down Expand Up @@ -818,7 +820,6 @@ public ReducedQueryPhase reducedAggsFromStream(List<StreamSearchResult> list) th
list.stream().flatMap(ssr -> ssr.getFlightTickets().stream()).collect(Collectors.toList())
);
});
return null;
}

public ReducedQueryPhase reducedFromStream(List<StreamSearchResult> list) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,13 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
currentRow[0]++;
if (currentRow[0] == batchSize) {
flushBatch();
flushSignal.awaitConsumption(10000000);
}
}

@Override
public void finish() throws IOException {
if (currentRow[0] > 0) {
flushBatch();
flushSignal.awaitConsumption(10000000);
logger.info("Flushed last batch for segment {}", context.toString());
}
}
Expand Down

0 comments on commit 5d0ee1d

Please sign in to comment.