Skip to content

Commit

Permalink
Make SchemaProvider async (#3777)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Jan 3, 2023
1 parent 34a8b86 commit 6482032
Show file tree
Hide file tree
Showing 22 changed files with 212 additions and 160 deletions.
12 changes: 12 additions & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion datafusion-examples/examples/dataframe_in_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async fn main() -> Result<()> {

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
ctx.register_batch("t", batch)?;
let df = ctx.table("t")?;
let df = ctx.table("t").await?;

// construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL
let filter = col("b").eq(lit(10));
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async fn main() -> Result<()> {

// get a DataFrame from the context
// this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0.
let df = ctx.table("t")?;
let df = ctx.table("t").await?;

// perform the aggregation
let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?;
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ async fn main() -> Result<()> {
let expr = pow.call(vec![col("a"), col("b")]);

// get a DataFrame from the context
let df = ctx.table("t")?;
let df = ctx.table("t").await?;

// if we do not have `pow` in the scope and we registered it, we can get it from the registry
let pow = df.registry().udf("pow")?;
Expand Down
10 changes: 9 additions & 1 deletion datafusion/common/src/table_reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
// specific language governing permissions and limitations
// under the License.

use std::fmt::{Display, Formatter};

/// A resolved path to a table of the form "catalog.schema.table"
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Hash)]
pub struct ResolvedTableReference<'a> {
/// The catalog (aka database) containing the table
pub catalog: &'a str,
Expand All @@ -26,6 +28,12 @@ pub struct ResolvedTableReference<'a> {
pub table: &'a str,
}

impl<'a> Display for ResolvedTableReference<'a> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}.{}", self.catalog, self.schema, self.table)
}
}

/// Represents a path to a table that may require further resolution
#[derive(Debug, Clone, Copy)]
pub enum TableReference<'a> {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pyo3 = { version = "0.17.1", optional = true }
rand = "0.8"
rayon = { version = "1.5", optional = true }
smallvec = { version = "1.6", features = ["union"] }
sqlparser = "0.30"
sqlparser = { version = "0.30", features = ["visitor"] }
tempfile = "3"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] }
tokio-stream = "0.1"
Expand Down
25 changes: 15 additions & 10 deletions datafusion/core/src/catalog/information_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//!
//! Information Schema]<https://en.wikipedia.org/wiki/Information_schema>
use async_trait::async_trait;
use std::{any::Any, sync::Arc};

