From 3774dca60079047b2242424e21bb363030221497 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 19 Dec 2024 15:26:28 -0500 Subject: [PATCH] Add BigQueryVectorSearchEnrichmentHandler. --- .../apache_beam/ml/rag/enrichment/__init__.py | 20 + .../rag/enrichment/bigquery_vector_search.py | 238 ++++++++++ .../bigquery_vector_search_it_test.py | 405 ++++++++++++++++++ 3 files changed, 663 insertions(+) create mode 100644 sdks/python/apache_beam/ml/rag/enrichment/__init__.py create mode 100644 sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py create mode 100644 sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search_it_test.py diff --git a/sdks/python/apache_beam/ml/rag/enrichment/__init__.py b/sdks/python/apache_beam/ml/rag/enrichment/__init__.py new file mode 100644 index 000000000000..efcb5ac31950 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/enrichment/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Enrichment components for RAG pipelines. +This module provides components for vector search enrichment in RAG pipelines. +""" diff --git a/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py new file mode 100644 index 000000000000..cf2d0a579127 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search.py @@ -0,0 +1,238 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from google.cloud import bigquery + +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Embedding +from apache_beam.transforms.enrichment import EnrichmentSourceHandler + + +@dataclass +class BigQueryVectorSearchParameters: + """Parameters for configuring BigQuery vector similarity search. + + This class encapsulates the configuration needed to perform vector similarity + search using BigQuery's VECTOR_SEARCH function. It handles formatting the + query with proper embedding vectors and metadata restrictions. + + Args: + table_name: Fully qualified BigQuery table name containing vectors + embedding_column: Column name containing the embedding vectors + columns: List of columns to retrieve from matched vectors + neighbor_count: Number of similar vectors to return (top-k) + metadata_restriction_template: Template string for filtering vectors by + metadata. Use Python string format syntax, e.g. + "metadata.type = '{doc_type}'" + distance_type: Optional distance metric to use. Supported values: + COSINE_DISTANCE (default), EUCLIDEAN_DISTANCE, or DOT_PRODUCT + options: Optional dictionary of additional VECTOR_SEARCH options + + Example: + ```python + params = BigQueryVectorSearchParameters( + table_name='project.dataset.embeddings', + embedding_column='embedding', + columns=['content', 'url', 'date'], + neighbor_count=5, + metadata_restriction_template="type = '{doc_type}'", + distance_type='COSINE_DISTANCE' + ) + ``` + """ + table_name: str + embedding_column: str + columns: List[str] + neighbor_count: int + metadata_restriction_template: str + distance_type: Optional[str] = None + options: Optional[Dict[str, Any]] = None + + def format_query(self, chunks: List[Chunk]) -> str: + """Format the vector search query template.""" + base_columns_str = ", ".join(f"base.{col}" for col in self.columns) + columns_str = ", ".join(self.columns) + distance_clause = ( + f", distance_type => '{self.distance_type}'" + if self.distance_type else "") + options_clause = (f", options => {self.options}" if self.options else "") + + # Create metadata check function + metadata_fn = """ + CREATE TEMP FUNCTION check_metadata( + metadata ARRAY>, + search_key STRING, + search_value STRING + ) + AS (( + SELECT COUNT(*) > 0 + FROM UNNEST(metadata) + WHERE key = search_key AND value = search_value + )); + """ + + # Union embeddings with IDs + embedding_unions = [] + for chunk in chunks: + if chunk.embedding is None or chunk.embedding.dense_embedding is None: + raise ValueError(f"Chunk {chunk.id} missing embedding") + embedding_str = ( + f"SELECT '{chunk.id}' as id, " + f"{[float(x) for x in chunk.embedding.dense_embedding]} as embedding") + embedding_unions.append(embedding_str) + embeddings_query = " UNION ALL ".join(embedding_unions) + + # Format metadata restrictions for each embedding + metadata_restrictions = [ + f"({self.metadata_restriction_template.format(**chunk.metadata)})" + for chunk in chunks + ] + combined_restrictions = " OR ".join(metadata_restrictions) + + return f""" + {metadata_fn} + + WITH query_embeddings AS ({embeddings_query}) + + SELECT + query.id, + ARRAY_AGG( + STRUCT({base_columns_str}) + ) as chunks + FROM VECTOR_SEARCH( + (SELECT {columns_str}, {self.embedding_column} + FROM `{self.table_name}` + WHERE {combined_restrictions}), + '{self.embedding_column}', + TABLE query_embeddings, + top_k => {self.neighbor_count} + {distance_clause} + {options_clause} + ) + GROUP BY query.id + """ + + +class BigQueryVectorSearchEnrichmentHandler( + EnrichmentSourceHandler[Union[Chunk, List[Chunk]], + List[Tuple[Chunk, Dict[str, Any]]]]): + """Enrichment handler that performs vector similarity search using BigQuery. + + This handler enriches Chunks by finding similar vectors in a BigQuery table + using the VECTOR_SEARCH function. It supports batching requests for efficiency + and preserves the original Chunk metadata while adding the search results. + + Args: + project: GCP project ID containing the BigQuery dataset + vector_search_parameters: Configuration for the vector search query + min_batch_size: Minimum number of chunks to batch before processing + max_batch_size: Maximum number of chunks to process in one batch + **kwargs: Additional arguments passed to bigquery.Client + + Example: + ```python + params = BigQueryVectorSearchParameters(...) + handler = BigQueryVectorSearchEnrichmentHandler( + project='my-project', + vector_search_parameters=params, + min_batch_size=100, + max_batch_size=1000 + ) + + with beam.Pipeline() as p: + enriched = ( + p + | beam.Create([chunk1, chunk2]) + | beam.ParDo(handler) + ) + ``` + + The handler will: + 1. Batch incoming chunks according to batch size parameters + 2. Format and execute vector search query for each batch + 3. Join results back to original chunks + 4. Return tuples of (original_chunk, search_results) + """ + def __init__( + self, + project: str, + vector_search_parameters: BigQueryVectorSearchParameters, + *, + min_batch_size: int = 1, + max_batch_size: int = 1000, + **kwargs): + self.project = project + self.vector_search_parameters = vector_search_parameters + self.kwargs = kwargs + self._batching_kwargs = { + 'min_batch_size': min_batch_size, 'max_batch_size': max_batch_size + } + self.join_fn = join_fn + self.use_custom_types = True + + def __enter__(self): + self.client = bigquery.Client(project=self.project, **self.kwargs) + + def __call__(self, request: Union[Chunk, List[Chunk]], *args, + **kwargs) -> List[Tuple[Chunk, Dict[str, Any]]]: + """Process request(s) using BigQuery vector search. + + Args: + request: Single Chunk with embedding or list of Chunk's with + embeddings to process + + Returns: + Chunk(s) where chunk.metadata['enrichment_output'] contains the + data retrieved via BigQuery VECTOR_SEARCH. + """ + # Convert single request to list for uniform processing + requests = request if isinstance(request, list) else [request] + + # Generate and execute query + query = self.vector_search_parameters.format_query(requests) + query_job = self.client.query(query) + results = query_job.result() + + # Map results back to embeddings + id_to_embedding = {emb.id: emb for emb in requests} + response = [] + for result_row in results: + result_dict = dict(result_row.items()) + embedding = id_to_embedding[result_row.id] + response.append((embedding, result_dict)) + + return response + + def __exit__(self, exc_type, exc_val, exc_tb): + self.client.close() + + def batch_elements_kwargs(self) -> Dict[str, int]: + """Returns kwargs for beam.BatchElements.""" + return self._batching_kwargs + + +def join_fn(left: Embedding, right: Dict[str, Any]) -> Embedding: + left.metadata['enrichment_data'] = right + return left diff --git a/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search_it_test.py b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search_it_test.py new file mode 100644 index 000000000000..bec6adf9748a --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/enrichment/bigquery_vector_search_it_test.py @@ -0,0 +1,405 @@ +import logging +import secrets +import time +import unittest + +import apache_beam as beam +from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper +from apache_beam.io.gcp.internal.clients import bigquery +from apache_beam.ml.rag.enrichment.bigquery_vector_search import BigQueryVectorSearchEnrichmentHandler +from apache_beam.ml.rag.enrichment.bigquery_vector_search import BigQueryVectorSearchParameters +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.transforms.enrichment import Enrichment + +try: + from google.api_core.exceptions import BadRequest +except ImportError: + raise unittest.SkipTest('BigQuery dependencies not installed') + +_LOGGER = logging.getLogger(__name__) + + +class BigQueryVectorSearchIT(unittest.TestCase): + bigquery_dataset_id = 'python_vector_search_test_' + project = "dataflow-test" + + @classmethod + def setUpClass(cls): + cls.bigquery_client = BigQueryWrapper() + cls.dataset_id = '%s%d%s' % ( + cls.bigquery_dataset_id, int(time.time()), secrets.token_hex(3)) + cls.bigquery_client.get_or_create_dataset(cls.project, cls.dataset_id) + _LOGGER.info( + "Created dataset %s in project %s", cls.dataset_id, cls.project) + + @classmethod + def tearDownClass(cls): + request = bigquery.BigqueryDatasetsDeleteRequest( + projectId=cls.project, datasetId=cls.dataset_id, deleteContents=True) + try: + cls.bigquery_client.client.datasets.Delete(request) + except Exception: + _LOGGER.warning( + 'Failed to clean up dataset %s in project %s', + cls.dataset_id, + cls.project) + + +class TestBigQueryVectorSearchIT(BigQueryVectorSearchIT): + # Test data with embeddings + table_data = [{ + "id": "doc1", + "content": "This is a test document", + "embedding": [0.1, 0.2, 0.3], + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "id": "doc2", + "content": "Another test document", + "embedding": [0.2, 0.3, 0.4], + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "id": "doc3", + "content": "Un document de test", + "embedding": [0.3, 0.4, 0.5], + "metadata": [{ + "key": "language", "value": "fr" + }] + }] + + @classmethod + def create_table(cls, table_name): + fields = [('id', 'STRING'), ('content', 'STRING'), + ('embedding', 'FLOAT64', 'REPEATED'), + ( + 'metadata', + 'RECORD', + 'REPEATED', [('key', 'STRING'), ('value', 'STRING')])] + table_schema = bigquery.TableSchema() + for field_def in fields: + field = bigquery.TableFieldSchema() + field.name = field_def[0] + field.type = field_def[1] + if len(field_def) > 2: + field.mode = field_def[2] + if len(field_def) > 3: + for subfield_def in field_def[3]: + subfield = bigquery.TableFieldSchema() + subfield.name = subfield_def[0] + subfield.type = subfield_def[1] + field.fields.append(subfield) + table_schema.fields.append(field) + + table = bigquery.Table( + tableReference=bigquery.TableReference( + projectId=cls.project, datasetId=cls.dataset_id, + tableId=table_name), + schema=table_schema) + + request = bigquery.BigqueryTablesInsertRequest( + projectId=cls.project, datasetId=cls.dataset_id, table=table) + cls.bigquery_client.client.tables.Insert(request) + cls.bigquery_client.insert_rows( + cls.project, cls.dataset_id, table_name, cls.table_data) + cls.table_name = f"{cls.project}.{cls.dataset_id}.{table_name}" + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.create_table('vector_test') + + def test_basic_vector_search(self): + """Test basic vector similarity search.""" + test_chunks = [ + Chunk( + id="query1", + index=0, + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="test query"), + metadata={"language": "en"}) + ] + # Expected chunk will have enrichment_data in metadata + expected_chunks = [ + Chunk( + id="query1", + index=0, + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="test query"), + metadata={ + "language": "en", + "enrichment_data": { + "id": "query1", + "chunks": [{ + "content": "This is a test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "Another test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }) + ] + + params = BigQueryVectorSearchParameters( + table_name=self.table_name, + embedding_column='embedding', + columns=['content', 'metadata'], + neighbor_count=2, + metadata_restriction_template=( + "check_metadata(metadata, 'language', '{language}')")) + + handler = BigQueryVectorSearchEnrichmentHandler( + project=self.project, vector_search_parameters=params) + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + + assert_that(result, equal_to(expected_chunks)) + + def test_batched_metadata_filter_vector_search(self): + """Test vector similarity search with batching.""" + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="test query 1"), + metadata={"language": "en"}, + index=0), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4]), + content=Content(text="test query 2"), + metadata={"language": "en"}, + index=1), + Chunk( + id="query3", + embedding=Embedding(dense_embedding=[0.3, 0.4, 0.5]), + content=Content(text="test query 3"), + metadata={"language": "fr"}, + index=2) + ] + + params = BigQueryVectorSearchParameters( + table_name=self.table_name, + embedding_column='embedding', + columns=['content', 'metadata'], + neighbor_count=2, + metadata_restriction_template=( + "check_metadata(metadata, 'language', '{language}')")) + + handler = BigQueryVectorSearchEnrichmentHandler( + project=self.project, + vector_search_parameters=params, + min_batch_size=2, # Force batching + max_batch_size=2 # Process 2 chunks at a time + ) + + expected_chunks = [ + Chunk( + id="query1", + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], sparse_embedding=None), + content=Content(text="test query 1"), + metadata={ + "language": "en", + "enrichment_data": { + "id": "query1", + "chunks": [{ + "content": "This is a test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "Another test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }, + index=0), + Chunk( + id="query2", + embedding=Embedding( + dense_embedding=[0.2, 0.3, 0.4], sparse_embedding=None), + content=Content(text="test query 2"), + metadata={ + "language": "en", + "enrichment_data": { + "id": "query2", + "chunks": [{ + "content": "Another test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "This is a test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }, + index=1), + Chunk( + id="query3", + embedding=Embedding( + dense_embedding=[0.3, 0.4, 0.5], sparse_embedding=None), + content=Content(text="test query 3"), + metadata={ + "language": "fr", + "enrichment_data": { + "id": "query3", + "chunks": [{ + "content": "Un document de test", + "metadata": [{ + "key": "language", "value": "fr" + }] + }] + } + }, + index=2) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + + assert_that(result, equal_to(expected_chunks)) + + def test_euclidean_distance_search(self): + """Test vector similarity search using Euclidean distance.""" + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="test query 1"), + metadata={"language": "en"}, + index=0), + Chunk( + id="query2", + embedding=Embedding(dense_embedding=[0.2, 0.3, 0.4]), + content=Content(text="test query 2"), + metadata={"language": "en"}, + index=1) + ] + + params = BigQueryVectorSearchParameters( + table_name=self.table_name, + embedding_column='embedding', + columns=['content', 'metadata'], + neighbor_count=2, + metadata_restriction_template=( + "check_metadata(metadata, 'language', '{language}')"), + distance_type='EUCLIDEAN' # Use Euclidean distance + ) + + handler = BigQueryVectorSearchEnrichmentHandler( + project=self.project, + vector_search_parameters=params, + min_batch_size=2, + max_batch_size=2) + + expected_chunks = [ + Chunk( + id="query1", + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], sparse_embedding=None), + content=Content(text="test query 1"), + metadata={ + "language": "en", + "enrichment_data": { + "id": "query1", + "chunks": [{ + "content": "This is a test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "Another test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }, + index=0), + Chunk( + id="query2", + embedding=Embedding( + dense_embedding=[0.2, 0.3, 0.4], sparse_embedding=None), + content=Content(text="test query 2"), + metadata={ + "language": "en", + "enrichment_data": { + "id": "query2", + "chunks": [{ + "content": "Another test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }, + { + "content": "This is a test document", + "metadata": [{ + "key": "language", "value": "en" + }] + }] + } + }, + index=1) + ] + + with TestPipeline(is_integration_test=True) as p: + result = (p | beam.Create(test_chunks) | Enrichment(handler)) + + assert_that(result, equal_to(expected_chunks)) + + def test_invalid_query(self): + """Test error handling for invalid queries.""" + test_chunks = [ + Chunk( + id="query1", + embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]), + content=Content(text="test query"), + metadata={"language": "en"}) + ] + + params = BigQueryVectorSearchParameters( + table_name=self.table_name, + embedding_column='nonexistent_column', # Invalid column + columns=['content'], + neighbor_count=1, + metadata_restriction_template=( + "language = '{language}'" # Invalid template + ) + ) + + handler = BigQueryVectorSearchEnrichmentHandler( + project=self.project, vector_search_parameters=params) + + with self.assertRaises(BadRequest): + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | Enrichment(handler)) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main()