Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(prompts): add prompts as first class citizens #145

Merged
merged 14 commits into from
Jul 12, 2024
160 changes: 157 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[workspace]
members = ["swiftide", "examples", "benchmarks"]
resolver = "2"

[profile.dev.package]
insta.opt-level = 3
similar.opt-level = 3
4 changes: 4 additions & 0 deletions swiftide/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ strum = "0.26"
strum_macros = "0.26"
num_cpus = "1.16"
pin-project-lite = "0.2"
tera = { version = "1", default-features = false }
lazy_static = { version = "1.5.0" }
uuid = { version = "1.10", features = ["v4"] }

# Integrations
async-openai = { version = "0.23", optional = true }
Expand Down Expand Up @@ -94,6 +97,7 @@ mockall = "0.12.1"
temp-dir = "0.1.13"
wiremock = "0.6.0"
test-case = "3.3.1"
insta = { version = "1.39.0", features = ["yaml"] }

[lints.rust]
unsafe_code = "forbid"
Expand Down
9 changes: 4 additions & 5 deletions swiftide/src/indexing/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ impl Debug for Node {
&self
.vectors
.iter()
.map(HashMap::iter)
.flatten()
.flat_map(HashMap::iter)
.map(|(embed_type, vec)| format!("'{embed_type}': {}", vec.len()))
.join(","),
)
Expand Down Expand Up @@ -99,15 +98,15 @@ impl Node {

if self.embed_mode == EmbedMode::PerField || self.embed_mode == EmbedMode::Both {
embeddables.push((EmbeddedField::Chunk, self.chunk.clone()));
for (name, value) in self.metadata.iter() {
for (name, value) in &self.metadata {
embeddables.push((EmbeddedField::Metadata(name.clone()), value.clone()));
}
}

embeddables
}

/// Converts the node into an [self::EmbeddedField::Combined] type of embeddable.
/// Converts the node into an [`self::EmbeddedField::Combined`] type of embeddable.
///
/// This embeddable format consists of the metadata formatted as key-value pairs, each on a new line,
/// followed by the data chunk.
Expand Down Expand Up @@ -153,7 +152,7 @@ impl Hash for Node {

/// Embed mode of the pipeline.
///
/// See also [super::pipeline::Pipeline::with_embed_mode].
/// See also [`super::pipeline::Pipeline::with_embed_mode`].
#[derive(Copy, Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
pub enum EmbedMode {
#[default]
Expand Down
3 changes: 2 additions & 1 deletion swiftide/src/indexing/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl Pipeline {

/// Sets the embed mode for the pipeline.
///
/// See also [super::node::EmbedMode].
/// See also [`super::node::EmbedMode`].
///
/// # Arguments
///
Expand All @@ -95,6 +95,7 @@ impl Pipeline {
/// # Returns
///
/// An instance of `Pipeline` with the updated embed mode.
#[must_use]
pub fn with_embed_mode(mut self, embed_mode: EmbedMode) -> Self {
self.stream = self
.stream
Expand Down
10 changes: 5 additions & 5 deletions swiftide/src/integrations/aws_bedrock/simple_prompt.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::SimplePrompt;
use crate::{prompt::Prompt, SimplePrompt};
use anyhow::Result;
use async_trait::async_trait;
use aws_sdk_bedrockruntime::primitives::Blob;
Expand All @@ -8,10 +8,10 @@ use super::AwsBedrock;
#[async_trait]
impl SimplePrompt for AwsBedrock {
#[tracing::instrument(skip_all, err)]
async fn prompt(&self, prompt: &str) -> Result<String> {
async fn prompt(&self, prompt: Prompt) -> Result<String> {
let blob = self
.model_family
.build_request_to_bytes(prompt, &self.model_config)
.build_request_to_bytes(prompt.render().await?, &self.model_config)
.map(Blob::new)?;

let response_bytes = self.client.prompt_u8(&self.model_id, blob).await?;
Expand Down Expand Up @@ -53,7 +53,7 @@ mod test {
.build()
.unwrap();

let response = bedrock.prompt("Hello").await.unwrap();
let response = bedrock.prompt("Hello".into()).await.unwrap();

assert_eq!(response, "Hello, world!");
}
Expand Down Expand Up @@ -84,7 +84,7 @@ mod test {
.test_client(bedrock_mock)
.build()
.unwrap();
let response = bedrock.prompt("Hello").await.unwrap();
let response = bedrock.prompt("Hello".into()).await.unwrap();
assert_eq!(response, "Hello, world!");
}
}
6 changes: 3 additions & 3 deletions swiftide/src/integrations/openai/simple_prompt.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! This module provides an implementation of the `SimplePrompt` trait for the `OpenAI` struct.
//! It defines an asynchronous function to interact with the `OpenAI` API, allowing prompt processing
//! and generating responses as part of the Swiftide system.
use crate::SimplePrompt;
use crate::{prompt::Prompt, SimplePrompt};
use async_openai::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};
use async_trait::async_trait;

Expand All @@ -25,7 +25,7 @@ impl SimplePrompt for OpenAI {
/// - Returns an error if the request to the OpenAI API fails.
/// - Returns an error if the response does not contain the expected content.
#[tracing::instrument(skip_all, err)]
async fn prompt(&self, prompt: &str) -> Result<String> {
async fn prompt(&self, prompt: Prompt) -> Result<String> {
// Retrieve the model from the default options, returning an error if not set.
let model = self
.default_options
Expand All @@ -37,7 +37,7 @@ impl SimplePrompt for OpenAI {
let request = CreateChatCompletionRequestArgs::default()
.model(model)
.messages(vec![ChatCompletionRequestUserMessageArgs::default()
.content(prompt)
.content(prompt.render().await?)
.build()?
.into()])
.build()?;
Expand Down
Loading