From 3775a3c08b3f42e4a51e597caf65a687c749564f Mon Sep 17 00:00:00 2001 From: haixuanTao Date: Wed, 9 Oct 2024 12:01:50 +0200 Subject: [PATCH 1/2] Make downloading file use `content_dosposition` header or the filename within the url to get the path avoiding to create confusion on the name of the file. --- binaries/daemon/src/spawn.rs | 10 ++-- binaries/runtime/src/operator/python.rs | 7 +-- binaries/runtime/src/operator/shared_lib.rs | 13 ++---- libraries/extensions/download/src/lib.rs | 52 ++++++++++++++------- 4 files changed, 43 insertions(+), 39 deletions(-) diff --git a/binaries/daemon/src/spawn.rs b/binaries/daemon/src/spawn.rs index 78302f0df..87eca5a3b 100644 --- a/binaries/daemon/src/spawn.rs +++ b/binaries/daemon/src/spawn.rs @@ -27,7 +27,6 @@ use dora_node_api::{ }; use eyre::{ContextCompat, WrapErr}; use std::{ - env::consts::EXE_EXTENSION, path::{Path, PathBuf}, process::Stdio, sync::Arc, @@ -101,13 +100,10 @@ pub async fn spawn_node( source => { let resolved_path = if source_is_url(source) { // try to download the shared library - let target_path = Path::new("build") - .join(node_id.to_string()) - .with_extension(EXE_EXTENSION); - download_file(source, &target_path) + let target_dir = Path::new("build"); + download_file(source, &target_dir) .await - .wrap_err("failed to download custom node")?; - target_path.clone() + .wrap_err("failed to download custom node")? } else { resolve_path(source, working_dir).wrap_err_with(|| { format!("failed to resolve node source `{}`", source) diff --git a/binaries/runtime/src/operator/python.rs b/binaries/runtime/src/operator/python.rs index c511f4634..4885dc37b 100644 --- a/binaries/runtime/src/operator/python.rs +++ b/binaries/runtime/src/operator/python.rs @@ -42,16 +42,13 @@ pub fn run( dataflow_descriptor: &Descriptor, ) -> eyre::Result<()> { let path = if source_is_url(&python_source.source) { - let target_path = Path::new("build") - .join(node_id.to_string()) - .join(format!("{}.py", operator_id)); + let target_path = Path::new("build"); // try to download the shared library let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; rt.block_on(download_file(&python_source.source, &target_path)) - .wrap_err("failed to download Python operator")?; - target_path + .wrap_err("failed to download Python operator")? } else { Path::new(&python_source.source).to_owned() }; diff --git a/binaries/runtime/src/operator/shared_lib.rs b/binaries/runtime/src/operator/shared_lib.rs index 70fccff4c..e7ec30686 100644 --- a/binaries/runtime/src/operator/shared_lib.rs +++ b/binaries/runtime/src/operator/shared_lib.rs @@ -27,26 +27,21 @@ use tokio::sync::{mpsc::Sender, oneshot}; use tracing::{field, span}; pub fn run( - node_id: &NodeId, - operator_id: &OperatorId, + _node_id: &NodeId, + _operator_id: &OperatorId, source: &str, events_tx: Sender, incoming_events: flume::Receiver, init_done: oneshot::Sender>, ) -> eyre::Result<()> { let path = if source_is_url(source) { - let target_path = adjust_shared_library_path( - &Path::new("build") - .join(node_id.to_string()) - .join(operator_id.to_string()), - )?; + let target_path = &Path::new("build"); // try to download the shared library let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; rt.block_on(download_file(source, &target_path)) - .wrap_err("failed to download shared library operator")?; - target_path + .wrap_err("failed to download shared library operator")? } else { adjust_shared_library_path(Path::new(source))? }; diff --git a/libraries/extensions/download/src/lib.rs b/libraries/extensions/download/src/lib.rs index b0843c83f..1e49fee3e 100644 --- a/libraries/extensions/download/src/lib.rs +++ b/libraries/extensions/download/src/lib.rs @@ -1,35 +1,51 @@ -use eyre::Context; +use eyre::{Context, ContextCompat}; #[cfg(unix)] use std::os::unix::prelude::PermissionsExt; -use std::path::Path; +use std::path::{Path, PathBuf}; use tokio::io::AsyncWriteExt; -use tracing::info; -pub async fn download_file(url: T, target_path: &Path) -> Result<(), eyre::ErrReport> -where - T: reqwest::IntoUrl + std::fmt::Display + Copy, -{ - if target_path.exists() { - info!("Using cache: {:?}", target_path.to_str()); - return Ok(()); +fn get_filename(response: &reqwest::Response) -> Option { + if let Some(content_disposition) = response.headers().get("content-disposition") { + if let Ok(filename) = content_disposition.to_str() { + if let Some(name) = filename.split("filename=").nth(1) { + return Some(name.trim_matches('"').to_string()); + } + } } - if let Some(parent) = target_path.parent() { - tokio::fs::create_dir_all(parent) - .await - .wrap_err("failed to create parent folder")?; + // If Content-Disposition header is not available, extract from URL + let path = Path::new(response.url().as_str()); + if let Some(name) = path.file_name() { + if let Some(filename) = name.to_str() { + return Some(filename.to_string()); + } } + None +} + +pub async fn download_file(url: T, target_dir: &Path) -> Result +where + T: reqwest::IntoUrl + std::fmt::Display + Copy, +{ + tokio::fs::create_dir_all(&target_dir) + .await + .wrap_err("failed to create parent folder")?; + let response = reqwest::get(url) .await - .wrap_err_with(|| format!("failed to request operator from `{url}`"))? + .wrap_err_with(|| format!("failed to request operator from `{url}`"))?; + + let filename = get_filename(&response).context("Could not find a filename")?; + let bytes = response .bytes() .await .wrap_err("failed to read operator from `{uri}`")?; - let mut file = tokio::fs::File::create(target_path) + let path = target_dir.join(filename); + let mut file = tokio::fs::File::create(&path) .await .wrap_err("failed to create target file")?; - file.write_all(&response) + file.write_all(&bytes) .await .wrap_err("failed to write downloaded operator to file")?; file.sync_all().await.wrap_err("failed to `sync_all`")?; @@ -39,5 +55,5 @@ where .await .wrap_err("failed to make downloaded file executable")?; - Ok(()) + Ok(path.to_path_buf()) } From b494a47d85dcacf59f1a49fc6f1386555cf8a4ea Mon Sep 17 00:00:00 2001 From: haixuanTao Date: Wed, 9 Oct 2024 12:08:38 +0200 Subject: [PATCH 2/2] use `download_file` when calling build or start for a url dataflow --- Cargo.lock | 1 + binaries/cli/Cargo.toml | 1 + binaries/cli/src/build.rs | 7 +++++-- binaries/cli/src/main.rs | 32 +++++++++++++++++++++++++------- 4 files changed, 32 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f9811f522..7e70b2d34 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2284,6 +2284,7 @@ dependencies = [ "dora-coordinator", "dora-core", "dora-daemon", + "dora-download", "dora-message", "dora-node-api-c", "dora-operator-api-c", diff --git a/binaries/cli/Cargo.toml b/binaries/cli/Cargo.toml index 42494a85a..73b4c8589 100644 --- a/binaries/cli/Cargo.toml +++ b/binaries/cli/Cargo.toml @@ -24,6 +24,7 @@ dora-core = { workspace = true } dora-message = { workspace = true } dora-node-api-c = { workspace = true } dora-operator-api-c = { workspace = true } +dora-download = { workspace = true } serde = { version = "1.0.136", features = ["derive"] } serde_yaml = "0.9.11" webbrowser = "0.8.3" diff --git a/binaries/cli/src/build.rs b/binaries/cli/src/build.rs index 496402a81..62b38f6fb 100644 --- a/binaries/cli/src/build.rs +++ b/binaries/cli/src/build.rs @@ -5,8 +5,11 @@ use dora_core::{ use eyre::{eyre, Context}; use std::{path::Path, process::Command}; -pub fn build(dataflow: &Path) -> eyre::Result<()> { - let descriptor = Descriptor::blocking_read(dataflow)?; +use crate::resolve_dataflow; + +pub fn build(dataflow: String) -> eyre::Result<()> { + let dataflow = resolve_dataflow(dataflow).context("could not resolve dataflow")?; + let descriptor = Descriptor::blocking_read(&dataflow)?; let dataflow_absolute = if dataflow.is_relative() { std::env::current_dir().unwrap().join(dataflow) } else { diff --git a/binaries/cli/src/main.rs b/binaries/cli/src/main.rs index 120cf847c..f6197724e 100644 --- a/binaries/cli/src/main.rs +++ b/binaries/cli/src/main.rs @@ -4,13 +4,14 @@ use colored::Colorize; use communication_layer_request_reply::{RequestReplyLayer, TcpLayer, TcpRequestReplyConnection}; use dora_coordinator::Event; use dora_core::{ - descriptor::Descriptor, + descriptor::{source_is_url, Descriptor}, topics::{ DORA_COORDINATOR_PORT_CONTROL_DEFAULT, DORA_COORDINATOR_PORT_DEFAULT, DORA_DAEMON_LOCAL_LISTEN_PORT_DEFAULT, }, }; use dora_daemon::Daemon; +use dora_download::download_file; use dora_message::{ cli_to_coordinator::ControlRequest, coordinator_to_cli::{ControlRequestReply, DataflowList, DataflowResult, DataflowStatus}, @@ -21,7 +22,7 @@ use dora_tracing::set_up_tracing_opts; use duration_str::parse; use eyre::{bail, Context}; use formatting::FormatDataflowError; -use std::{io::Write, net::SocketAddr}; +use std::{env::current_dir, io::Write, net::SocketAddr}; use std::{ net::{IpAddr, Ipv4Addr}, path::PathBuf, @@ -80,8 +81,8 @@ enum Command { /// Run build commands provided in the given dataflow. Build { /// Path to the dataflow descriptor file - #[clap(value_name = "PATH", value_hint = clap::ValueHint::FilePath)] - dataflow: PathBuf, + #[clap(value_name = "PATH")] + dataflow: String, }, /// Generate a new project or node. Choose the language between Rust, Python, C or C++. New { @@ -111,8 +112,8 @@ enum Command { /// Start the given dataflow path. Attach a name to the running dataflow by using --name. Start { /// Path to the dataflow descriptor file - #[clap(value_name = "PATH", value_hint = clap::ValueHint::FilePath)] - dataflow: PathBuf, + #[clap(value_name = "PATH")] + dataflow: String, /// Assign a name to the dataflow #[clap(long)] name: Option, @@ -324,7 +325,7 @@ fn run() -> eyre::Result<()> { graph::create(dataflow, mermaid, open)?; } Command::Build { dataflow } => { - build::build(&dataflow)?; + build::build(dataflow)?; } Command::New { args, @@ -366,6 +367,7 @@ fn run() -> eyre::Result<()> { detach, hot_reload, } => { + let dataflow = resolve_dataflow(dataflow).context("could not resolve dataflow")?; let dataflow_descriptor = Descriptor::blocking_read(&dataflow).wrap_err("Failed to read yaml dataflow")?; let working_dir = dataflow @@ -656,3 +658,19 @@ fn connect_to_coordinator( ) -> std::io::Result> { TcpLayer::new().connect(coordinator_addr) } + +fn resolve_dataflow(dataflow: String) -> eyre::Result { + let dataflow = if source_is_url(&dataflow) { + // try to download the shared library + let target_path = current_dir().context("Could not access the current dir")?; + let rt = Builder::new_current_thread() + .enable_all() + .build() + .context("tokio runtime failed")?; + rt.block_on(async { download_file(&dataflow, &target_path).await }) + .wrap_err("failed to download dataflow yaml file")? + } else { + PathBuf::from(dataflow) + }; + Ok(dataflow) +}