Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add config dangerously_functions_filter #582

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ buffer_editor: null
# Controls the function calling feature. For setup instructions, visit https://github.com/sigoden/llm-functions
function_calling: false

# Regex for seletecting dangerous functions.
# User confirmation is required when executing these functions.
# e.g. 'execute_command|execute_js_code' 'execute_.*'
dangerously_functions: null

# Specifies the embedding model to use
embedding_model: null

Expand Down Expand Up @@ -253,4 +258,5 @@ bots:
- name: todo-sh
model: null
temperature: null
top_p: null
top_p: null
dangerously_functions: null
6 changes: 6 additions & 0 deletions src/config/bot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ impl Bot {
&self.name
}

pub fn config(&self) -> &BotConfig {
&self.config
}

pub fn functions(&self) -> &Functions {
&self.functions
}
Expand Down Expand Up @@ -170,6 +174,8 @@ pub struct BotConfig {
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dangerously_functions: Option<FunctionsFilter>,
}

impl BotConfig {
Expand Down
30 changes: 25 additions & 5 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@ use crate::client::{
create_client_config, list_chat_models, list_client_types, ClientConfig, Model,
OPENAI_COMPATIBLE_PLATFORMS,
};
use crate::function::{FunctionDeclaration, Functions, ToolCallResult};
use crate::function::{FunctionDeclaration, Functions, FunctionsFilter, ToolCallResult};
use crate::rag::Rag;
use crate::render::{MarkdownRender, RenderOptions};
use crate::utils::{
format_option_value, fuzzy_match, get_env_name, light_theme_from_colorfgbg, now, render_prompt,
set_text, warning_text, AbortSignal, IS_STDOUT_TERMINAL,
};
use crate::utils::*;

use anyhow::{anyhow, bail, Context, Result};
use fancy_regex::Regex;
use inquire::{Confirm, Select};
use parking_lot::RwLock;
use serde::Deserialize;
Expand Down Expand Up @@ -96,6 +94,7 @@ pub struct Config {
pub rag_top_k: usize,
pub rag_template: Option<String>,
pub function_calling: bool,
pub dangerously_functions: Option<FunctionsFilter>,
pub compress_threshold: usize,
pub summarize_prompt: Option<String>,
pub summary_prompt: Option<String>,
Expand Down Expand Up @@ -144,6 +143,7 @@ impl Default for Config {
rag_top_k: 4,
rag_template: None,
function_calling: false,
dangerously_functions: None,
compress_threshold: 4000,
summarize_prompt: None,
summary_prompt: None,
Expand Down Expand Up @@ -965,6 +965,26 @@ impl Config {
functions
}

pub fn is_dangerously_function(&self, name: &str) -> bool {
if get_env_bool("no_dangerously_functions") {
return false;
}
let dangerously_functions = match &self.bot {
Some(bot) => bot.config().dangerously_functions.as_ref(),
None => self.dangerously_functions.as_ref(),
};
match dangerously_functions {
None => false,
Some(regex) => {
let regex = match Regex::new(&format!("^({regex})$")) {
Ok(v) => v,
Err(_) => return false,
};
regex.is_match(name).unwrap_or(false)
}
}
}

pub fn buffer_editor(&self) -> Option<String> {
self.buffer_editor
.clone()
Expand Down
31 changes: 9 additions & 22 deletions src/function.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use crate::{
config::{Config, GlobalConfig},
utils::{
dimmed_text, get_env_bool, indent_text, run_command, run_command_with_output, warning_text,
IS_STDOUT_TERMINAL,
},
utils::*,
};

use anyhow::{anyhow, bail, Context, Result};
Expand All @@ -20,6 +17,7 @@ use std::{

pub const FUNCTION_ALL_MATCHER: &str = ".*";
pub type ToolResults = (Vec<ToolCallResult>, String);
pub type FunctionsFilter = String;

pub fn eval_tool_calls(
config: &GlobalConfig,
Expand Down Expand Up @@ -171,6 +169,7 @@ impl ToolCall {

pub fn eval(&self, config: &GlobalConfig) -> Result<Value> {
let function_name = self.name.clone();
let is_dangerously = config.read().is_dangerously_function(&function_name);
let (call_name, cmd_name, mut cmd_args) = match &config.read().bot {
Some(bot) => {
if !bot.functions().contains(&function_name) {
Expand Down Expand Up @@ -219,11 +218,11 @@ impl ToolCall {
#[cfg(windows)]
let cmd_name = polyfill_cmd_name(&cmd_name, &bin_dir);

let output = if self.is_execute() {
let output = if is_dangerously {
if *IS_STDOUT_TERMINAL {
println!("{prompt}");
let answer = Text::new("[1] Run, [2] Run & Retrieve, [3] Skip:")
.with_default("1")
.with_default("2")
.with_validator(|input: &str| match matches!(input, "1" | "2" | "3") {
true => Ok(Validation::Valid),
false => Ok(Validation::Invalid(
Expand All @@ -239,7 +238,7 @@ impl ToolCall {
}
Value::Null
}
"2" => run_and_retrieve(&cmd_name, &cmd_args, envs, &prompt)?,
"2" => run_and_retrieve(&cmd_name, &cmd_args, envs)?,
_ => Value::Null,
}
} else {
Expand All @@ -248,35 +247,23 @@ impl ToolCall {
}
} else {
println!("{}", dimmed_text(&prompt));
run_and_retrieve(&cmd_name, &cmd_args, envs, &prompt)?
run_and_retrieve(&cmd_name, &cmd_args, envs)?
};

Ok(output)
}

pub fn is_execute(&self) -> bool {
if get_env_bool("function_auto_execute") {
false
} else {
self.name.starts_with("may_") || self.name.contains("__may_")
}
}
}

fn run_and_retrieve(
cmd_name: &str,
cmd_args: &[String],
envs: HashMap<String, String>,
prompt: &str,
) -> Result<Value> {
let (success, stdout, stderr) = run_command_with_output(cmd_name, cmd_args, Some(envs))?;

if success {
if !stderr.is_empty() {
eprintln!(
"{}",
warning_text(&format!("{prompt}:\n{}", indent_text(&stderr, 4)))
);
eprintln!("{}", warning_text(&stderr));
}
let value = if !stdout.is_empty() {
serde_json::from_str(&stdout)
Expand All @@ -296,7 +283,7 @@ fn run_and_retrieve(
} else {
&stderr
};
bail!("{}", &format!("{prompt}:\n{}", indent_text(err, 4)));
bail!("{err}");
}
}

Expand Down
8 changes: 0 additions & 8 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,6 @@ pub fn dimmed_text(input: &str) -> String {
nu_ansi_term::Style::new().dimmed().paint(input).to_string()
}

pub fn indent_text(text: &str, spaces: usize) -> String {
let indent_size = " ".repeat(spaces);
text.lines()
.map(|line| format!("{}{}", indent_size, line))
.collect::<Vec<String>>()
.join("\n")
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down