diff --git a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrame.java b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrame.java index 6285ae06c4995..aad45291aeae0 100644 --- a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrame.java +++ b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrame.java @@ -9,6 +9,7 @@ package org.opensearch.datafusion; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowFileReader; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel; @@ -48,14 +49,14 @@ public CompletableFuture collect(BufferAllocator allocator) { } // return a stream over the dataframe - public CompletableFuture getStream(BufferAllocator allocator) { + public CompletableFuture getStream(BufferAllocator allocator, VectorSchemaRoot root) { CompletableFuture result = new CompletableFuture<>(); long runtimePointer = ctx.getRuntime(); DataFusion.executeStream(runtimePointer, ptr, (String errString, long streamId) -> { if (errString != null && errString.isEmpty() == false) { result.completeExceptionally(new RuntimeException(errString)); } else { - result.complete(new RecordBatchStream(ctx, streamId, allocator)); + result.complete(new RecordBatchStream(ctx, streamId, allocator, root)); } }); return result; diff --git a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java index 9bc47c0157453..a287304634dd3 100644 --- a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java +++ b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java @@ -41,6 +41,7 @@ public class DataFrameStreamProducer implements PartitionedStreamProducer { private StreamTicket rootTicket; private Set partitions; + VectorSchemaRoot root; public DataFrameStreamProducer(Function streamRegistrar, Set partitions, Function> frameSupplier) { logger.info("Constructed DataFrameFlightProducer"); @@ -57,7 +58,8 @@ public VectorSchemaRoot createRoot(BufferAllocator allocator) { arrowFields.put("count", countField); arrowFields.put("ord", new Field("ord", FieldType.nullable(new ArrowType.Utf8()), null)); Schema schema = new Schema(arrowFields.values()); - return VectorSchemaRoot.create(schema, allocator); + root = VectorSchemaRoot.create(schema, allocator); + return root; } @Override @@ -72,9 +74,9 @@ public void run(VectorSchemaRoot root, FlushSignal flushSignal) { try { assert rootTicket != null; df = frameSupplier.apply(rootTicket).join(); - recordBatchStream = df.getStream(allocator).get(); + recordBatchStream = df.getStream(allocator, root).get(); while (recordBatchStream.loadNextBatch().join()) { - logger.info(recordBatchStream.getVectorSchemaRoot().getRowCount()); +// logger.info(recordBatchStream.getVectorSchemaRoot().getRowCount()); // wait for a signal to load the next batch flushSignal.awaitConsumption(1000); } diff --git a/libs/datafusion/src/main/java/org.opensearch.datafusion/RecordBatchStream.java b/libs/datafusion/src/main/java/org.opensearch.datafusion/RecordBatchStream.java index 9dae2a2a0ff91..77b4decb98839 100644 --- a/libs/datafusion/src/main/java/org.opensearch.datafusion/RecordBatchStream.java +++ b/libs/datafusion/src/main/java/org.opensearch.datafusion/RecordBatchStream.java @@ -27,11 +27,12 @@ public class RecordBatchStream implements AutoCloseable { private final BufferAllocator allocator; private final CDataDictionaryProvider dictionaryProvider; - public RecordBatchStream(SessionContext ctx, long streamId, BufferAllocator allocator) { + public RecordBatchStream(SessionContext ctx, long streamId, BufferAllocator allocator, VectorSchemaRoot root) { this.context = ctx; this.ptr = streamId; this.allocator = allocator; this.dictionaryProvider = new CDataDictionaryProvider(); + this.vectorSchemaRoot = root; } private static native void destroy(long pointer); @@ -50,7 +51,7 @@ public void close() throws Exception { public static Logger logger = LogManager.getLogger(RecordBatchStream.class); public CompletableFuture loadNextBatch() { - ensureInitialized(); +// ensureInitialized(); long runtimePointer = context.getRuntime(); CompletableFuture result = new CompletableFuture<>(); next( @@ -111,6 +112,7 @@ private Schema getSchema() { result.complete(schema); // The FFI schema will be released from rust when it is dropped } catch (Exception e) { + logger.error("WTF", e); result.completeExceptionally(e); } } diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java index 705739f324dd3..77cbd180c997c 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java @@ -737,7 +737,7 @@ public ReducedQueryPhase reducedAggsFromStream(List list) th new DataFrameStreamProducer((p -> streamManager.registerStream(p, TaskId.EMPTY_TASK_ID)), streamTickets, (t) -> DataFusion.agg(t.toBytes()))); logger.info("Register stream at coordinator"); - AccessController.doPrivileged((PrivilegedAction) () -> { + return AccessController.doPrivileged((PrivilegedAction) () -> { StreamReader streamIterator = streamManager.getStreamReader(producer.getRootTicket()); logger.info("Finished register stream at coordinator"); @@ -760,6 +760,7 @@ public ReducedQueryPhase reducedAggsFromStream(List list) th UInt8Vector count = (UInt8Vector) root.getVector("count"); Long bucketCount = (Long) getValue(count, row); + logger.info("Got data from DF {} {}", ordName, bucketCount); buckets.add(new StringTerms.Bucket(new BytesRef(ordName.getBytes()), bucketCount.longValue(), new InternalAggregations(List.of()), false, 0, DocValueFormat.RAW)); } } @@ -782,8 +783,9 @@ public ReducedQueryPhase reducedAggsFromStream(List list) th // buckets.add(new StringTerms.Bucket(new BytesRef(ordName.getBytes()), bucketCount.longValue(), new InternalAggregations(List.of()), false, 0, DocValueFormat.RAW)); // } // } + logger.info("Buckets are {}", buckets); aggs.add(new StringTerms( - root.getSchema().getCustomMetadata().get("name"), + "category", InternalOrder.key(true), InternalOrder.key(true), null, @@ -818,7 +820,6 @@ public ReducedQueryPhase reducedAggsFromStream(List list) th list.stream().flatMap(ssr -> ssr.getFlightTickets().stream()).collect(Collectors.toList()) ); }); - return null; } public ReducedQueryPhase reducedFromStream(List list) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/support/StreamingAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/support/StreamingAggregator.java index 059a1ee76e976..e7ec11c89d720 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/support/StreamingAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/support/StreamingAggregator.java @@ -96,7 +96,6 @@ public void collect(int doc, long owningBucketOrd) throws IOException { currentRow[0]++; if (currentRow[0] == batchSize) { flushBatch(); - flushSignal.awaitConsumption(10000000); } } @@ -104,7 +103,6 @@ public void collect(int doc, long owningBucketOrd) throws IOException { public void finish() throws IOException { if (currentRow[0] > 0) { flushBatch(); - flushSignal.awaitConsumption(10000000); logger.info("Flushed last batch for segment {}", context.toString()); } }