Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: mask api key for cerebriumai llm #14272

Merged
merged 5 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions libs/langchain/langchain/llms/cerebriumai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@
from typing import Any, Dict, List, Mapping, Optional

import requests
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env

logger = logging.getLogger(__name__)


class CerebriumAI(LLM):
"""CerebriumAI large language models.

To use, you should have the ``cerebrium`` python package installed, and the
environment variable ``CEREBRIUMAI_API_KEY`` set with your API key.
To use, you should have the ``cerebrium`` python package installed.
You should also have the environment variable ``CEREBRIUMAI_API_KEY``
set with your API key or pass it as a named argument in the constructor.

Any parameters that are valid to be passed to the call can be passed
in, even if not explicitly saved on this class.
Expand All @@ -25,7 +26,7 @@ class CerebriumAI(LLM):
.. code-block:: python

from langchain.llms import CerebriumAI
cerebrium = CerebriumAI(endpoint_url="")
cerebrium = CerebriumAI(endpoint_url="", cerebriumai_api_key="my-api-key")

"""

Expand All @@ -36,7 +37,7 @@ class CerebriumAI(LLM):
"""Holds any model parameters valid for `create` call not
explicitly specified."""

cerebriumai_api_key: Optional[str] = None
cerebriumai_api_key: Optional[SecretStr] = None

class Config:
"""Configuration for this pydantic config."""
Expand Down Expand Up @@ -64,8 +65,8 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
cerebriumai_api_key = get_from_dict_or_env(
values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY"
cerebriumai_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY")
)
values["cerebriumai_api_key"] = cerebriumai_api_key
return values
Expand All @@ -91,7 +92,8 @@ def _call(
**kwargs: Any,
) -> str:
headers: Dict = {
"Authorization": self.cerebriumai_api_key,
"Authorization": self.cerebriumai_api_key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm i dont think this is right, right? this expression will evaluate to a boolean?

and self.cerebriumai_api_key.get_secret_value(),
"Content-Type": "application/json",
}
params = self.model_kwargs or {}
Expand Down
33 changes: 33 additions & 0 deletions libs/langchain/tests/unit_tests/llms/test_cerebriumai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Test CerebriumAI llm"""


from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch

from langchain.llms.cerebriumai import CerebriumAI


def test_api_key_is_secret_string() -> None:
llm = CerebriumAI(cerebriumai_api_key="test-cerebriumai-api-key")
assert isinstance(llm.cerebriumai_api_key, SecretStr)


def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
llm = CerebriumAI(cerebriumai_api_key="secret-api-key")
print(llm.cerebriumai_api_key, end="")
captured = capsys.readouterr()

assert captured.out == "**********"
assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"


def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
monkeypatch.setenv("CEREBRIUMAI_API_KEY", "secret-api-key")
llm = CerebriumAI()
print(llm.cerebriumai_api_key, end="")
captured = capsys.readouterr()

assert captured.out == "**********"
assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"