-
Notifications
You must be signed in to change notification settings - Fork 205
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(vertex): add support for google vertex (#319)
- Loading branch information
1 parent
e1d0063
commit 3412bac
Showing
11 changed files
with
424 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import asyncio | ||
|
||
from anthropic import AnthropicVertex, AsyncAnthropicVertex | ||
|
||
|
||
def sync_client() -> None: | ||
print("------ Sync Vertex ------") | ||
|
||
client = AnthropicVertex() | ||
|
||
message = client.beta.messages.create( | ||
model="claude-instant-1p2", | ||
max_tokens=100, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Say hello there!", | ||
} | ||
], | ||
) | ||
print(message.model_dump_json(indent=2)) | ||
|
||
|
||
async def async_client() -> None: | ||
print("------ Async Vertex ------") | ||
|
||
client = AsyncAnthropicVertex() | ||
|
||
message = await client.beta.messages.create( | ||
model="claude-instant-1p2", | ||
max_tokens=1024, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Say hello there!", | ||
} | ||
], | ||
) | ||
print(message.model_dump_json(indent=2)) | ||
|
||
|
||
sync_client() | ||
asyncio.run(async_client()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,6 +69,7 @@ | |
"AI_PROMPT", | ||
] | ||
|
||
from .lib.vertex import * | ||
from .lib.streaming import * | ||
|
||
_setup_logging() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from ._google_auth import google_auth as google_auth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from ..._exceptions import AnthropicError | ||
|
||
INSTRUCTIONS = """ | ||
Anthropic error: missing required dependency `{library}`. | ||
$ pip install anthropic[{extra}] | ||
""" | ||
|
||
|
||
class MissingDependencyError(AnthropicError): | ||
def __init__(self, *, library: str, extra: str) -> None: | ||
super().__init__(INSTRUCTIONS.format(library=library, extra=extra)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any | ||
from typing_extensions import ClassVar, override | ||
|
||
from ._common import MissingDependencyError | ||
from ..._utils import LazyProxy | ||
|
||
if TYPE_CHECKING: | ||
import google.auth # type: ignore | ||
|
||
google_auth = google.auth | ||
|
||
|
||
class GoogleAuthProxy(LazyProxy[Any]): | ||
should_cache: ClassVar[bool] = True | ||
|
||
@override | ||
def __load__(self) -> Any: | ||
try: | ||
import google.auth # type: ignore | ||
except ImportError as err: | ||
raise MissingDependencyError(extra="vertex", library="google-auth") from err | ||
|
||
return google.auth | ||
|
||
|
||
if not TYPE_CHECKING: | ||
google_auth = GoogleAuthProxy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from ._client import AnthropicVertex as AnthropicVertex, AsyncAnthropicVertex as AsyncAnthropicVertex |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from .._extras import google_auth | ||
|
||
if TYPE_CHECKING: | ||
from google.auth.credentials import Credentials # type: ignore[import-untyped] | ||
|
||
# pyright: reportMissingTypeStubs=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false | ||
# google libraries don't provide types :/ | ||
|
||
# Note: these functions are blocking as they make HTTP requests, the async | ||
# client runs these functions in a separate thread to ensure they do not | ||
# cause synchronous blocking issues. | ||
|
||
|
||
def load_auth() -> tuple[Credentials, str]: | ||
from google.auth.transport.requests import Request # type: ignore[import-untyped] | ||
|
||
credentials, project_id = google_auth.default() | ||
credentials.refresh(Request()) | ||
|
||
if not project_id: | ||
raise ValueError("Could not resolve project_id") | ||
|
||
if not isinstance(project_id, str): | ||
raise TypeError(f"Expected project_id to be a str but got {type(project_id)}") | ||
|
||
return credentials, project_id | ||
|
||
|
||
def refresh_auth(credentials: Credentials) -> None: | ||
from google.auth.transport.requests import Request # type: ignore[import-untyped] | ||
|
||
credentials.refresh(Request()) |
Oops, something went wrong.