Skip to content

Commit

Permalink
feat(search): Add dynamic SQL query generation for pgvector search
Browse files Browse the repository at this point in the history
Implements flexible query generation for PostgreSQL vector similarity search with:
- Introduces PgVecCustomStrategy for custom vector search implementations
- Adds builder pattern support for configuring search parameters and filters
- Enables dynamic SQL generation with metadata filtering and vector similarity

This change allows users to create customized vector similarity searches while
maintaining type safety and query optimization capabilities.

Signed-off-by: shamb0 <[email protected]>
  • Loading branch information
shamb0 committed Dec 20, 2024
1 parent 299bcc9 commit 4c9b793
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 94 deletions.
123 changes: 52 additions & 71 deletions swiftide-core/src/search_strategies/custom_strategy.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,63 @@
//! Generic vector search strategy framework for customizable query generation.
//!
//! Provides core abstractions for vector similarity search:
//! - Generic query type parameter for storage-specific implementations
//! - Configurable vector field selection and result limits
//! - Generic query type parameter for retriever-specific implementations
//! - Flexible query generation through closure-based configuration
//!
//! This module serves as the foundation for implementing custom vector
//! search strategies across different storage backends, ensuring type
//! safety and consistent behavior while allowing maximum flexibility
//! in query generation.
//! This module implements a strategy pattern for vector similarity search,
//! allowing different retrieval backends to provide their own query generation
//! logic while maintaining a consistent interface. The framework emphasizes
//! composition over inheritance, enabling configuration through closures
//! rather than struct fields.
use crate::{
indexing::EmbeddedField,
querying::{self, states, Query},
};
use crate::querying::{self, states, Query};
use anyhow::{anyhow, Result};
use std::marker::PhantomData;
use std::sync::Arc;

/// A type alias to simplify the query generation function type
type QueryGenerator<Q, T> = Arc<dyn Fn(&T, &Query<states::Pending>) -> Result<Q> + Send + Sync>;
/// A type alias for query generation functions.
///
/// The query generator takes a pending query state and produces a
/// retriever-specific query type. All configuration parameters should
/// be captured in the closure's environment.
type QueryGenerator<Q> = Arc<dyn Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync>;

