diff --git a/tests/unit/vertex_rag/test_rag_constants.py b/tests/unit/vertex_rag/test_rag_constants.py index a525d2784f..cd2a74e30c 100644 --- a/tests/unit/vertex_rag/test_rag_constants.py +++ b/tests/unit/vertex_rag/test_rag_constants.py @@ -18,7 +18,17 @@ from google.cloud import aiplatform -from vertexai.preview import rag +from vertexai.preview.rag import ( + EmbeddingModelConfig, + RagCorpus, + RagFile, + RagResource, + SlackChannelsSource, + SlackChannel, + JiraSource, + JiraQuery, + Weaviate, +) from google.cloud.aiplatform_v1beta1 import ( GoogleDriveSource, RagFileChunkingConfig, @@ -32,6 +42,7 @@ SlackSource as GapicSlackSource, RagContexts, RetrieveContextsResponse, + RagVectorDbConfig, ) from google.cloud.aiplatform_v1beta1.types import api_auth from google.protobuf import timestamp_pb2 @@ -47,6 +58,16 @@ TEST_RAG_CORPUS_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragCorpora/{TEST_RAG_CORPUS_ID}" # RagCorpus +TEST_WEAVIATE_HTTP_ENDPOINT = "test.weaviate.com" +TEST_WEAVIATE_COLLECTION_NAME = "test-collection" +TEST_WEAVIATE_API_KEY_SECRET_VERSION = ( + "projects/test-project/secrets/test-secret/versions/1" +) +TEST_WEAVIATE_CONFIG = Weaviate( + weaviate_http_endpoint=TEST_WEAVIATE_HTTP_ENDPOINT, + collection_name=TEST_WEAVIATE_COLLECTION_NAME, + api_key=TEST_WEAVIATE_API_KEY_SECRET_VERSION, +) TEST_GAPIC_RAG_CORPUS = GapicRagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, @@ -57,15 +78,37 @@ TEST_PROJECT, TEST_REGION ) ) -TEST_EMBEDDING_MODEL_CONFIG = rag.EmbeddingModelConfig( +TEST_GAPIC_RAG_CORPUS_WEAVIATE = GapicRagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + rag_vector_db_config=RagVectorDbConfig( + weaviate=RagVectorDbConfig.Weaviate( + http_endpoint=TEST_WEAVIATE_HTTP_ENDPOINT, + collection_name=TEST_WEAVIATE_COLLECTION_NAME, + ), + api_auth=api_auth.ApiAuth( + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=TEST_WEAVIATE_API_KEY_SECRET_VERSION + ), + ), + ), +) +TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig( publisher_model="publishers/google/models/textembedding-gecko", ) -TEST_RAG_CORPUS = rag.RagCorpus( +TEST_RAG_CORPUS = RagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, description=TEST_CORPUS_DISCRIPTION, embedding_model_config=TEST_EMBEDDING_MODEL_CONFIG, ) +TEST_RAG_CORPUS_WEAVIATE = RagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + vector_db=TEST_WEAVIATE_CONFIG, +) TEST_PAGE_TOKEN = "test-page-token" # RagFiles @@ -165,7 +208,7 @@ display_name=TEST_FILE_DISPLAY_NAME, description=TEST_FILE_DESCRIPTION, ) -TEST_RAG_FILE = rag.RagFile( +TEST_RAG_FILE = RagFile( name=TEST_RAG_FILE_RESOURCE_NAME, display_name=TEST_FILE_DISPLAY_NAME, description=TEST_FILE_DESCRIPTION, @@ -183,15 +226,15 @@ TEST_SLACK_API_KEY_SECRET_VERSION_2 = ( "projects/test-project/secrets/test-secret/versions/2" ) -TEST_SLACK_SOURCE = rag.SlackChannelsSource( +TEST_SLACK_SOURCE = SlackChannelsSource( channels=[ - rag.SlackChannel( + SlackChannel( channel_id=TEST_SLACK_CHANNEL_ID, api_key=TEST_SLACK_API_KEY_SECRET_VERSION, start_time=TEST_SLACK_START_TIME, end_time=TEST_SLACK_END_TIME, ), - rag.SlackChannel( + SlackChannel( channel_id=TEST_SLACK_CHANNEL_ID_2, api_key=TEST_SLACK_API_KEY_SECRET_VERSION_2, ), @@ -241,9 +284,9 @@ TEST_JIRA_API_KEY_SECRET_VERSION = ( "projects/test-project/secrets/test-secret/versions/1" ) -TEST_JIRA_SOURCE = rag.JiraSource( +TEST_JIRA_SOURCE = JiraSource( queries=[ - rag.JiraQuery( + JiraQuery( email=TEST_JIRA_EMAIL, jira_projects=[TEST_JIRA_PROJECT], custom_queries=[TEST_JIRA_CUSTOM_QUERY], @@ -286,11 +329,11 @@ ] ) TEST_RETRIEVAL_RESPONSE = RetrieveContextsResponse(contexts=TEST_CONTEXTS) -TEST_RAG_RESOURCE = rag.RagResource( +TEST_RAG_RESOURCE = RagResource( rag_corpus=TEST_RAG_CORPUS_RESOURCE_NAME, rag_file_ids=[TEST_RAG_FILE_ID], ) -TEST_RAG_RESOURCE_INVALID_NAME = rag.RagResource( +TEST_RAG_RESOURCE_INVALID_NAME = RagResource( rag_corpus="213lkj-1/23jkl/", rag_file_ids=[TEST_RAG_FILE_ID], ) diff --git a/tests/unit/vertex_rag/test_rag_data.py b/tests/unit/vertex_rag/test_rag_data.py index 2b789d6513..fd243920b1 100644 --- a/tests/unit/vertex_rag/test_rag_data.py +++ b/tests/unit/vertex_rag/test_rag_data.py @@ -47,6 +47,21 @@ def create_rag_corpus_mock(): yield create_rag_corpus_mock +@pytest.fixture +def create_rag_corpus_mock_weaviate(): + with mock.patch.object( + VertexRagDataServiceClient, + "create_rag_corpus", + ) as create_rag_corpus_mock_weaviate: + create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + create_rag_corpus_lro_mock.done.return_value = True + create_rag_corpus_lro_mock.result.return_value = ( + tc.TEST_GAPIC_RAG_CORPUS_WEAVIATE + ) + create_rag_corpus_mock_weaviate.return_value = create_rag_corpus_lro_mock + yield create_rag_corpus_mock_weaviate + + @pytest.fixture def list_rag_corpora_pager_mock(): with mock.patch.object( @@ -141,6 +156,7 @@ def list_rag_files_pager_mock(): def rag_corpus_eq(returned_corpus, expected_corpus): assert returned_corpus.name == expected_corpus.name assert returned_corpus.display_name == expected_corpus.display_name + assert returned_corpus.vector_db.__eq__(expected_corpus.vector_db) def rag_file_eq(returned_file, expected_file): @@ -191,6 +207,15 @@ def test_create_corpus_success(self): rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS) + @pytest.mark.usefixtures("create_rag_corpus_mock_weaviate") + def test_create_corpus_weaviate_success(self): + rag_corpus = rag.create_corpus( + display_name=tc.TEST_CORPUS_DISPLAY_NAME, + vector_db=tc.TEST_WEAVIATE_CONFIG, + ) + + rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE) + @pytest.mark.usefixtures("rag_data_client_mock_exception") def test_create_corpus_failure(self): with pytest.raises(RuntimeError) as e: diff --git a/vertexai/preview/rag/__init__.py b/vertexai/preview/rag/__init__.py index fff380b359..56590fed0e 100644 --- a/vertexai/preview/rag/__init__.py +++ b/vertexai/preview/rag/__init__.py @@ -38,36 +38,38 @@ ) from vertexai.preview.rag.utils.resources import ( EmbeddingModelConfig, - JiraSource, JiraQuery, + JiraSource, RagCorpus, RagFile, RagResource, SlackChannel, SlackChannelsSource, + Weaviate, ) __all__ = ( + "EmbeddingModelConfig", + "JiraQuery", + "JiraSource", + "RagCorpus", + "RagFile", + "RagResource", + "Retrieval", + "SlackChannel", + "SlackChannelsSource", + "VertexRagStore", + "Weaviate", "create_corpus", - "list_corpora", - "get_corpus", "delete_corpus", - "upload_file", + "delete_file", + "get_corpus", + "get_file", "import_files", "import_files_async", - "get_file", + "list_corpora", "list_files", - "delete_file", "retrieval_query", - "EmbeddingModelConfig", - "Retrieval", - "VertexRagStore", - "RagResource", - "RagFile", - "RagCorpus", - "JiraSource", - "JiraQuery", - "SlackChannel", - "SlackChannelsSource", + "upload_file", ) diff --git a/vertexai/preview/rag/rag_data.py b/vertexai/preview/rag/rag_data.py index 32da89c0a3..0983037173 100644 --- a/vertexai/preview/rag/rag_data.py +++ b/vertexai/preview/rag/rag_data.py @@ -48,6 +48,7 @@ RagCorpus, RagFile, SlackChannelsSource, + Weaviate, ) @@ -55,6 +56,7 @@ def create_corpus( display_name: Optional[str] = None, description: Optional[str] = None, embedding_model_config: Optional[EmbeddingModelConfig] = None, + vector_db: Optional[Weaviate] = None, ) -> RagCorpus: """Creates a new RagCorpus resource. @@ -76,6 +78,8 @@ def create_corpus( consist of any UTF-8 characters. description: The description of the RagCorpus. embedding_model_config: The embedding model config. + vector_db: The vector db config of the RagCorpus. If unspecified, the + default database Spanner is used. Returns: RagCorpus. Raises: @@ -88,9 +92,14 @@ def create_corpus( rag_corpus = GapicRagCorpus(display_name=display_name, description=description) if embedding_model_config: - rag_corpus = _gapic_utils.set_embedding_model_config( - embedding_model_config, - rag_corpus, + _gapic_utils.set_embedding_model_config( + embedding_model_config=embedding_model_config, + rag_corpus=rag_corpus, + ) + if vector_db is not None: + _gapic_utils.set_vector_db( + vector_db=vector_db, + rag_corpus=rag_corpus, ) request = CreateRagCorpusRequest( diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py index 6cc1c1d316..640fd8c5f0 100644 --- a/vertexai/preview/rag/utils/_gapic_utils.py +++ b/vertexai/preview/rag/utils/_gapic_utils.py @@ -28,6 +28,7 @@ RagFile as GapicRagFile, SlackSource as GapicSlackSource, JiraSource as GapicJiraSource, + RagVectorDbConfig, ) from google.cloud.aiplatform import initializer from google.cloud.aiplatform.utils import ( @@ -41,6 +42,7 @@ RagFile, SlackChannelsSource, JiraSource, + Weaviate, ) @@ -93,8 +95,22 @@ def convert_gapic_to_embedding_model_config( return embedding_model_config +def convert_gapic_to_vector_db( + gapic_vector_db: RagVectorDbConfig, +) -> Weaviate: + """Convert Gapic RagVectorDbConfig to Weaviate.""" + if gapic_vector_db.__contains__("weaviate"): + return Weaviate( + weaviate_http_endpoint=gapic_vector_db.weaviate.http_endpoint, + collection_name=gapic_vector_db.weaviate.collection_name, + api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version, + ) + else: + return None + + def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus: - """ "Convert GapicRagCorpus to RagCorpus.""" + """Convert GapicRagCorpus to RagCorpus.""" rag_corpus = RagCorpus( name=gapic_rag_corpus.name, display_name=gapic_rag_corpus.display_name, @@ -102,12 +118,13 @@ def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus: embedding_model_config=convert_gapic_to_embedding_model_config( gapic_rag_corpus.rag_embedding_model_config ), + vector_db=convert_gapic_to_vector_db(gapic_rag_corpus.rag_vector_db_config), ) return rag_corpus def convert_gapic_to_rag_file(gapic_rag_file: GapicRagFile) -> RagFile: - """ "Convert GapicRagFile to RagFile.""" + """Convert GapicRagFile to RagFile.""" rag_file = RagFile( name=gapic_rag_file.name, display_name=gapic_rag_file.display_name, @@ -315,7 +332,7 @@ def get_file_name( def set_embedding_model_config( embedding_model_config: EmbeddingModelConfig, rag_corpus: GapicRagCorpus, -) -> GapicRagCorpus: +) -> None: if embedding_model_config.publisher_model and embedding_model_config.endpoint: raise ValueError("publisher_model and endpoint cannot be set at the same time.") if ( @@ -371,4 +388,27 @@ def set_embedding_model_config( "endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`" ) - return rag_corpus + +def set_vector_db( + vector_db: Weaviate, + rag_corpus: GapicRagCorpus, +) -> None: + """Sets the vector db configuration for the rag corpus.""" + if isinstance(vector_db, Weaviate): + http_endpoint = vector_db.weaviate_http_endpoint + collection_name = vector_db.collection_name + api_key = vector_db.api_key + + rag_corpus.rag_vector_db_config = RagVectorDbConfig( + weaviate=RagVectorDbConfig.Weaviate( + http_endpoint=http_endpoint, + collection_name=collection_name, + ), + api_auth=api_auth.ApiAuth( + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=api_key + ), + ), + ) + else: + raise TypeError("vector_db must be a Weaviate.") diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py index 1b5af451f6..aad7bad35d 100644 --- a/vertexai/preview/rag/utils/resources.py +++ b/vertexai/preview/rag/utils/resources.py @@ -69,6 +69,22 @@ class EmbeddingModelConfig: model_version_id: Optional[str] = None +@dataclasses.dataclass +class Weaviate: + """Weaviate. + + Attributes: + weaviate_http_endpoint: The Weaviate DB instance HTTP endpoint + collection_name: The corresponding Weaviate collection this corpus maps to + api_key: The SecretManager resource name for the Weaviate DB API token. Format: + ``projects/{project}/secrets/{secret}/versions/{version}`` + """ + + weaviate_http_endpoint: str + collection_name: str + api_key: str + + @dataclasses.dataclass class RagCorpus: """RAG corpus(output only). @@ -78,12 +94,15 @@ class RagCorpus: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}`` display_name: Display name that was configured at client side. description: The description of the RagCorpus. + embedding_model_config: The embedding model config of the RagCorpus. + vector_db: The Vector DB of the RagCorpus. """ name: Optional[str] = None display_name: Optional[str] = None description: Optional[str] = None embedding_model_config: Optional[EmbeddingModelConfig] = None + vector_db: Optional[Weaviate] = None @dataclasses.dataclass