Skip to content

Commit

Permalink
feat: add support for Vertex AI Gemini Pro (#185)
Browse files Browse the repository at this point in the history
* update to LiteLLM 1.14.1 with Gemini support

* add third-party service credential handling pattern
  - add GCP credential handling for Vertex AI

* add Gemini Pro model settings

* docs: Vertex AI credentials creation.
  • Loading branch information
janaka authored Dec 14, 2023
1 parent d58c398 commit bbc207c
Show file tree
Hide file tree
Showing 13 changed files with 614 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/user-guide/configuration.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Configuration

- [LLM Service Configuration](./llm-config.md)
- [Configure Spaces](./config-spaces.md)
- [Configure File Storage Services](./file-storage-services.md)
33 changes: 33 additions & 0 deletions docs/user-guide/llm-config.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# LLM Configuration

Currently all supported LLM services are configured when Docq is deployed/started via environment variables.

These are documented in [/misc/docker.env.template](https://github.com/docqai/docq/blob/main/misc/docker.env.template).

Find a little bit more information on how to get these credentials from each respective service.

???+ warning
In production, set secrets values in the shell. e.g. `set SOME_API_KEY=<your secret api key value>`.

A tool like Infisical helps manage this process while keeping your secret values safe.

## OpenAI

- Sign up for an OpenAI account. Generate an API key

## Azure OpenAI

Assumes you have an Azure subscription

- If using the all in one ARM template to deploy Docq these env vars are set securely for you.
- If using and existing deployment or click-ops deployment. Login to your Azure console. Navigate to your Azure OpenAI deployment to find the required values.

## Vertex AI (PaLm2 and Gemini Pro)

Assumes you have a GCP account.

- Login to your GCP console.
- You can either select an existing GCP project or create a new one.
- Navigate to 'IAM and Admin' > 'Service Accounts'. Create a new service account. You may also use an existing one.
- Assign the SA account the 'Vertex AI User' role.
- switch to the 'Keys' tab > 'Add key' > 'Create new key'. This will generate a credentials.json file. Copy the entire content of this file and set the env var.
19 changes: 18 additions & 1 deletion misc/docker.env.template
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
STREAMLIT_SERVER_ADDRESS=0.0.0.0
STREAMLIT_SERVER_PORT=8501 #default
DOCQ_DATA=./.persisted/
DOCQ_OPENAI_API_KEY # ideally set value on shell, don't insert a value here because it's a secret.

DOCQ_COOKIE_HMAC_SECRET_KEY=cookie_password

## === #
# ideally set secret values on shell, don't insert a value here.
## === #

# LLM Services
DOCQ_OPENAI_API_KEY

DOCQ_GOOGLE_APPLICATION_CREDENTIALS_JSON # for VertexAI but can be used for any other GCP service.

DOCQ_AZURE_OPENAI_API_KEY1
DOCQ_AZURE_OPENAI_API_KEY2
DOCQ_AZURE_OPENAI_API_BASE # from your deployment
DOCQ_AZURE_OPENAI_API_VERSION # based version set in your deployment

TOKENIZERS_PARALLELISM = "True" # for HUGGINGFACE_OPTIMUM_BAAI


# SMTP Server Settings
DOCQ_SMTP_SERVER="smtp-relay.brevo.com" # SMTP server address
DOCQ_SMTP_PORT=587 # SMTP port tls
Expand Down
15 changes: 14 additions & 1 deletion misc/secrets.toml.template
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
DOCQ_DATA = "./.persisted/"
DOCQ_OPENAI_API_KEY = "YOUR-OPENAI-API-KEY"
DOCQ_COOKIE_HMAC_SECRET_KEY = "32_char_secret_used_to_encrypt"

# LLM Services
DOCQ_OPENAI_API_KEY = "YOUR-OPENAI-API-KEY"

DOCQ_GOOGLE_APPLICATION_CREDENTIALS_JSON = "for VertexAI but can be used for any other GCP service."

DOCQ_AZURE_OPENAI_API_KEY1
DOCQ_AZURE_OPENAI_API_KEY2
DOCQ_AZURE_OPENAI_API_BASE = "from your deployment"
DOCQ_AZURE_OPENAI_API_VERSION = "based version set in your deployment"

# for HUGGINGFACE_OPTIMUM_BAAI
TOKENIZERS_PARALLELISM = "True"


# SMTP Server Settings
DOCQ_SMTP_SERVER = "smtp-relay.brevo.com" # SMTP server address
DOCQ_SMTP_PORT = 587 # SMTP port tls
Expand Down
6 changes: 5 additions & 1 deletion mkdocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,18 @@ nav:
- "Deployment Scenarios": overview/deployment-scenarios.md
- "User Guides":
- "Getting Started": user-guide/getting-started.md
- "Configuration": user-guide/configuration.md
- "Configuration":
- "LLM Service Configuration": user-guide/llm-config.md
- "Configure Spaces": user-guides/config-spaces.md
- "Configure File Storage Services": user-guides/file-storage-services.md
- "Deploy to Streamlit": user-guide/deploy-to-streamlit.md
- "Deploy to Azure": user-guide/deploy-to-azure.md
- "Deploy to AWS": user-guide/deploy-to-aws.md
- "Deploy to GCP": user-guide/deploy-to-gcp.md
- "Configure Spaces": user-guide/config-spaces.md
- "Data Sources": user-guide/data-sources.md
- "Configure File Storage Services": user-guide/file-storage-services.md
- "LLM Service Configuration": user-guide/llm-config.md
- "Usage": user-guide/usage.md
- "FAQ": user-guide/faq.md
- "Developer Guides":
Expand Down
392 changes: 387 additions & 5 deletions poetry.lock

Large diffs are not rendered by default.

12 changes: 10 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "docq"
version = "0.7.1"
version = "0.7.2"
description = "Docq.AI - private and secure knowledge insight on your data."
authors = ["Docq.AI Team <[email protected]>"]
maintainers = ["Docq.AI Team <[email protected]>"]
Expand Down Expand Up @@ -47,9 +47,12 @@ google-api-python-client = "^2.104.0"
google-auth-httplib2 = "^0.1.1"
microsoftgraph-python = "^1.1.6"
llama-index = "^0.9.8.post1"
litellm = "^1.7.7"
pydantic = "^2.5.2"
mkdocs-material = "^9.4.14"
google-cloud-aiplatform = "^1.38.0"
litellm = "^1.14.1"
opentelemetry-instrumentation-httpx = "0.41b0"
opentelemetry-instrumentation-system-metrics = "0.41b0"

[tool.poetry.group.dev.dependencies]
pre-commit = "^2.18.1"
Expand Down Expand Up @@ -179,6 +182,11 @@ cmd = "opentelemetry-instrument --logs_exporter none streamlit run web/index.py
args = [{ name = "port", default = 8501, type = "integer" }]
env = { WATCHDOG_LOG_LEVEL = "ERROR", PYTHONPATH = "${PWD}/web/:${PWD}/source/:${PWD}/../docq-extensions/source/" }

[tool.poe.tasks.run-otel-infisical]
cmd = "infisical run --env=dev -- opentelemetry-instrument --logs_exporter none streamlit run web/index.py --server.port $port --browser.gatherUsageStats false --server.runOnSave true --server.fileWatcherType auto"
args = [{ name = "port", default = 8501, type = "integer" }]
env = { WATCHDOG_LOG_LEVEL = "ERROR", PYTHONPATH = "${PWD}/web/:${PWD}/source/:${PWD}/../docq-extensions/source/" }

[tool.poe.tasks.docker-build]
cmd = """
docker build
Expand Down
60 changes: 59 additions & 1 deletion source/docq/model_selection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from enum import Enum
from typing import Dict, Optional

import vertexai

from ..config import OrganisationSettingsKey
from ..manage_settings import get_organisation_settings

Expand All @@ -32,6 +34,8 @@ class ModelVendor(str, Enum):
AWS_BEDROCK_STABILITYAI = "AWS Bedrock StabilityAI"
AWS_SAGEMAKER_META = "AWS Sagemaker Meta"
HUGGINGFACE_OPTIMUM_BAAI = "HuggingFace Optimum BAAI"
GOOGLE_VERTEXAI_PALM2 = "Google VertexAI Palm2"
GOOGLE_VERTEXTAI_GEMINI_PRO = "Google VertexAI Gemini Pro"


class ModelCapability(str, Enum):
Expand All @@ -45,6 +49,8 @@ class ModelCapability(str, Enum):
SUMMARISATION = "Summarisation"
IMAGE = "Image"
AUDIO = "Audio"
TEXT = "Text"
VISION = "Vision"


@dataclass
Expand Down Expand Up @@ -74,7 +80,8 @@ class ModelUsageSettingsCollection:
"""Unique key for the model collection."""
model_usage_settings: Dict[ModelCapability, ModelUsageSettings]


#NOTE: we are using LiteLLM client via Llama Index as the LLM client interface. This means model names need to follow the LiteLLM naming convention.
# SoT https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json
LLM_MODEL_COLLECTIONS = {
"openai_latest": ModelUsageSettingsCollection(
name="OpenAI Latest",
Expand Down Expand Up @@ -134,6 +141,57 @@ class ModelUsageSettingsCollection:
),
},
),
"google_vertexai_palm2": ModelUsageSettingsCollection(
name="Google VertexAI Palm2 Latest",
key="google_vertexai_palm2_latest",
model_usage_settings={
ModelCapability.CHAT: ModelUsageSettings(
model_vendor=ModelVendor.GOOGLE_VERTEXAI_PALM2,
model_name="chat-bison@002",
model_capability=ModelCapability.CHAT,
),
ModelCapability.EMBEDDING: ModelUsageSettings(
model_vendor=ModelVendor.HUGGINGFACE_OPTIMUM_BAAI,
model_name="BAAI/bge-small-en-v1.5",
model_capability=ModelCapability.EMBEDDING,
license_="MIT",
citation="""@misc{bge_embedding,
title={C-Pack: Packaged Resources To Advance General Chinese Embedding},
author={Shitao Xiao and Zheng Liu and Peitian Zhang and Niklas Muennighoff},
year={2023},
eprint={2309.07597},
archivePrefix={arXiv},
primaryClass={cs.CL}
}""",
),
},
),
"google_vertexai_gemini_pro": ModelUsageSettingsCollection(
name="Google VertexAI Gemini Pro Latest",
key="google_vertexai_gemini_pro_latest",
model_usage_settings={
ModelCapability.CHAT: ModelUsageSettings(
model_vendor=ModelVendor.GOOGLE_VERTEXTAI_GEMINI_PRO,
model_name="gemini-pro",
model_capability=ModelCapability.CHAT,
temperature=0.0,
),
ModelCapability.EMBEDDING: ModelUsageSettings(
model_vendor=ModelVendor.HUGGINGFACE_OPTIMUM_BAAI,
model_name="BAAI/bge-small-en-v1.5",
model_capability=ModelCapability.EMBEDDING,
license_="MIT",
citation="""@misc{bge_embedding,
title={C-Pack: Packaged Resources To Advance General Chinese Embedding},
author={Shitao Xiao and Zheng Liu and Peitian Zhang and Niklas Muennighoff},
year={2023},
eprint={2309.07597},
archivePrefix={arXiv},
primaryClass={cs.CL}
}""",
),
},
),
}


Expand Down
5 changes: 3 additions & 2 deletions source/docq/services/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Docq services."""
from . import google_drive, ms_onedrive, smtp_service
from . import credential_utils, google_drive, ms_onedrive, smtp_service

__all__ = [
"google_drive",
"smtp_service",
"ms_onedrive"
"ms_onedrive",
"credential_utils",
]

def _init() -> None:
Expand Down
59 changes: 59 additions & 0 deletions source/docq/services/credential_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Utils to help handle credentials for various third-party services.
If you there are credentials that needs to be written to a file or set to a different environment variable, this is the place to do it.
If a credentials object/json needs to be constructed from values in env vars or other sources, this is the place to do it.
"""
import os
from typing import Optional

from opentelemetry import trace

from .google_drive import CREDENTIALS_KEY

tracer = trace.get_tracer(__name__)


def load_gcp_credentials_from_env_var(save_path: Optional[str] = None) -> bool:
"""Saves credentials json from an env var to a file for services that only accept a file path.
Args:
save_path (str): Path inc filename to save the credentials JSON to. Default `./.streamlit/gcp_credentials.json`
"""
span = trace.get_current_span()

DOCQ_GOOGLE_APPLICATION_CREDENTIALS_JSON = "DOCQ_GOOGLE_APPLICATION_CREDENTIALS_JSON" # noqa: N806

success = False
credentials_json = f"<env var {DOCQ_GOOGLE_APPLICATION_CREDENTIALS_JSON} not set>"
try:
credentials_json = os.environ[DOCQ_GOOGLE_APPLICATION_CREDENTIALS_JSON]
except KeyError as e:
success = False
span.add_event(f"Failed to access env var {DOCQ_GOOGLE_APPLICATION_CREDENTIALS_JSON}", attributes={"error": str(e)})
span.record_exception(e)

path = save_path or "./.streamlit/gcp_credentials.json"

try:
with open(path, "w") as f:
f.write(credentials_json)
success = True
span.add_event("Wrote GCP credentials to file successfully", attributes={"file_path": path})
except IOError as e:
success = False
span.add_event("Failed to write GCP credentials to file", attributes={"error": str(e), "file_path": path})
span.record_exception(e)

# note: this will only be available to the current process thread. coroutine code will not have access.
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = path # default GCP env var
os.environ[CREDENTIALS_KEY] = path # used by the google_drive.py module.

return success

def setup_all_service_credentials() -> None:
"""Setup all service credentials."""
with tracer.start_as_current_span("docq.services.credential_utils.setup_all_service_credentials") as span:
load_gcp_credentials_from_env_var()
span.add_event("GCP credentials loaded")
1 change: 1 addition & 0 deletions source/docq/services/google_drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload

#TODO: move this to a config factory in credential_utils.py. If possible to provide the json, do that.
CREDENTIALS_KEY = "DOCQ_GOOGLE_APPLICATION_CREDENTIALS"
REDIRECT_URL_KEY = "DOCQ_GOOGLE_AUTH_REDIRECT_URL"

Expand Down
3 changes: 2 additions & 1 deletion source/docq/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _config_logging() -> None:
"""Configure logging."""
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(process)d %(levelname)s %(message)s", force=True) # force over rides Otel (or other) logging config with this.


#FIXME: right now this will run everytime a user hits the home page. add a global lock using st.cache to make this only run once.
def init() -> None:
"""Initialize Docq."""
with tracer.start_as_current_span("docq.setup.init") as span:
Expand All @@ -36,6 +36,7 @@ def init() -> None:
manage_spaces._init()
manage_users._init()
services._init()
services.credential_utils.setup_all_service_credentials()
store._init()
manage_organisations._init_default_org_if_necessary()
manage_users._init_admin_if_necessary()
Expand Down
24 changes: 22 additions & 2 deletions source/docq/support/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
)
from llama_index.callbacks.base import CallbackManager
from llama_index.chat_engine import SimpleChatEngine
from llama_index.chat_engine.types import AGENT_CHAT_RESPONSE_TYPE, AgentChatResponse, ChatMode
from llama_index.embeddings import AzureOpenAIEmbedding, OpenAIEmbedding, OptimumEmbedding
from llama_index.chat_engine.types import AGENT_CHAT_RESPONSE_TYPE, AgentChatResponse
from llama_index.embeddings import AzureOpenAIEmbedding, GooglePaLMEmbedding, OpenAIEmbedding, OptimumEmbedding
from llama_index.embeddings.base import BaseEmbedding
from llama_index.indices.base import BaseIndex
from llama_index.indices.composability import ComposableGraph
Expand Down Expand Up @@ -96,6 +96,8 @@ def _init_local_models() -> None:

@tracer.start_as_current_span(name="_get_generation_model")
def _get_generation_model(model_settings_collection: ModelUsageSettingsCollection) -> LLM | None:
import litellm
litellm.telemetry = False
model = None
if model_settings_collection and model_settings_collection.model_usage_settings[ModelCapability.CHAT]:
chat_model_settings = model_settings_collection.model_usage_settings[ModelCapability.CHAT]
Expand Down Expand Up @@ -131,9 +133,27 @@ def _get_generation_model(model_settings_collection: ModelUsageSettingsCollectio
_env_missing = not bool(os.getenv("DOCQ_OPENAI_API_KEY"))
if _env_missing:
log.warning("Chat model: env var values missing")
elif chat_model_settings.model_vendor == ModelVendor.GOOGLE_VERTEXAI_PALM2:
# GCP project_id is coming from the credentials json.
model = LiteLLM(
temperature=chat_model_settings.temperature,
model=chat_model_settings.model_name,
callback_manager=_callback_manager,
)
elif chat_model_settings.model_vendor == ModelVendor.GOOGLE_VERTEXTAI_GEMINI_PRO:
# GCP project_id is coming from the credentials json.
model = LiteLLM(
temperature=chat_model_settings.temperature,
model=chat_model_settings.model_name,
callback_manager=_callback_manager,
max_tokens=2048,
kwargs={"telemetry":False}
)
litellm.vertex_location = "us-central1"
else:
raise ValueError("Chat model: model settings with a supported model vendor not found.")

model.max_retries = 3
return model


Expand Down

0 comments on commit bbc207c

Please sign in to comment.