/// `CustomQuery` provides a flexible way to generate provider-specific search queries.
/// `CustomStrategy` provides a flexible way to generate retriever-specific search queries.
///
/// This struct implements a strategy pattern for vector similarity search, allowing
/// different retrieval backends to provide their own query generation logic. Configuration
/// is managed through the query generation closure, promoting a more flexible and
/// composable design.
///
/// # Type Parameters
/// * `Q` - The provider-specific query type (e.g., `sqlx::QueryBuilder` for `PostgreSQL`)
/// * `Q` - The retriever-specific query type (e.g., `sqlx::QueryBuilder` for `PostgreSQL`)
///
/// # Examples
/// ```
/// let strategy = CustomQuery::from_query(|strategy, query_node| {
/// // Query construction logic
/// Ok(provider_specific_query)
/// ```rust
/// // Define search configuration
/// const MAX_SEARCH_RESULTS: i64 = 5;
///
/// // Create a custom search strategy
/// let strategy = CustomStrategy::from_query(|query_node| {
/// let mut builder = QueryBuilder::new();
///
/// // Configure search parameters within the closure
/// builder.push(" LIMIT ");
/// builder.push_bind(MAX_SEARCH_RESULTS);
///
/// Ok(builder)
/// });
/// ```
///
/// # Implementation Notes
/// - Search configuration (like result limits and vector fields) should be defined
/// in the closure's scope
/// - Implementers are responsible for validating configuration values
/// - The query generator has access to the full query state for maximum flexibility
pub struct CustomStrategy<Q> {
/// The query generation function now returns a `Q`
query: Option<QueryGenerator<Q, Self>>,
/// Maximum number of results to return
top_k: u64,
/// Field to use for vector similarity search
vector_field: EmbeddedField,
query: Option<QueryGenerator<Q>>,

/// `PhantomData` to handle the generic parameter
_marker: PhantomData<Q>,
}
Expand All @@ -49,8 +68,6 @@ impl<Q> Default for CustomStrategy<Q> {
fn default() -> Self {
Self {
query: None,
top_k: super::DEFAULT_TOP_K,
vector_field: EmbeddedField::Combined,
_marker: PhantomData,
}
}
Expand All @@ -61,22 +78,25 @@ impl<Q> Clone for CustomStrategy<Q> {
fn clone(&self) -> Self {
Self {
query: self.query.clone(), // Arc clone is fine
top_k: self.top_k,
vector_field: self.vector_field.clone(),
_marker: PhantomData,
}
}
}

impl<Q: Send + Sync + 'static> CustomStrategy<Q> {
/// Creates a new `CustomQuery` with a query generation function
/// Creates a new `CustomStrategy` with a query generation function.
///
/// The provided closure should contain all necessary configuration for
/// query generation. This design allows for more flexible configuration
/// management compared to struct-level fields.
///
/// # Parameters
/// * `query` - A closure that generates retriever-specific queries
pub fn from_query(
query: impl Fn(&Self, &Query<states::Pending>) -> Result<Q> + Send + Sync + 'static,
query: impl Fn(&Query<states::Pending>) -> Result<Q> + Send + Sync + 'static,
) -> Self {
Self {
query: Some(Arc::new(query)),
top_k: super::DEFAULT_TOP_K,
vector_field: EmbeddedField::Combined,
_marker: PhantomData,
}
}
Expand All @@ -89,49 +109,10 @@ impl<Q: Send + Sync + 'static> CustomStrategy<Q> {
/// - The query function fails while processing the provided `query_node`.
pub fn build_query(&self, query_node: &Query<states::Pending>) -> Result<Q> {
match &self.query {
Some(query_fn) => Ok(query_fn(self, query_node)?),
Some(query_fn) => Ok(query_fn(query_node)?),
None => Err(anyhow!(
"No query function has been set. Use from_query() to set a query function."
)),
}
}

/// Sets the maximum number of results to return
///
/// # Panics
/// This function will panic if:
/// - `top_k` is greater than the maximum value for a Postgres `bigint` (i.e., `i64::MAX`).
/// - `top_k` is not positive (i.e., `top_k <= 0`).
#[must_use]
pub fn with_top_k(mut self, top_k: u64) -> Self {
// Ensure top_k is within Postgres bigint bounds
assert!(
i64::try_from(top_k).is_ok(),
"{}",
format!(
"top_k value {top_k} exceeds maximum allowed value {:#?}",
i64::MAX
)
);
assert!(top_k > 0, "top_k must be positive, got {top_k}");

self.top_k = top_k;
self
}

/// Sets the vector field to use for similarity search
#[must_use]
pub fn with_vector_field(mut self, vector_field: impl Into<EmbeddedField>) -> Self {
self.vector_field = vector_field.into();
self
}

// Accessor methods
pub fn top_k(&self) -> u64 {
self.top_k
}

pub fn vector_field(&self) -> &EmbeddedField {
&self.vector_field
}
}
34 changes: 11 additions & 23 deletions swiftide/tests/pgvector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ struct VectorSearchResult {
/// - Performs a similarity-based vector search on the database and validates the retrieved results.
///
/// Ensures correctness of end-to-end data flow, including table management, vector storage, and query execution.
#[ignore]
#[test_log::test(tokio::test)]
async fn test_pgvector_indexing() {
// Setup temporary directory and file for testing
Expand Down Expand Up @@ -134,7 +133,6 @@ async fn test_pgvector_indexing() {
/// a mock OpenAI client, configures `PgVector`, and executes a query
/// to ensure the pipeline retrieves the correct data and generates
/// an expected response.
#[ignore]
#[test_log::test(tokio::test)]
async fn test_pgvector_retrieve() {
// Setup temporary directory and file for testing
Expand Down Expand Up @@ -256,11 +254,11 @@ async fn test_pgvector_retrieve() {
/// - Transforms results into a meaningful summary
/// - Produces a final answer
///
/// # Implementation Notes
/// The test uses mock servers to simulate API responses, allowing for
/// reproducible testing without external dependencies. It demonstrates
/// the integration between different components: document processing,
/// vector storage, similarity search, and result transformation.
/// # Configuration Pattern
/// The test demonstrates the recommended configuration approach:
/// - Define search parameters as constants in the implementation scope
/// - Pass configuration through the query generator closure
/// - Keep the strategy struct minimal and focused on query generation
#[test_log::test(tokio::test)]
async fn test_pgvector_retrieve_dynamic_search() {
// Setup temporary directory and file for testing
Expand Down Expand Up @@ -331,7 +329,8 @@ async fn test_pgvector_retrieve_dynamic_search() {
// Configure search strategy
// Create a custom query generator with metadata filtering
let custom_strategy = query::search_strategies::CustomStrategy::from_query(
move |strategy, query_node| -> Result<sqlx::QueryBuilder<'static, sqlx::Postgres>> {
move |query_node| -> Result<sqlx::QueryBuilder<'static, sqlx::Postgres>> {
const CUSTOM_STRATEGY_MAX_RESULTS: i64 = 5;
let mut builder = sqlx::QueryBuilder::new("");
let table: &str = pgv_storage_for_closure.get_table_name();

Expand All @@ -353,7 +352,7 @@ async fn test_pgvector_retrieve_dynamic_search() {
builder.push("'{\"filter\": \"true\"}'::jsonb");

// Add vector similarity ordering
let vector_field = VectorConfig::from(strategy.vector_field().clone()).field;
let vector_field = VectorConfig::from(EmbeddedField::Combined).field;
builder.push(" ORDER BY ");
builder.push(vector_field);
builder.push(" <=> ");
Expand All @@ -369,23 +368,12 @@ async fn test_pgvector_retrieve_dynamic_search() {

// Add LIMIT clause
builder.push(" LIMIT ");
let top_k_i64 = i64::try_from(strategy.top_k()).map_err(|_| {
anyhow!(
"top_k value {} is too large for Postgres bigint",
strategy.top_k()
)
})?;

if top_k_i64 <= 0 {
return Err(anyhow!("top_k must be positive, got {}", top_k_i64));
}
builder.push_bind(top_k_i64);

builder.push_bind(CUSTOM_STRATEGY_MAX_RESULTS);

Ok(builder)
},
)
.with_top_k(5)
.with_vector_field(EmbeddedField::Combined);
);

let query_pipeline = query::Pipeline::from_search_strategy(custom_strategy)
.then_transform_query(query_transformers::GenerateSubquestions::from_client(
Expand Down

0 comments on commit 4c9b793

Please sign in to comment.