Skip to content

Commit

Permalink
community[patch]: Make cohere_api_key a SecretStr (langchain-ai#12188)
Browse files Browse the repository at this point in the history
This PR makes `cohere_api_key` in `llms/cohere` a SecretStr, so that the
API Key is not leaked when `Cohere.cohere_api_key` is represented as a
string.

---------

Signed-off-by: Arun <[email protected]>
Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
2 people authored and gkorland committed Mar 30, 2024
1 parent ea2de9d commit 9e41626
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
19 changes: 12 additions & 7 deletions libs/community/langchain_community/llms/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
)
from langchain_core.language_models.llms import LLM
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from tenacity import (
before_sleep_log,
retry,
Expand Down Expand Up @@ -73,7 +73,8 @@ class BaseCohere(Serializable):
temperature: float = 0.75
"""A non-negative float that tunes the degree of randomness in generation."""

cohere_api_key: Optional[str] = None
cohere_api_key: Optional[SecretStr] = None
"""Cohere API key. If not provided, will be read from the environment variable."""

stop: Optional[List[str]] = None

Expand All @@ -94,13 +95,17 @@ def validate_environment(cls, values: Dict) -> Dict:
"Please install it with `pip install cohere`."
)
else:
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
values["cohere_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "cohere_api_key", "COHERE_API_KEY")
)
client_name = values["user_agent"]
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
values["client"] = cohere.Client(
api_key=values["cohere_api_key"].get_secret_value(),
client_name=client_name,
)
values["async_client"] = cohere.AsyncClient(
cohere_api_key, client_name=client_name
api_key=values["cohere_api_key"].get_secret_value(),
client_name=client_name,
)
return values

Expand Down
13 changes: 13 additions & 0 deletions libs/community/tests/integration_tests/llms/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from pathlib import Path

from langchain_core.pydantic_v1 import SecretStr
from pytest import MonkeyPatch

from langchain_community.llms.cohere import Cohere
from langchain_community.llms.loading import load_llm
from tests.integration_tests.llms.utils import assert_llm_equality
Expand All @@ -14,6 +17,16 @@ def test_cohere_call() -> None:
assert isinstance(output, str)


def test_cohere_api_key(monkeypatch: MonkeyPatch) -> None:
"""Test that cohere api key is a secret key."""
# test initialization from init
assert isinstance(Cohere(cohere_api_key="1").cohere_api_key, SecretStr)

# test initialization from env variable
monkeypatch.setenv("COHERE_API_KEY", "secret-api-key")
assert isinstance(Cohere().cohere_api_key, SecretStr)


def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an Cohere LLM."""
llm = Cohere(max_tokens=10)
Expand Down

0 comments on commit 9e41626

Please sign in to comment.