Skip to content

Commit

Permalink
feat: Adding Weaviate Vector DB option for RAG corpuses to SDK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671132855
  • Loading branch information
speedstorm1 authored and copybara-github committed Sep 4, 2024
1 parent c29fa5d commit 9b28202
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 34 deletions.
65 changes: 54 additions & 11 deletions tests/unit/vertex_rag/test_rag_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@

from google.cloud import aiplatform

from vertexai.preview import rag
from vertexai.preview.rag import (
EmbeddingModelConfig,
RagCorpus,
RagFile,
RagResource,
SlackChannelsSource,
SlackChannel,
JiraSource,
JiraQuery,
Weaviate,
)
from google.cloud.aiplatform_v1beta1 import (
GoogleDriveSource,
RagFileChunkingConfig,
Expand All @@ -32,6 +42,7 @@
SlackSource as GapicSlackSource,
RagContexts,
RetrieveContextsResponse,
RagVectorDbConfig,
)
from google.cloud.aiplatform_v1beta1.types import api_auth
from google.protobuf import timestamp_pb2
Expand All @@ -47,6 +58,16 @@
TEST_RAG_CORPUS_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragCorpora/{TEST_RAG_CORPUS_ID}"

# RagCorpus
TEST_WEAVIATE_HTTP_ENDPOINT = "test.weaviate.com"
TEST_WEAVIATE_COLLECTION_NAME = "test-collection"
TEST_WEAVIATE_API_KEY_SECRET_VERSION = (
"projects/test-project/secrets/test-secret/versions/1"
)
TEST_WEAVIATE_CONFIG = Weaviate(
weaviate_http_endpoint=TEST_WEAVIATE_HTTP_ENDPOINT,
collection_name=TEST_WEAVIATE_COLLECTION_NAME,
api_key=TEST_WEAVIATE_API_KEY_SECRET_VERSION,
)
TEST_GAPIC_RAG_CORPUS = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
Expand All @@ -57,15 +78,37 @@
TEST_PROJECT, TEST_REGION
)
)
TEST_EMBEDDING_MODEL_CONFIG = rag.EmbeddingModelConfig(
TEST_GAPIC_RAG_CORPUS_WEAVIATE = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
rag_vector_db_config=RagVectorDbConfig(
weaviate=RagVectorDbConfig.Weaviate(
http_endpoint=TEST_WEAVIATE_HTTP_ENDPOINT,
collection_name=TEST_WEAVIATE_COLLECTION_NAME,
),
api_auth=api_auth.ApiAuth(
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
api_key_secret_version=TEST_WEAVIATE_API_KEY_SECRET_VERSION
),
),
),
)
TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig(
publisher_model="publishers/google/models/textembedding-gecko",
)
TEST_RAG_CORPUS = rag.RagCorpus(
TEST_RAG_CORPUS = RagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
embedding_model_config=TEST_EMBEDDING_MODEL_CONFIG,
)
TEST_RAG_CORPUS_WEAVIATE = RagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
vector_db=TEST_WEAVIATE_CONFIG,
)
TEST_PAGE_TOKEN = "test-page-token"

