diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanTask.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanTask.java index e4764236dadbd..7747dd6034016 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanTask.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanTask.java @@ -25,6 +25,7 @@ * id via {@link JniWrapper}, thus we allow only one-time execution of method {@link #execute()}. If a re-scan * operation is expected, call {@link NativeDataset#newScan} to create a new scanner instance. */ +@Deprecated public class NativeScanTask implements ScanTask { private final NativeScanner scanner; diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java index de18f9e5e0bcb..8ca8e5cf50eac 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java @@ -68,6 +68,19 @@ ArrowReader execute() { } @Override + public ArrowReader scanBatches() { + if (closed) { + throw new NativeInstanceReleasedException(); + } + if (!executed.compareAndSet(false, true)) { + throw new UnsupportedOperationException("NativeScanner can only be executed once. Create a " + + "new scanner instead"); + } + return new NativeReader(context.getAllocator()); + } + + @Override + @Deprecated public Iterable scan() { if (closed) { throw new NativeInstanceReleasedException(); diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanTask.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanTask.java index 434f5c9a6fa5a..16b8aeefb61f9 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanTask.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanTask.java @@ -26,6 +26,7 @@ * ScanTask is meant to be a unit of work to be dispatched. The implementation * must be thread and concurrent safe. */ +@Deprecated public interface ScanTask extends AutoCloseable { /** diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/Scanner.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/Scanner.java index 93a1b08f36609..43749b7db8ec2 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/Scanner.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/Scanner.java @@ -17,6 +17,7 @@ package org.apache.arrow.dataset.scanner; +import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.Schema; /** @@ -24,12 +25,21 @@ */ public interface Scanner extends AutoCloseable { + /** + * Read the dataset as a stream of record batches. + * + * @return a {@link ArrowReader}. + */ + ArrowReader scanBatches(); + /** * Perform the scan operation. * * @return a iterable set of {@link ScanTask}s. Each task is considered independent and it is allowed * to execute the tasks concurrently to gain better performance. + * @deprecated use {@link #scanBatches()} instead. */ + @Deprecated Iterable scan(); /** diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/TestDataset.java b/java/dataset/src/test/java/org/apache/arrow/dataset/TestDataset.java index 15224534d2873..2516c409593ba 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/TestDataset.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/TestDataset.java @@ -28,7 +28,6 @@ import java.util.stream.StreamSupport; import org.apache.arrow.dataset.scanner.ScanOptions; -import org.apache.arrow.dataset.scanner.ScanTask; import org.apache.arrow.dataset.scanner.Scanner; import org.apache.arrow.dataset.source.Dataset; import org.apache.arrow.dataset.source.DatasetFactory; @@ -63,9 +62,7 @@ protected List collectResultFromFactory(DatasetFactory factory final Dataset dataset = factory.finish(); final Scanner scanner = dataset.newScan(options); try { - final List ret = stream(scanner.scan()) - .flatMap(t -> stream(collectTaskData(t))) - .collect(Collectors.toList()); + final List ret = collectTaskData(scanner); AutoCloseables.close(scanner, dataset); return ret; } catch (RuntimeException e) { @@ -75,8 +72,8 @@ protected List collectResultFromFactory(DatasetFactory factory } } - protected List collectTaskData(ScanTask scanTask) { - try (ArrowReader reader = scanTask.execute()) { + protected List collectTaskData(Scanner scan) { + try (ArrowReader reader = scan.scanBatches()) { List batches = new ArrayList<>(); while (reader.loadNextBatch()) { VectorSchemaRoot root = reader.getVectorSchemaRoot(); diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java index b8d51a3edb169..9dc5f2b655a83 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java @@ -42,7 +42,6 @@ import org.apache.arrow.dataset.jni.NativeDataset; import org.apache.arrow.dataset.jni.NativeInstanceReleasedException; import org.apache.arrow.dataset.jni.NativeMemoryPool; -import org.apache.arrow.dataset.jni.NativeScanTask; import org.apache.arrow.dataset.jni.NativeScanner; import org.apache.arrow.dataset.jni.TestNativeDataset; import org.apache.arrow.dataset.scanner.ScanOptions; @@ -88,7 +87,7 @@ public void testBaseParquetRead() throws Exception { Schema schema = inferResultSchemaFromFactory(factory, options); List datum = collectResultFromFactory(factory, options); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(1, datum.size()); assertEquals(2, schema.getFields().size()); assertEquals("id", schema.getFields().get(0).getName()); @@ -112,7 +111,7 @@ public void testParquetProjectSingleColumn() throws Exception { List datum = collectResultFromFactory(factory, options); org.apache.avro.Schema expectedSchema = truncateAvroSchema(writeSupport.getAvroSchema(), 0, 1); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(1, schema.getFields().size()); assertEquals("id", schema.getFields().get(0).getName()); assertEquals(Types.MinorType.INT.getType(), schema.getFields().get(0).getType()); @@ -139,7 +138,7 @@ public void testParquetBatchSize() throws Exception { Schema schema = inferResultSchemaFromFactory(factory, options); List datum = collectResultFromFactory(factory, options); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(3, datum.size()); datum.forEach(batch -> assertEquals(1, batch.getLength())); checkParquetReadResult(schema, writeSupport.getWrittenRecords(), datum); @@ -163,7 +162,7 @@ public void testParquetDirectoryRead() throws Exception { Schema schema = inferResultSchemaFromFactory(factory, options); List datum = collectResultFromFactory(factory, options); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(7, datum.size()); datum.forEach(batch -> assertEquals(1, batch.getLength())); checkParquetReadResult(schema, expectedJsonUnordered, datum); @@ -182,7 +181,7 @@ public void testEmptyProjectSelectsZeroColumns() throws Exception { List datum = collectResultFromFactory(factory, options); org.apache.avro.Schema expectedSchema = org.apache.avro.Schema.createRecord(Collections.emptyList()); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(0, schema.getFields().size()); assertEquals(1, datum.size()); checkParquetReadResult(schema, @@ -204,7 +203,7 @@ public void testNullProjectSelectsAllColumns() throws Exception { Schema schema = inferResultSchemaFromFactory(factory, options); List datum = collectResultFromFactory(factory, options); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(1, datum.size()); assertEquals(2, schema.getFields().size()); assertEquals("id", schema.getFields().get(0).getName()); @@ -233,7 +232,7 @@ public void testNoErrorWhenCloseAgain() throws Exception { } @Test - public void testErrorThrownWhenScanAgain() throws Exception { + public void testErrorThrownWhenScanBatchesAgain() throws Exception { ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 1, "a"); FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), @@ -241,25 +240,18 @@ public void testErrorThrownWhenScanAgain() throws Exception { NativeDataset dataset = factory.finish(); ScanOptions options = new ScanOptions(100); NativeScanner scanner = dataset.newScan(options); - List taskList1 = collect(scanner.scan()); - List taskList2 = collect(scanner.scan()); - NativeScanTask task1 = taskList1.get(0); - NativeScanTask task2 = taskList2.get(0); - List datum = collectTaskData(task1); - + List datum = collectTaskData(scanner); AutoCloseables.close(datum); - - UnsupportedOperationException uoe = assertThrows(UnsupportedOperationException.class, task2::execute); - Assertions.assertEquals("NativeScanner cannot be executed more than once. Consider creating new scanner instead", + UnsupportedOperationException uoe = assertThrows(UnsupportedOperationException.class, + scanner::scanBatches); + Assertions.assertEquals("NativeScanner can only be executed once. Create a new scanner instead", uoe.getMessage()); - AutoCloseables.close(taskList1); - AutoCloseables.close(taskList2); AutoCloseables.close(scanner, dataset, factory); } @Test - public void testScanInOtherThread() throws Exception { + public void testScanBatchesInOtherThread() throws Exception { ExecutorService executor = Executors.newSingleThreadExecutor(); ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 1, "a"); @@ -268,17 +260,14 @@ public void testScanInOtherThread() throws Exception { NativeDataset dataset = factory.finish(); ScanOptions options = new ScanOptions(100); NativeScanner scanner = dataset.newScan(options); - List taskList = collect(scanner.scan()); - NativeScanTask task = taskList.get(0); - List datum = executor.submit(() -> collectTaskData(task)).get(); + List datum = executor.submit(() -> collectTaskData(scanner)).get(); AutoCloseables.close(datum); - AutoCloseables.close(taskList); AutoCloseables.close(scanner, dataset, factory); } @Test - public void testErrorThrownWhenScanAfterScannerClose() throws Exception { + public void testErrorThrownWhenScanBatchesAfterScannerClose() throws Exception { ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 1, "a"); FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), @@ -287,28 +276,13 @@ public void testErrorThrownWhenScanAfterScannerClose() throws Exception { ScanOptions options = new ScanOptions(100); NativeScanner scanner = dataset.newScan(options); scanner.close(); - assertThrows(NativeInstanceReleasedException.class, scanner::scan); - AutoCloseables.close(factory); - } - - @Test - public void testErrorThrownWhenExecuteTaskAfterTaskClose() throws Exception { - ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 1, "a"); + assertThrows(NativeInstanceReleasedException.class, scanner::scanBatches); - FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), - FileFormat.PARQUET, writeSupport.getOutputURI()); - NativeDataset dataset = factory.finish(); - ScanOptions options = new ScanOptions(100); - NativeScanner scanner = dataset.newScan(options); - List tasks = collect(scanner.scan()); - NativeScanTask task = tasks.get(0); - task.close(); - assertThrows(NativeInstanceReleasedException.class, task::execute); AutoCloseables.close(factory); } @Test - public void testErrorThrownWhenIterateOnIteratorAfterTaskClose() throws Exception { + public void testErrorThrownWhenReadAfterNativeReaderClose() throws Exception { ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 1, "a"); FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), @@ -316,11 +290,10 @@ public void testErrorThrownWhenIterateOnIteratorAfterTaskClose() throws Exceptio NativeDataset dataset = factory.finish(); ScanOptions options = new ScanOptions(100); NativeScanner scanner = dataset.newScan(options); - List tasks = collect(scanner.scan()); - NativeScanTask task = tasks.get(0); - ArrowReader reader = task.execute(); - task.close(); + ArrowReader reader = scanner.scanBatches(); + scanner.close(); assertThrows(NativeInstanceReleasedException.class, reader::loadNextBatch); + AutoCloseables.close(factory); } @@ -348,7 +321,7 @@ public void testBaseArrowIpcRead() throws Exception { Schema schema = inferResultSchemaFromFactory(factory, options); List datum = collectResultFromFactory(factory, options); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(1, datum.size()); assertEquals(1, schema.getFields().size()); assertEquals("ints", schema.getFields().get(0).getName()); @@ -376,7 +349,7 @@ public void testBaseOrcRead() throws Exception { Schema schema = inferResultSchemaFromFactory(factory, options); List datum = collectResultFromFactory(factory, options); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(1, datum.size()); assertEquals(1, schema.getFields().size()); assertEquals("ints", schema.getFields().get(0).getName()); diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestNativeDataset.java b/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestNativeDataset.java index 2a86a25688309..d0f91769096d8 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestNativeDataset.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestNativeDataset.java @@ -25,9 +25,9 @@ import org.junit.Assert; public abstract class TestNativeDataset extends TestDataset { - protected void assertSingleTaskProduced(DatasetFactory factory, ScanOptions options) { + protected void assertScanBatchesProduced(DatasetFactory factory, ScanOptions options) { final Dataset dataset = factory.finish(); final Scanner scanner = dataset.newScan(options); - Assert.assertEquals(1L, stream(scanner.scan()).count()); + Assert.assertNotNull(scanner.scanBatches()); } }