From 6dd83be026d145c5428380df41ef5f162a723aef Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 17 Oct 2023 13:03:05 -0400 Subject: [PATCH 1/2] Allow for multiple input files per table instead of a single file --- datafusion/input/location.py | 16 +++++----------- datafusion/tests/test_input.py | 2 +- src/common/schema.rs | 16 ++++++++-------- 3 files changed, 14 insertions(+), 20 deletions(-) diff --git a/datafusion/input/location.py b/datafusion/input/location.py index efbc82f23..25ee60ddb 100644 --- a/datafusion/input/location.py +++ b/datafusion/input/location.py @@ -16,21 +16,18 @@ # under the License. import os +import glob from typing import Any from datafusion.common import DataTypeMap, SqlTable from datafusion.input.base import BaseInputSource - - class LocationInputPlugin(BaseInputSource): """ Input Plugin for everything, which can be read in from a file (on disk, remote etc.) """ - def is_correct_input(self, input_item: Any, table_name: str, **kwargs): return isinstance(input_item, str) - def build_table( self, input_file: str, @@ -41,14 +38,11 @@ def build_table( format = extension.lstrip(".").lower() num_rows = 0 # Total number of rows in the file. Used for statistics columns = [] - if format == "parquet": import pyarrow.parquet as pq - # Read the Parquet metadata metadata = pq.read_metadata(input_file) num_rows = metadata.num_rows - # Iterate through the schema and build the SqlTable for col in metadata.schema: columns.append( @@ -57,10 +51,8 @@ def build_table( DataTypeMap.from_parquet_type_str(col.physical_type), ) ) - elif format == "csv": import csv - # Consume header row and count number of rows for statistics. # TODO: Possibly makes sense to have the eager number of rows # calculated as a configuration since you must read the entire file @@ -73,7 +65,6 @@ def build_table( print(header_row) for _ in reader: num_rows += 1 - # TODO: Need to actually consume this row into resonable columns raise RuntimeError( "TODO: Currently unable to support CSV input files." @@ -84,4 +75,7 @@ def build_table( Only Parquet and CSV." ) - return SqlTable(table_name, columns, num_rows, input_file) + # Input could possibly be multiple files. Create a list if so + input_files = glob.glob(input_file) + + return SqlTable(table_name, columns, num_rows, input_files) diff --git a/datafusion/tests/test_input.py b/datafusion/tests/test_input.py index 1e2ef4166..5b1decf26 100644 --- a/datafusion/tests/test_input.py +++ b/datafusion/tests/test_input.py @@ -30,4 +30,4 @@ def test_location_input(): tbl = location_input.build_table(input_file, table_name) assert "blog" == tbl.name assert 3 == len(tbl.columns) - assert "blogs.parquet" in tbl.filepath + assert "blogs.parquet" in tbl.filepaths[0] diff --git a/src/common/schema.rs b/src/common/schema.rs index a003d0ca1..77b0ce2ba 100644 --- a/src/common/schema.rs +++ b/src/common/schema.rs @@ -56,7 +56,7 @@ pub struct SqlTable { #[pyo3(get, set)] pub statistics: SqlStatistics, #[pyo3(get, set)] - pub filepath: Option, + pub filepaths: Option>, } #[pymethods] @@ -66,7 +66,7 @@ impl SqlTable { table_name: String, columns: Vec<(String, DataTypeMap)>, row_count: f64, - filepath: Option, + filepaths: Option>, ) -> Self { Self { name: table_name, @@ -76,7 +76,7 @@ impl SqlTable { indexes: Vec::new(), constraints: Vec::new(), statistics: SqlStatistics::new(row_count), - filepath, + filepaths, } } } @@ -124,7 +124,7 @@ impl SqlSchema { pub struct SqlTableSource { schema: SchemaRef, statistics: Option, - filepath: Option, + filepaths: Option>, } impl SqlTableSource { @@ -132,12 +132,12 @@ impl SqlTableSource { pub fn new( schema: SchemaRef, statistics: Option, - filepath: Option, + filepaths: Option>, ) -> Self { Self { schema, statistics, - filepath, + filepaths, } } @@ -148,8 +148,8 @@ impl SqlTableSource { /// Access optional filepath associated with this table source #[allow(dead_code)] - pub fn filepath(&self) -> Option<&String> { - self.filepath.as_ref() + pub fn filepaths(&self) -> Option<&Vec> { + self.filepaths.as_ref() } } From c3d5dc5d3a00a9587d15eadc50834d7ea4ed8ecc Mon Sep 17 00:00:00 2001 From: Jeremy Dyer Date: Tue, 17 Oct 2023 13:27:46 -0400 Subject: [PATCH 2/2] linter fixes --- datafusion/input/location.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/datafusion/input/location.py b/datafusion/input/location.py index 25ee60ddb..939c7f415 100644 --- a/datafusion/input/location.py +++ b/datafusion/input/location.py @@ -21,13 +21,17 @@ from datafusion.common import DataTypeMap, SqlTable from datafusion.input.base import BaseInputSource + + class LocationInputPlugin(BaseInputSource): """ Input Plugin for everything, which can be read in from a file (on disk, remote etc.) """ + def is_correct_input(self, input_item: Any, table_name: str, **kwargs): return isinstance(input_item, str) + def build_table( self, input_file: str, @@ -40,6 +44,7 @@ def build_table( columns = [] if format == "parquet": import pyarrow.parquet as pq + # Read the Parquet metadata metadata = pq.read_metadata(input_file) num_rows = metadata.num_rows @@ -53,6 +58,7 @@ def build_table( ) elif format == "csv": import csv + # Consume header row and count number of rows for statistics. # TODO: Possibly makes sense to have the eager number of rows # calculated as a configuration since you must read the entire file