Skip to content

Commit

Permalink
feat: enhance .file for loading resources from diverse sources (#1155)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Feb 8, 2025
1 parent a491820 commit eb527ea
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 72 deletions.
138 changes: 86 additions & 52 deletions src/config/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use crate::client::{
MessageContent, MessageContentPart, MessageContentToolCalls, MessageRole, Model,
};
use crate::function::ToolResult;
use crate::utils::{base64_encode, sha256, AbortSignal};
use crate::utils::{base64_encode, is_loader_protocol, sha256, AbortSignal};

use anyhow::{bail, Context, Result};
use path_absolutize::Absolutize;
use std::{collections::HashMap, fs::File, io::Read, path::Path};
use indexmap::IndexSet;
use std::{collections::HashMap, fs::File, io::Read};
use unicode_width::{UnicodeWidthChar, UnicodeWidthStr};

const IMAGE_EXTS: [&str; 5] = ["png", "jpeg", "jpg", "webp", "gif"];
Expand Down Expand Up @@ -60,38 +60,19 @@ impl Input {
paths: Vec<String>,
role: Option<Role>,
) -> Result<Self> {
let mut raw_paths = vec![];
let mut external_cmds = vec![];
let mut local_paths = vec![];
let mut remote_urls = vec![];
let loaders = config.read().document_loaders.clone();
let (raw_paths, local_paths, remote_urls, external_cmds, protocol_paths, with_last_reply) =
resolve_paths(&loaders, paths)?;
let mut last_reply = None;
let mut with_last_reply = false;
for path in paths {
match resolve_local_path(&path) {
Some(v) => {
if v == "%%" {
with_last_reply = true;
raw_paths.push(v);
} else if v.len() > 2 && v.starts_with('`') && v.ends_with('`') {
external_cmds.push(v[1..v.len() - 1].to_string());
raw_paths.push(v);
} else {
if let Ok(path) = Path::new(&v).absolutize() {
raw_paths.push(path.display().to_string());
}
local_paths.push(v);
}
}
None => {
raw_paths.push(path.clone());
remote_urls.push(path);
}
}
}
let (documents, medias, data_urls) =
load_documents(config, external_cmds, local_paths, remote_urls)
.await
.context("Failed to load files")?;
let (documents, medias, data_urls) = load_documents(
&loaders,
local_paths,
remote_urls,
external_cmds,
protocol_paths,
)
.await
.context("Failed to load files")?;
let mut texts = vec![];
if !raw_text.is_empty() {
texts.push(raw_text.to_string());
Expand Down Expand Up @@ -409,11 +390,65 @@ fn resolve_role(config: &Config, role: Option<Role>) -> (Role, bool, bool) {
}
}

type ResolvePathsOutput = (
Vec<String>,
Vec<String>,
Vec<String>,
Vec<String>,
Vec<String>,
bool,
);

fn resolve_paths(
loaders: &HashMap<String, String>,
paths: Vec<String>,
) -> Result<ResolvePathsOutput> {
let mut raw_paths = IndexSet::new();
let mut local_paths = IndexSet::new();
let mut remote_urls = IndexSet::new();
let mut external_cmds = IndexSet::new();
let mut protocol_paths = IndexSet::new();
let mut with_last_reply = false;
for path in paths {
if path == "%%" {
with_last_reply = true;
raw_paths.insert(path);
} else if path.starts_with('`') && path.len() > 2 && path.ends_with('`') {
external_cmds.insert(path[1..path.len() - 1].to_string());
raw_paths.insert(path);
} else if is_url(&path) {
if path.strip_suffix("**").is_some() {
bail!("Invalid website '{path}'");
}
remote_urls.insert(path.clone());
raw_paths.insert(path);
} else if is_loader_protocol(loaders, &path) {
protocol_paths.insert(path.clone());
raw_paths.insert(path);
} else {
let resolved_path = resolve_home_dir(&path);
let absolute_path = to_absolute_path(&resolved_path)
.with_context(|| format!("Invalid path '{path}'"))?;
local_paths.insert(resolved_path);
raw_paths.insert(absolute_path);
}
}
Ok((
raw_paths.into_iter().collect(),
local_paths.into_iter().collect(),
remote_urls.into_iter().collect(),
external_cmds.into_iter().collect(),
protocol_paths.into_iter().collect(),
with_last_reply,
))
}

async fn load_documents(
config: &GlobalConfig,
external_cmds: Vec<String>,
loaders: &HashMap<String, String>,
local_paths: Vec<String>,
remote_urls: Vec<String>,
external_cmds: Vec<String>,
protocol_paths: Vec<String>,
) -> Result<(
Vec<(&'static str, String, String)>,
Vec<String>,
Expand All @@ -422,6 +457,7 @@ async fn load_documents(
let mut files = vec![];
let mut medias = vec![];
let mut data_urls = HashMap::new();

for cmd in external_cmds {
let (success, stdout, stderr) =
run_command_with_output(&SHELL.cmd, &[&SHELL.arg, &cmd], None)?;
Expand All @@ -433,23 +469,22 @@ async fn load_documents(
}

let local_files = expand_glob_paths(&local_paths, true).await?;
let loaders = config.read().document_loaders.clone();
for file_path in local_files {
if is_image(&file_path) {
let contents = read_media_to_data_url(&file_path)
.with_context(|| format!("Unable to read media file '{file_path}'"))?;
.with_context(|| format!("Unable to read media '{file_path}'"))?;
data_urls.insert(sha256(&contents), file_path);
medias.push(contents)
} else {
let document = load_file(&loaders, &file_path)
let document = load_file(loaders, &file_path)
.await
.with_context(|| format!("Unable to read file '{file_path}'"))?;
files.push(("FILE", file_path, document.contents));
}
}

for file_url in remote_urls {
let (contents, extension) = fetch_with_loaders(&loaders, &file_url, true)
let (contents, extension) = fetch_with_loaders(loaders, &file_url, true)
.await
.with_context(|| format!("Failed to load url '{file_url}'"))?;
if extension == MEDIA_URL_EXTENSION {
Expand All @@ -459,6 +494,17 @@ async fn load_documents(
files.push(("URL", file_url, contents));
}
}

for protocol_path in protocol_paths {
let documents = load_protocol_path(loaders, &protocol_path)
.with_context(|| format!("Failed to load from '{protocol_path}'"))?;
files.extend(
documents
.into_iter()
.map(|document| ("FROM", document.path, document.contents)),
);
}

Ok((files, medias, data_urls))
}

Expand All @@ -474,18 +520,6 @@ pub fn resolve_data_url(data_urls: &HashMap<String, String>, data_url: String) -
}
}

fn resolve_local_path(path: &str) -> Option<String> {
if is_url(path) {
return None;
}
let new_path = if let (Some(file), Some(home)) = (path.strip_prefix("~/"), dirs::home_dir()) {
home.join(file).display().to_string()
} else {
path.to_string()
};
Some(new_path)
}

fn is_image(path: &str) -> bool {
get_patch_extension(path)
.map(|v| IMAGE_EXTS.contains(&v.as_str()))
Expand Down
55 changes: 43 additions & 12 deletions src/rag/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use hnsw_rs::prelude::*;
use indexmap::{IndexMap, IndexSet};
use inquire::{required, validator::Validation, Confirm, Select, Text};
use parking_lot::RwLock;
use path_absolutize::Absolutize;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{collections::HashMap, env, fmt::Debug, fs, hash::Hash, path::Path, time::Duration};
Expand Down Expand Up @@ -321,8 +320,8 @@ impl Rag {
if let Some(spinner) = &spinner {
let _ = spinner.set_message(String::new());
}
let (document_paths, mut recursive_urls, mut urls, mut local_paths) =
resolve_paths(paths).await?;
let (document_paths, mut recursive_urls, mut urls, mut protocol_paths, mut local_paths) =
resolve_paths(&loaders, paths).await?;
let mut to_deleted: IndexMap<String, Vec<FileId>> = Default::default();
if refresh {
for (file_id, file) in &self.data.files {
Expand All @@ -342,6 +341,13 @@ impl Rag {
.into_iter()
.filter(|v| !self.data.document_paths.contains(&format!("{v}**")))
.collect();
let protocol_paths_cloned = protocol_paths.clone();
let match_protocol_path =
|v: &str| protocol_paths_cloned.iter().any(|root| v.starts_with(root));
protocol_paths = protocol_paths
.into_iter()
.filter(|v| !self.data.document_paths.contains(v))
.collect();
for (file_id, file) in &self.data.files {
if is_url(&file.path) {
if !urls.swap_remove(&file.path) && !match_recursive_url(&file.path) {
Expand All @@ -350,6 +356,13 @@ impl Rag {
.or_default()
.push(*file_id);
}
} else if is_loader_protocol(&loaders, &file.path) {
if !match_protocol_path(&file.path) {
to_deleted
.entry(file.hash.clone())
.or_default()
.push(*file_id);
}
} else if !local_paths.swap_remove(&file.path) {
to_deleted
.entry(file.hash.clone())
Expand All @@ -362,7 +375,7 @@ impl Rag {
let mut loaded_documents = vec![];
let mut has_error = false;
let mut index = 0;
let total = recursive_urls.len() + urls.len() + local_paths.len();
let total = recursive_urls.len() + urls.len() + protocol_paths.len() + local_paths.len();
let handle_error = |error: anyhow::Error, has_error: &mut bool| {
println!("{}", warning_text(&format!("⚠️ {error}")));
*has_error = true;
Expand All @@ -383,6 +396,14 @@ impl Rag {
Err(err) => handle_error(err, &mut has_error),
}
}
for protocol_path in protocol_paths {
index += 1;
println!("Load {protocol_path} [{index}/{total}]");
match load_protocol_path(&loaders, &protocol_path) {
Ok(v) => loaded_documents.extend(v),
Err(err) => handle_error(err, &mut has_error),
}
}
for local_path in local_paths {
index += 1;
println!("Load {local_path} [{index}/{total}]");
Expand Down Expand Up @@ -899,7 +920,7 @@ fn set_chunk_overlay(default_value: usize) -> Result<usize> {
fn add_documents() -> Result<Vec<String>> {
let text = Text::new("Add documents:")
.with_validator(required!("This field is required"))
.with_help_message("e.g. file;dir/;dir/**/*.{md,mdx};solo-url;site-url/**")
.with_help_message("e.g. file;dir/;dir/**/*.{md,mdx};loader:resource;url;website/**")
.prompt()?;
let paths = text
.split(';')
Expand All @@ -916,16 +937,19 @@ fn add_documents() -> Result<Vec<String>> {
}

async fn resolve_paths<T: AsRef<str>>(
loaders: &HashMap<String, String>,
paths: &[T],
) -> Result<(
IndexSet<String>,
IndexSet<String>,
IndexSet<String>,
IndexSet<String>,
IndexSet<String>,
)> {
let mut document_paths = IndexSet::new();
let mut recursive_urls = IndexSet::new();
let mut urls = IndexSet::new();
let mut protocol_paths = IndexSet::new();
let mut absolute_paths = vec![];
for path in paths {
let path = path.as_ref().trim();
Expand All @@ -936,18 +960,25 @@ async fn resolve_paths<T: AsRef<str>>(
urls.insert(path.to_string());
}
document_paths.insert(path.to_string());
} else if is_loader_protocol(loaders, path) {
protocol_paths.insert(path.to_string());
document_paths.insert(path.to_string());
} else {
let absolute_path = Path::new(path)
.absolutize()
.with_context(|| format!("Invalid path '{path}'"))?
.display()
.to_string();
absolute_paths.push(absolute_path.clone());
let resolved_path = resolve_home_dir(path);
let absolute_path = to_absolute_path(&resolved_path)
.with_context(|| format!("Invalid path '{path}'"))?;
absolute_paths.push(resolved_path);
document_paths.insert(absolute_path);
}
}
let local_paths = expand_glob_paths(&absolute_paths, false).await?;
Ok((document_paths, recursive_urls, urls, local_paths))
Ok((
document_paths,
recursive_urls,
urls,
protocol_paths,
local_paths,
))
}

fn progress(spinner: &Option<Spinner>, message: String) {
Expand Down
7 changes: 4 additions & 3 deletions src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,14 +579,15 @@ pub async fn run_repl_command(
ask(config, abort_signal.clone(), input, true).await?;
}
None => println!(
r#"Usage: .file <file|dir|url|%%|cmd>... [-- <text>...]
r#"Usage: .file <file|dir|url|cmd|loader:resource|%%>... [-- <text>...]
.file /tmp/file.txt
.file src/ Cargo.toml -- analyze
.file https://example.com/file.txt -- summarize
.file https://example.com/image.png -- recognize text
.file %% -- translate last reply to english
.file `git diff` -- Generate git commit message"#
.file `git diff` -- Generate git commit message
.file jina:https://example.com
.file %% -- translate last reply to english"#
),
},
".continue" => {
Expand Down
5 changes: 2 additions & 3 deletions src/utils/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ pub fn run_command_with_output<T: AsRef<OsStr>>(
}

pub fn run_loader_command(path: &str, extension: &str, loader_command: &str) -> Result<String> {
let cmd_args = shell_words::split(loader_command).with_context(|| {
anyhow!("Invalid rag document loader '{extension}': `{loader_command}`")
})?;
let cmd_args = shell_words::split(loader_command)
.with_context(|| anyhow!("Invalid document loader '{extension}': `{loader_command}`"))?;
let mut use_stdout = true;
let outpath = temp_file("-output-", "").display().to_string();
let cmd_args: Vec<_> = cmd_args
Expand Down
Loading

0 comments on commit eb527ea

Please sign in to comment.