use arrow::{
Expand All @@ -43,6 +44,9 @@ pub const VIEWS: &str = "views";
pub const COLUMNS: &str = "columns";
pub const DF_SETTINGS: &str = "df_settings";

/// All information schema tables
pub const INFORMATION_SCHEMA_TABLES: &[&str] = &[TABLES, VIEWS, COLUMNS, DF_SETTINGS];

/// Implements the `information_schema` virtual schema and tables
///
/// The underlying tables in the `information_schema` are created on
Expand All @@ -69,7 +73,7 @@ struct InformationSchemaConfig {

impl InformationSchemaConfig {
/// Construct the `information_schema.tables` virtual table
fn make_tables(&self, builder: &mut InformationSchemaTablesBuilder) {
async fn make_tables(&self, builder: &mut InformationSchemaTablesBuilder) {
// create a mem table with the names of tables

for catalog_name in self.catalog_list.catalog_names() {
Expand All @@ -79,7 +83,7 @@ impl InformationSchemaConfig {
if schema_name != INFORMATION_SCHEMA {
let schema = catalog.schema(&schema_name).unwrap();
for table_name in schema.table_names() {
let table = schema.table(&table_name).unwrap();
let table = schema.table(&table_name).await.unwrap();
builder.add_table(
&catalog_name,
&schema_name,
Expand Down Expand Up @@ -108,15 +112,15 @@ impl InformationSchemaConfig {
}
}

fn make_views(&self, builder: &mut InformationSchemaViewBuilder) {
async fn make_views(&self, builder: &mut InformationSchemaViewBuilder) {
for catalog_name in self.catalog_list.catalog_names() {
let catalog = self.catalog_list.catalog(&catalog_name).unwrap();

for schema_name in catalog.schema_names() {
if schema_name != INFORMATION_SCHEMA {
let schema = catalog.schema(&schema_name).unwrap();
for table_name in schema.table_names() {
let table = schema.table(&table_name).unwrap();
let table = schema.table(&table_name).await.unwrap();
builder.add_view(
&catalog_name,
&schema_name,
Expand All @@ -130,15 +134,15 @@ impl InformationSchemaConfig {
}

/// Construct the `information_schema.columns` virtual table
fn make_columns(&self, builder: &mut InformationSchemaColumnsBuilder) {
async fn make_columns(&self, builder: &mut InformationSchemaColumnsBuilder) {
for catalog_name in self.catalog_list.catalog_names() {
let catalog = self.catalog_list.catalog(&catalog_name).unwrap();

for schema_name in catalog.schema_names() {
if schema_name != INFORMATION_SCHEMA {
let schema = catalog.schema(&schema_name).unwrap();
for table_name in schema.table_names() {
let table = schema.table(&table_name).unwrap();
let table = schema.table(&table_name).await.unwrap();
for (i, field) in table.schema().fields().iter().enumerate() {
builder.add_column(
&catalog_name,
Expand Down Expand Up @@ -168,6 +172,7 @@ impl InformationSchemaConfig {
}
}

#[async_trait]
impl SchemaProvider for InformationSchemaProvider {
fn as_any(&self) -> &(dyn Any + 'static) {
self
Expand All @@ -182,7 +187,7 @@ impl SchemaProvider for InformationSchemaProvider {
]
}

fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
async fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
let config = self.config.clone();
let table: Arc<dyn PartitionStream> = if name.eq_ignore_ascii_case("tables") {
Arc::new(InformationSchemaTables::new(config))
Expand Down Expand Up @@ -246,7 +251,7 @@ impl PartitionStream for InformationSchemaTables {
self.schema.clone(),
// TODO: Stream this
futures::stream::once(async move {
config.make_tables(&mut builder);
config.make_tables(&mut builder).await;
Ok(builder.finish())
}),
))
Expand Down Expand Up @@ -337,7 +342,7 @@ impl PartitionStream for InformationSchemaViews {
self.schema.clone(),
// TODO: Stream this
futures::stream::once(async move {
config.make_views(&mut builder);
config.make_views(&mut builder).await;
Ok(builder.finish())
}),
))
Expand Down Expand Up @@ -451,7 +456,7 @@ impl PartitionStream for InformationSchemaColumns {
self.schema.clone(),
// TODO: Stream this
futures::stream::once(async move {
config.make_columns(&mut builder);
config.make_columns(&mut builder).await;
Ok(builder.finish())
}),
))
Expand Down
4 changes: 3 additions & 1 deletion datafusion/core/src/catalog/listing_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::catalog::schema::SchemaProvider;
use crate::datasource::datasource::TableProviderFactory;
use crate::datasource::TableProvider;
use crate::execution::context::SessionState;
use async_trait::async_trait;
use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::{DFSchema, DataFusionError, OwnedTableReference};
use datafusion_expr::CreateExternalTable;
Expand Down Expand Up @@ -156,6 +157,7 @@ impl ListingSchemaProvider {
}
}

#[async_trait]
impl SchemaProvider for ListingSchemaProvider {
fn as_any(&self) -> &dyn Any {
self
Expand All @@ -170,7 +172,7 @@ impl SchemaProvider for ListingSchemaProvider {
.collect()
}

fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
async fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
self.tables
.lock()
.expect("Can't lock tables")
Expand Down
7 changes: 5 additions & 2 deletions datafusion/core/src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Describes the interface and built-in implementations of schemas,
//! representing collections of named tables.
use async_trait::async_trait;
use dashmap::DashMap;
use std::any::Any;
use std::sync::Arc;
Expand All @@ -26,6 +27,7 @@ use crate::datasource::TableProvider;
use crate::error::{DataFusionError, Result};

/// Represents a schema, comprising a number of named tables.
#[async_trait]
pub trait SchemaProvider: Sync + Send {
/// Returns the schema provider as [`Any`](std::any::Any)
/// so that it can be downcast to a specific implementation.
Expand All @@ -35,7 +37,7 @@ pub trait SchemaProvider: Sync + Send {
fn table_names(&self) -> Vec<String>;

/// Retrieves a specific table from the schema by name, provided it exists.
fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>>;
async fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>>;

/// If supported by the implementation, adds a new table to this schema.
/// If a table of the same name existed before, it returns "Table already exists" error.
Expand Down Expand Up @@ -85,6 +87,7 @@ impl Default for MemorySchemaProvider {
}
}

#[async_trait]
impl SchemaProvider for MemorySchemaProvider {
fn as_any(&self) -> &dyn Any {
self
Expand All @@ -97,7 +100,7 @@ impl SchemaProvider for MemorySchemaProvider {
.collect()
}

fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
async fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
self.tables.get(name).map(|table| table.value().clone())
}

Expand Down
27 changes: 18 additions & 9 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ mod tests {
let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;

let df = ctx.table("t")?.select_columns(&["f.c1"])?;
let df = ctx.table("t").await?.select_columns(&["f.c1"])?;

let df_results = df.collect().await?;

Expand Down Expand Up @@ -1036,7 +1036,7 @@ mod tests {
));

// build query with a UDF using DataFrame API
let df = ctx.table("aggregate_test_100")?;
let df = ctx.table("aggregate_test_100").await?;

let expr = df.registry().udf("my_fn")?.call(vec![col("c12")]);
let df = df.select(vec![expr])?;
Expand Down Expand Up @@ -1101,7 +1101,7 @@ mod tests {
ctx.register_table("test_table", Arc::new(df_impl.clone()))?;

// pull the table out
let table = ctx.table("test_table")?;
let table = ctx.table("test_table").await?;

let group_expr = vec![col("c1")];
let aggr_expr = vec![sum(col("c12"))];
Expand Down Expand Up @@ -1161,7 +1161,7 @@ mod tests {
async fn test_table_with_name(name: &str) -> Result<DataFrame> {
let mut ctx = SessionContext::new();
register_aggregate_csv(&mut ctx, name).await?;
ctx.table(name)
ctx.table(name).await
}

async fn test_table() -> Result<DataFrame> {
Expand Down Expand Up @@ -1301,8 +1301,15 @@ mod tests {
ctx.register_table("t1", table.clone())?;
ctx.register_table("t2", table)?;
let df = ctx
.table("t1")?
.join(ctx.table("t2")?, JoinType::Inner, &["c1"], &["c1"], None)?
.table("t1")
.await?
.join(
ctx.table("t2").await?,
JoinType::Inner,
&["c1"],
&["c1"],
None,
)?
.sort(vec![
// make the test deterministic
col("t1.c1").sort(true, true),
Expand Down Expand Up @@ -1379,10 +1386,11 @@ mod tests {
)
.await?;

ctx.register_table("t1", Arc::new(ctx.table("test")?))?;
ctx.register_table("t1", Arc::new(ctx.table("test").await?))?;

let df = ctx
.table("t1")?
.table("t1")
.await?
.filter(col("id").eq(lit(1)))?
.select_columns(&["bool_col", "int_col"])?;

Expand Down Expand Up @@ -1463,7 +1471,8 @@ mod tests {
ctx.register_batch("t", batch)?;

let df = ctx
.table("t")?
.table("t")
.await?
// try and create a column with a '.' in it
.with_column("f.c2", lit("hello"))?;

Expand Down
10 changes: 6 additions & 4 deletions datafusion/core/src/datasource/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,13 @@ mod tests {
)
.await?;

ctx.register_table("t1", Arc::new(ctx.table("test")?))?;
ctx.register_table("t1", Arc::new(ctx.table("test").await?))?;

ctx.sql("CREATE VIEW t2 as SELECT * FROM t1").await?;

let df = ctx
.table("t2")?
.table("t2")
.await?
.filter(col("id").eq(lit(1)))?
.select_columns(&["bool_col", "int_col"])?;

Expand All @@ -457,12 +458,13 @@ mod tests {
)
.await?;

ctx.register_table("t1", Arc::new(ctx.table("test")?))?;
ctx.register_table("t1", Arc::new(ctx.table("test").await?))?;

ctx.sql("CREATE VIEW t2 as SELECT * FROM t1").await?;

let df = ctx
.table("t2")?
.table("t2")
.await?
.limit(0, Some(10))?
.select_columns(&["bool_col", "int_col"])?;

Expand Down
Loading

0 comments on commit 6482032

Please sign in to comment.