Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example of catalog API usage (#5291) #5326

Merged
merged 8 commits into from
Feb 26, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
catalog example
jaylmiller committed Feb 18, 2023
commit c5899f97577a5a27a72f02d68aedf35e2e1db5de
227 changes: 227 additions & 0 deletions datafusion-examples/examples/catalog.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
use async_trait::async_trait;
use datafusion::{
arrow::util::pretty,
catalog::{
catalog::{CatalogList, CatalogProvider},
schema::SchemaProvider,
},
datasource::{
file_format::{csv::CsvFormat, parquet::ParquetFormat, FileFormat},
listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
TableProvider,
},
error::Result,
execution::context::SessionState,
prelude::SessionContext,
};
use std::sync::RwLock;
use std::{
any::Any,
collections::HashMap,
path::{Path, PathBuf},
sync::Arc,
};

#[tokio::main]
async fn main() -> Result<()> {
let repo_dir = std::fs::canonicalize(
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
// parent dir of datafusion-examples = repo root
.join(".."),
)
.unwrap();
let mut ctx = SessionContext::new();
let state = ctx.state();
let catlist = Arc::new(CustomCatalogList::new());
// use our custom catalog list for context. each context has a single catalog list.
// context will by default have MemoryCatalogList
ctx.register_catalog_list(catlist.clone());

// intitialize our catalog and schemas
let catalog = DirCatalog::new();
let parquet_schema = DirSchema::create(
&state,
DirSchemaOpts {
format: Arc::new(ParquetFormat::default()),
dir: &repo_dir.join("parquet-testing").join("data"),
ext: "parquet",
},
)
.await?;
let csv_schema = DirSchema::create(
&state,
DirSchemaOpts {
format: Arc::new(CsvFormat::default()),
dir: &repo_dir.join("testing").join("data").join("csv"),
ext: "csv",
},
)
.await?;
// register schemas into catalog
catalog.register_schema("parquet", parquet_schema.clone())?;
catalog.register_schema("csv", csv_schema.clone())?;
// register our catalog in the context
ctx.register_catalog("dircat", Arc::new(catalog));
// catalog was passed down into our custom catalog list since we overide the ctx's default
let catalogs = catlist.catalogs.read().unwrap();
assert!(catalogs.contains_key("dircat"));
// tables are now available to be queried in the context
for table in parquet_schema.tables.keys().take(5) {
println!("querying table {table} from parquet schema");
let df = ctx
.sql(&format!("select * from dircat.parquet.\"{table}\" "))
.await
.unwrap()
.limit(0, Some(5))
.unwrap();
let result = df.collect().await;
match result {
Ok(batches) => {
pretty::print_batches(&batches).unwrap();
}
Err(e) => {
println!("table '{table}' query failed due to {e}");
}
}
}
Ok(())
}

struct DirSchemaOpts<'a> {
ext: &'a str,
dir: &'a Path,
format: Arc<dyn FileFormat>,
}
/// Schema where every file with extension `ext` in a given `dir` is a table.
struct DirSchema {
ext: String,
tables: HashMap<String, Arc<dyn TableProvider>>,
}
impl DirSchema {
async fn create(state: &SessionState, opts: DirSchemaOpts<'_>) -> Result<Arc<Self>> {
let DirSchemaOpts { ext, dir, format } = opts;
let mut tables = HashMap::new();
let listdir = std::fs::read_dir(&dir).unwrap();
for res in listdir {
let entry = res.unwrap();
let filename = entry.file_name().to_str().unwrap().to_string();
if !filename.ends_with(ext) {
continue;
}

let table_path = ListingTableUrl::parse(entry.path().to_str().unwrap())?;
let opts = ListingOptions::new(format.clone());
let conf = ListingTableConfig::new(table_path)
.with_listing_options(opts)
.infer_schema(state)
.await?;
let table = ListingTable::try_new(conf)?;
tables.insert(filename, Arc::new(table) as Arc<dyn TableProvider>);
}
Ok(Arc::new(Self {
tables,
ext: ext.to_string(),
}))
}
#[allow(unused)]
fn name(&self) -> &str {
&self.ext
}
}

#[async_trait]
impl SchemaProvider for DirSchema {
fn as_any(&self) -> &dyn Any {
self
}

fn table_names(&self) -> Vec<String> {
self.tables.keys().cloned().collect::<Vec<_>>()
}

async fn table(&self, name: &str) -> Option<Arc<dyn TableProvider>> {
self.tables.get(name).cloned()
}

fn table_exist(&self, name: &str) -> bool {
self.tables.contains_key(name)
}
}
/// Catalog holds multiple schemas
struct DirCatalog {
schemas: RwLock<HashMap<String, Arc<dyn SchemaProvider>>>,
}
impl DirCatalog {
fn new() -> Self {
Self {
schemas: RwLock::new(HashMap::new()),
}
}
}
impl CatalogProvider for DirCatalog {
fn as_any(&self) -> &dyn Any {
self
}
fn register_schema(
&self,
name: &str,
schema: Arc<dyn SchemaProvider>,
) -> Result<Option<Arc<dyn SchemaProvider>>> {
let mut schema_map = self.schemas.write().unwrap();
schema_map.insert(name.to_owned(), schema.clone());
Ok(Some(schema))
}

fn schema_names(&self) -> Vec<String> {
let schemas = self.schemas.read().unwrap();
schemas.keys().cloned().collect()
}

fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
let schemas = self.schemas.read().unwrap();
let maybe_schema = schemas.get(name);
if let Some(schema) = maybe_schema {
let schema = schema.clone() as Arc<dyn SchemaProvider>;
Some(schema)
} else {
None
}
}
}
/// Catalog lists holds multiple catalogs. Each context has a single catalog list.
struct CustomCatalogList {
catalogs: RwLock<HashMap<String, Arc<dyn CatalogProvider>>>,
}
impl CustomCatalogList {
fn new() -> Self {
Self {
catalogs: RwLock::new(HashMap::new()),
}
}
}
impl CatalogList for CustomCatalogList {
fn as_any(&self) -> &dyn Any {
self
}
fn register_catalog(
&self,
name: String,
catalog: Arc<dyn CatalogProvider>,
) -> Option<Arc<dyn CatalogProvider>> {
let mut cats = self.catalogs.write().unwrap();
cats.insert(name.to_owned(), catalog.clone());
Some(catalog)
}

/// Retrieves the list of available catalog names
fn catalog_names(&self) -> Vec<String> {
let cats = self.catalogs.read().unwrap();
cats.keys().cloned().collect()
}

/// Retrieves a specific catalog by name, provided it exists.
fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
let cats = self.catalogs.read().unwrap();
cats.get(name).cloned()
}
}