Skip to content

Commit

Permalink
DownloadedFile supports file: URLs.
Browse files Browse the repository at this point in the history
Fixes pantsbuild#11295

# Building wheels and fs_util will be skipped. Delete if not intended.
[ci skip-build-wheels]
  • Loading branch information
jsirois committed Jan 27, 2021
1 parent a2832b5 commit eb4b8f6
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 25 deletions.
1 change: 1 addition & 0 deletions src/rust/engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/rust/engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ crossbeam-channel = "0.4"
fnv = "1.0.5"
fs = { path = "fs" }
futures = "0.3"
futures-core = "^0.3.0"
graph = { path = "graph" }
hashing = { path = "hashing" }
indexmap = "1.4"
Expand Down
121 changes: 96 additions & 25 deletions src/rust/engine/src/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ use process_execution::{
self, CacheDest, CacheName, MultiPlatformProcess, Platform, Process, ProcessCacheScope,
};

use bytes::Bytes;
use graph::{Entry, Node, NodeError, NodeVisualizer};
use hashing::Digest;
use reqwest::Error;
use std::pin::Pin;
use store::{self, StoreFileByDigest};
use workunit_store::{
with_workunit, Level, UserMetadataItem, UserMetadataPyValue, WorkunitMetadata,
Expand Down Expand Up @@ -692,9 +695,90 @@ impl From<Snapshot> for NodeKey {
}
}

#[async_trait]
trait StreamingDownload: Send {
async fn next(&mut self) -> Option<Result<Bytes, String>>;
}

struct NetDownload {
stream: futures_core::stream::BoxStream<'static, Result<Bytes, Error>>,
}

impl NetDownload {
async fn start(core: &Arc<Core>, url: Url, file_name: String) -> Result<NetDownload, String> {
// TODO: Retry failures
let response = core
.http_client
.get(url.clone())
.send()
.await
.map_err(|err| format!("Error downloading file: {}", err))?;

// Handle common HTTP errors.
if response.status().is_server_error() {
return Err(format!(
"Server error ({}) downloading file {} from {}",
response.status().as_str(),
file_name,
url,
));
} else if response.status().is_client_error() {
return Err(format!(
"Client error ({}) downloading file {} from {}",
response.status().as_str(),
file_name,
url,
));
}
let byte_stream = Pin::new(Box::new(response.bytes_stream()));
Ok(NetDownload {
stream: byte_stream,
})
}
}

#[async_trait]
impl StreamingDownload for NetDownload {
async fn next(&mut self) -> Option<Result<Bytes, String>> {
self
.stream
.next()
.await
.map(|result| result.map_err(|err| err.to_string()))
}
}

struct FileDownload {
stream: tokio::io::ReaderStream<tokio::fs::File>,
}

impl FileDownload {
async fn start(path: &str, file_name: String) -> Result<FileDownload, String> {
let file = tokio::fs::File::open(path).await.map_err(|e| {
format!(
"Error ({}) opening file at {} for download to {}",
e, path, file_name
)
})?;
let stream = tokio::io::reader_stream(file);
Ok(FileDownload { stream })
}
}

#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub struct DownloadedFile(pub Key);

#[async_trait]
impl StreamingDownload for FileDownload {
async fn next(&mut self) -> Option<Result<Bytes, String>> {
self
.stream
.next()
.await
.map(|result| result.map_err(|err| err.to_string()))
}
}

impl DownloadedFile {
async fn load_or_download(
&self,
Expand All @@ -720,36 +804,24 @@ impl DownloadedFile {
core.store().snapshot_of_one_file(path, digest, true).await
}

async fn start_download(
core: &Arc<Core>,
url: Url,
file_name: String,
) -> Result<Box<dyn StreamingDownload>, String> {
if url.scheme() == "file" {
return Ok(Box::new(FileDownload::start(url.path(), file_name).await?));
}
Ok(Box::new(NetDownload::start(&core, url, file_name).await?))
}

async fn download(
core: Arc<Core>,
url: Url,
file_name: String,
expected_digest: hashing::Digest,
) -> Result<(), String> {
// TODO: Retry failures
let response = core
.http_client
.get(url.clone())
.send()
.await
.map_err(|err| format!("Error downloading file: {}", err))?;

// Handle common HTTP errors.
if response.status().is_server_error() {
return Err(format!(
"Server error ({}) downloading file {} from {}",
response.status().as_str(),
file_name,
url,
));
} else if response.status().is_client_error() {
return Err(format!(
"Client error ({}) downloading file {} from {}",
response.status().as_str(),
file_name,
url,
));
}
let mut response_stream = DownloadedFile::start_download(&core, url, file_name).await?;

let (actual_digest, bytes) = {
struct SizeLimiter<W: std::io::Write> {
Expand Down Expand Up @@ -784,7 +856,6 @@ impl DownloadedFile {
size_limit: expected_digest.1,
});

let mut response_stream = response.bytes_stream();
while let Some(next_chunk) = response_stream.next().await {
let chunk =
next_chunk.map_err(|err| format!("Error reading URL fetch response: {}", err))?;
Expand Down

0 comments on commit eb4b8f6

Please sign in to comment.