Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add huggingface extension #261

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
HF extension working
matthewmturner committed Jan 24, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 7fb99371c9b1c291d1ff11917125818511867c9e
49 changes: 26 additions & 23 deletions src/extensions/huggingface.rs
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ use datafusion_common::DataFusionError;
use log::info;
use std::sync::Arc;

use opendal::{services::Huggingface, Builder, Operator};
use opendal::{services::Huggingface, Operator};
use url::Url;

#[derive(Debug, Default)]
@@ -59,40 +59,43 @@ impl Extension for HuggingFaceExtension {
let mut hf_builder = Huggingface::default();
if let Some(repo_type) = &huggingface_config.repo_type {
hf_builder = hf_builder.repo_type(repo_type);
// url_parts.push(repo_type)
};
if let Some(repo_id) = &huggingface_config.repo_id {
hf_builder = hf_builder.repo_id(repo_id);
// url_parts.push(repo_id);
};
if let Some(revision) = &huggingface_config.revision {
hf_builder = hf_builder.revision(revision);
// url_parts.push("blob");
// url_parts.push(revision);
};
if let Some(root) = &huggingface_config.root {
hf_builder = hf_builder.root(root);
};
if let Some(token) = &huggingface_config.token {
hf_builder = hf_builder.token(token);
};
if let Some(repo_id) = &huggingface_config.repo_id {
hf_builder = hf_builder.repo_id(repo_id);

let operator = Operator::new(hf_builder)
.map_err(|e| {
datafusion_common::error::DataFusionError::External(e.to_string().into())
})?
.finish();

let operator = Operator::new(hf_builder)
.map_err(|e| {
datafusion_common::error::DataFusionError::External(e.to_string().into())
})?
.finish();
let store = object_store_opendal::OpendalStore::new(operator);

let store = object_store_opendal::OpendalStore::new(operator);
// let url = Url::parse(url_parts.join("/").as_str()).map_err(|e| {
// datafusion_common::error::DataFusionError::External(e.to_string().into())
// })?;
let url = Url::try_from("hf://")
.map_err(|e| DataFusionError::External(e.to_string().into()))?;
println!("Registering store for huggingface url: {url}");
builder
.runtime_env()
.register_object_store(&url, Arc::new(store));
// `repo_id` seems to always have a '/' to separate the organization and repo name
// but this causes issues with registering external tables and the URLs don't fully
// reflect the organization and repo name (it only shows the organization name).
// So we replace the '/' with a '-' so that the URL has both.
//
// An example URL to use is:
// 'hf://huggingfacetb-finemath/finemath-3plus/train-00000-of-00128.parquet'
//
// Where 'huggingfacetb' is the organization name and 'finemath' is the repo name
let url = Url::try_from(format!("hf://{}", repo_id.replace("/", "-")).as_str())
.map_err(|e| DataFusionError::External(e.to_string().into()))?;
info!("Registering store for huggingface url: {url}");
builder
.runtime_env()
.register_object_store(&url, Arc::new(store));
};
}

Ok(())
16 changes: 14 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -47,7 +47,21 @@ fn main() -> Result<()> {
runtime.block_on(entry_point)
}

fn should_init_env_logger(cli: &DftArgs) -> bool {
#[cfg(feature = "experimental-flightsql-server")]
if cli.serve {
return true;
}
if !cli.files.is_empty() || !cli.commands.is_empty() {
return true;
}
false
}

async fn app_entry_point(cli: DftArgs) -> Result<()> {
if should_init_env_logger(&cli) {
env_logger::init();
}
let state = state::initialize(cli.config_path());
let session_state_builder = DftSessionStateBuilder::new()
.with_execution_config(state.config.execution.clone())
@@ -56,7 +70,6 @@ async fn app_entry_point(cli: DftArgs) -> Result<()> {
#[cfg(feature = "experimental-flightsql-server")]
if cli.serve {
// FlightSQL Server mode: start a FlightSQL server
env_logger::init();
const DEFAULT_SERVER_ADDRESS: &str = "127.0.0.1:50051";
info!("Starting FlightSQL server on {}", DEFAULT_SERVER_ADDRESS);
let session_state = session_state_builder
@@ -83,7 +96,6 @@ async fn app_entry_point(cli: DftArgs) -> Result<()> {
}
if !cli.files.is_empty() || !cli.commands.is_empty() {
// CLI mode: executing commands from files or CLI arguments
env_logger::init();
let session_state = session_state_builder.with_app_type(AppType::Cli).build()?;
let execution_ctx =
ExecutionContext::try_new(&state.config.execution, session_state, AppType::Cli)?;