# RagFiles
Expand Down Expand Up @@ -165,7 +208,7 @@
display_name=TEST_FILE_DISPLAY_NAME,
description=TEST_FILE_DESCRIPTION,
)
TEST_RAG_FILE = rag.RagFile(
TEST_RAG_FILE = RagFile(
name=TEST_RAG_FILE_RESOURCE_NAME,
display_name=TEST_FILE_DISPLAY_NAME,
description=TEST_FILE_DESCRIPTION,
Expand All @@ -183,15 +226,15 @@
TEST_SLACK_API_KEY_SECRET_VERSION_2 = (
"projects/test-project/secrets/test-secret/versions/2"
)
TEST_SLACK_SOURCE = rag.SlackChannelsSource(
TEST_SLACK_SOURCE = SlackChannelsSource(
channels=[
rag.SlackChannel(
SlackChannel(
channel_id=TEST_SLACK_CHANNEL_ID,
api_key=TEST_SLACK_API_KEY_SECRET_VERSION,
start_time=TEST_SLACK_START_TIME,
end_time=TEST_SLACK_END_TIME,
),
rag.SlackChannel(
SlackChannel(
channel_id=TEST_SLACK_CHANNEL_ID_2,
api_key=TEST_SLACK_API_KEY_SECRET_VERSION_2,
),
Expand Down Expand Up @@ -241,9 +284,9 @@
TEST_JIRA_API_KEY_SECRET_VERSION = (
"projects/test-project/secrets/test-secret/versions/1"
)
TEST_JIRA_SOURCE = rag.JiraSource(
TEST_JIRA_SOURCE = JiraSource(
queries=[
rag.JiraQuery(
JiraQuery(
email=TEST_JIRA_EMAIL,
jira_projects=[TEST_JIRA_PROJECT],
custom_queries=[TEST_JIRA_CUSTOM_QUERY],
Expand Down Expand Up @@ -286,11 +329,11 @@
]
)
TEST_RETRIEVAL_RESPONSE = RetrieveContextsResponse(contexts=TEST_CONTEXTS)
TEST_RAG_RESOURCE = rag.RagResource(
TEST_RAG_RESOURCE = RagResource(
rag_corpus=TEST_RAG_CORPUS_RESOURCE_NAME,
rag_file_ids=[TEST_RAG_FILE_ID],
)
TEST_RAG_RESOURCE_INVALID_NAME = rag.RagResource(
TEST_RAG_RESOURCE_INVALID_NAME = RagResource(
rag_corpus="213lkj-1/23jkl/",
rag_file_ids=[TEST_RAG_FILE_ID],
)
25 changes: 25 additions & 0 deletions tests/unit/vertex_rag/test_rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ def create_rag_corpus_mock():
yield create_rag_corpus_mock


@pytest.fixture
def create_rag_corpus_mock_weaviate():
with mock.patch.object(
VertexRagDataServiceClient,
"create_rag_corpus",
) as create_rag_corpus_mock_weaviate:
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
create_rag_corpus_lro_mock.done.return_value = True
create_rag_corpus_lro_mock.result.return_value = (
tc.TEST_GAPIC_RAG_CORPUS_WEAVIATE
)
create_rag_corpus_mock_weaviate.return_value = create_rag_corpus_lro_mock
yield create_rag_corpus_mock_weaviate


@pytest.fixture
def list_rag_corpora_pager_mock():
with mock.patch.object(
Expand Down Expand Up @@ -141,6 +156,7 @@ def list_rag_files_pager_mock():
def rag_corpus_eq(returned_corpus, expected_corpus):
assert returned_corpus.name == expected_corpus.name
assert returned_corpus.display_name == expected_corpus.display_name
assert returned_corpus.vector_db.__eq__(expected_corpus.vector_db)


def rag_file_eq(returned_file, expected_file):
Expand Down Expand Up @@ -191,6 +207,15 @@ def test_create_corpus_success(self):

rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS)

@pytest.mark.usefixtures("create_rag_corpus_mock_weaviate")
def test_create_corpus_weaviate_success(self):
rag_corpus = rag.create_corpus(
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
vector_db=tc.TEST_WEAVIATE_CONFIG,
)

rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE)

@pytest.mark.usefixtures("rag_data_client_mock_exception")
def test_create_corpus_failure(self):
with pytest.raises(RuntimeError) as e:
Expand Down
34 changes: 18 additions & 16 deletions vertexai/preview/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,36 +38,38 @@
)
from vertexai.preview.rag.utils.resources import (
EmbeddingModelConfig,
JiraSource,
JiraQuery,
JiraSource,
RagCorpus,
RagFile,
RagResource,
SlackChannel,
SlackChannelsSource,
Weaviate,
)


__all__ = (
"EmbeddingModelConfig",
"JiraQuery",
"JiraSource",
"RagCorpus",
"RagFile",
"RagResource",
"Retrieval",
"SlackChannel",
"SlackChannelsSource",
"VertexRagStore",
"Weaviate",
"create_corpus",
"list_corpora",
"get_corpus",
"delete_corpus",
"upload_file",
"delete_file",
"get_corpus",
"get_file",
"import_files",
"import_files_async",
"get_file",
"list_corpora",
"list_files",
"delete_file",
"retrieval_query",
"EmbeddingModelConfig",
"Retrieval",
"VertexRagStore",
"RagResource",
"RagFile",
"RagCorpus",
"JiraSource",
"JiraQuery",
"SlackChannel",
"SlackChannelsSource",
"upload_file",
)
15 changes: 12 additions & 3 deletions vertexai/preview/rag/rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@
RagCorpus,
RagFile,
SlackChannelsSource,
Weaviate,
)


def create_corpus(
display_name: Optional[str] = None,
description: Optional[str] = None,
embedding_model_config: Optional[EmbeddingModelConfig] = None,
vector_db: Optional[Weaviate] = None,
) -> RagCorpus:
"""Creates a new RagCorpus resource.
Expand All @@ -76,6 +78,8 @@ def create_corpus(
consist of any UTF-8 characters.
description: The description of the RagCorpus.
embedding_model_config: The embedding model config.
vector_db: The vector db config of the RagCorpus. If unspecified, the
default database Spanner is used.
Returns:
RagCorpus.
Raises:
Expand All @@ -88,9 +92,14 @@ def create_corpus(

rag_corpus = GapicRagCorpus(display_name=display_name, description=description)
if embedding_model_config:
rag_corpus = _gapic_utils.set_embedding_model_config(
embedding_model_config,
rag_corpus,
_gapic_utils.set_embedding_model_config(
embedding_model_config=embedding_model_config,
rag_corpus=rag_corpus,
)
if vector_db is not None:
_gapic_utils.set_vector_db(
vector_db=vector_db,
rag_corpus=rag_corpus,
)

request = CreateRagCorpusRequest(
Expand Down
48 changes: 44 additions & 4 deletions vertexai/preview/rag/utils/_gapic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
RagFile as GapicRagFile,
SlackSource as GapicSlackSource,
JiraSource as GapicJiraSource,
RagVectorDbConfig,
)
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform.utils import (
Expand All @@ -41,6 +42,7 @@
RagFile,
SlackChannelsSource,
JiraSource,
Weaviate,
)


Expand Down Expand Up @@ -93,21 +95,36 @@ def convert_gapic_to_embedding_model_config(
return embedding_model_config


def convert_gapic_to_vector_db(
gapic_vector_db: RagVectorDbConfig,
) -> Weaviate:
"""Convert Gapic RagVectorDbConfig to Weaviate."""
if gapic_vector_db.__contains__("weaviate"):
return Weaviate(
weaviate_http_endpoint=gapic_vector_db.weaviate.http_endpoint,
collection_name=gapic_vector_db.weaviate.collection_name,
api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version,
)
else:
return None


def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus:
""" "Convert GapicRagCorpus to RagCorpus."""
"""Convert GapicRagCorpus to RagCorpus."""
rag_corpus = RagCorpus(
name=gapic_rag_corpus.name,
display_name=gapic_rag_corpus.display_name,
description=gapic_rag_corpus.description,
embedding_model_config=convert_gapic_to_embedding_model_config(
gapic_rag_corpus.rag_embedding_model_config
),
vector_db=convert_gapic_to_vector_db(gapic_rag_corpus.rag_vector_db_config),
)
return rag_corpus


def convert_gapic_to_rag_file(gapic_rag_file: GapicRagFile) -> RagFile:
""" "Convert GapicRagFile to RagFile."""
"""Convert GapicRagFile to RagFile."""
rag_file = RagFile(
name=gapic_rag_file.name,
display_name=gapic_rag_file.display_name,
Expand Down Expand Up @@ -315,7 +332,7 @@ def get_file_name(
def set_embedding_model_config(
embedding_model_config: EmbeddingModelConfig,
rag_corpus: GapicRagCorpus,
) -> GapicRagCorpus:
) -> None:
if embedding_model_config.publisher_model and embedding_model_config.endpoint:
raise ValueError("publisher_model and endpoint cannot be set at the same time.")
if (
Expand Down Expand Up @@ -371,4 +388,27 @@ def set_embedding_model_config(
"endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`"
)

return rag_corpus

def set_vector_db(
vector_db: Weaviate,
rag_corpus: GapicRagCorpus,
) -> None:
"""Sets the vector db configuration for the rag corpus."""
if isinstance(vector_db, Weaviate):
http_endpoint = vector_db.weaviate_http_endpoint
collection_name = vector_db.collection_name
api_key = vector_db.api_key

rag_corpus.rag_vector_db_config = RagVectorDbConfig(
weaviate=RagVectorDbConfig.Weaviate(
http_endpoint=http_endpoint,
collection_name=collection_name,
),
api_auth=api_auth.ApiAuth(
api_key_config=api_auth.ApiAuth.ApiKeyConfig(
api_key_secret_version=api_key
),
),
)
else:
raise TypeError("vector_db must be a Weaviate.")
Loading

0 comments on commit 9b28202

Please sign in to comment.