From f3c75554db67d3f7bc6da211e3d9b01c55afb896 Mon Sep 17 00:00:00 2001 From: Yuchen Liang Date: Tue, 5 Dec 2023 00:01:55 -0500 Subject: [PATCH 1/5] feat: use secret str for cerebriumai api key Signed-off-by: Yuchen Liang --- libs/langchain/langchain/llms/cerebriumai.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/llms/cerebriumai.py b/libs/langchain/langchain/llms/cerebriumai.py index 0a162f5dfeaa4..fb74313c6792c 100644 --- a/libs/langchain/langchain/llms/cerebriumai.py +++ b/libs/langchain/langchain/llms/cerebriumai.py @@ -2,12 +2,12 @@ from typing import Any, Dict, List, Mapping, Optional import requests -from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.utils import get_from_dict_or_env +from langchain.utils import convert_to_secret_str, get_from_dict_or_env logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ class CerebriumAI(LLM): """Holds any model parameters valid for `create` call not explicitly specified.""" - cerebriumai_api_key: Optional[str] = None + cerebriumai_api_key: Optional[SecretStr] = None class Config: """Configuration for this pydantic config.""" @@ -64,8 +64,8 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - cerebriumai_api_key = get_from_dict_or_env( - values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY" + cerebriumai_api_key = convert_to_secret_str( + get_from_dict_or_env(values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY") ) values["cerebriumai_api_key"] = cerebriumai_api_key return values @@ -91,7 +91,8 @@ def _call( **kwargs: Any, ) -> str: headers: Dict = { - "Authorization": self.cerebriumai_api_key, + "Authorization": self.cerebriumai_api_key + and self.cerebriumai_api_key.get_secret_value(), "Content-Type": "application/json", } params = self.model_kwargs or {} From 6be9ae613adb683470b47a3b1e9584693e2d2f5b Mon Sep 17 00:00:00 2001 From: Yuchen Liang Date: Tue, 5 Dec 2023 00:32:31 -0500 Subject: [PATCH 2/5] add unit tests for masking cerebrium api key Signed-off-by: Yuchen Liang --- ...{test_cerebrium.py => test_cerebriumai.py} | 0 .../tests/unit_tests/llms/test_cerebriumai.py | 33 +++++++++++++++++++ 2 files changed, 33 insertions(+) rename libs/langchain/tests/integration_tests/llms/{test_cerebrium.py => test_cerebriumai.py} (100%) create mode 100644 libs/langchain/tests/unit_tests/llms/test_cerebriumai.py diff --git a/libs/langchain/tests/integration_tests/llms/test_cerebrium.py b/libs/langchain/tests/integration_tests/llms/test_cerebriumai.py similarity index 100% rename from libs/langchain/tests/integration_tests/llms/test_cerebrium.py rename to libs/langchain/tests/integration_tests/llms/test_cerebriumai.py diff --git a/libs/langchain/tests/unit_tests/llms/test_cerebriumai.py b/libs/langchain/tests/unit_tests/llms/test_cerebriumai.py new file mode 100644 index 0000000000000..b7d343081dd98 --- /dev/null +++ b/libs/langchain/tests/unit_tests/llms/test_cerebriumai.py @@ -0,0 +1,33 @@ +"""Test CerebriumAI llm""" + + +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch + +from langchain.llms.cerebriumai import CerebriumAI + + +def test_api_key_is_secret_string() -> None: + llm = CerebriumAI(cerebriumai_api_key="test-cerebriumai-api-key") + assert isinstance(llm.cerebriumai_api_key, SecretStr) + + +def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None: + llm = CerebriumAI(cerebriumai_api_key="secret-api-key") + print(llm.cerebriumai_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')" + + +def test_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + monkeypatch.setenv("CEREBRIUMAI_API_KEY", "secret-api-key") + llm = CerebriumAI() + print(llm.cerebriumai_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')" From 075f914cdb00d264ba4af82ed5e8e4d0a7e14b93 Mon Sep 17 00:00:00 2001 From: Yuchen Liang Date: Tue, 5 Dec 2023 00:43:53 -0500 Subject: [PATCH 3/5] improve doc Signed-off-by: Yuchen Liang --- libs/langchain/langchain/llms/cerebriumai.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/llms/cerebriumai.py b/libs/langchain/langchain/llms/cerebriumai.py index fb74313c6792c..7044cc97fd143 100644 --- a/libs/langchain/langchain/llms/cerebriumai.py +++ b/libs/langchain/langchain/llms/cerebriumai.py @@ -15,8 +15,9 @@ class CerebriumAI(LLM): """CerebriumAI large language models. - To use, you should have the ``cerebrium`` python package installed, and the - environment variable ``CEREBRIUMAI_API_KEY`` set with your API key. + To use, you should have the ``cerebrium`` python package installed. + You should also have the environment variable ``CEREBRIUMAI_API_KEY`` + set with your API key or pass it as a named arguement in the constructor. Any parameters that are valid to be passed to the call can be passed in, even if not explicitly saved on this class. @@ -25,7 +26,7 @@ class CerebriumAI(LLM): .. code-block:: python from langchain.llms import CerebriumAI - cerebrium = CerebriumAI(endpoint_url="") + cerebrium = CerebriumAI(endpoint_url="", cerebriumai_api_key="my-api-key") """ From 098cb7bc73e793de4c736af22845bb68e042d65e Mon Sep 17 00:00:00 2001 From: Yuchen Liang Date: Wed, 6 Dec 2023 00:11:37 -0500 Subject: [PATCH 4/5] fix spell Signed-off-by: Yuchen Liang --- libs/langchain/langchain/llms/cerebriumai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/langchain/llms/cerebriumai.py b/libs/langchain/langchain/llms/cerebriumai.py index 7044cc97fd143..dd9e98509d097 100644 --- a/libs/langchain/langchain/llms/cerebriumai.py +++ b/libs/langchain/langchain/llms/cerebriumai.py @@ -17,7 +17,7 @@ class CerebriumAI(LLM): To use, you should have the ``cerebrium`` python package installed. You should also have the environment variable ``CEREBRIUMAI_API_KEY`` - set with your API key or pass it as a named arguement in the constructor. + set with your API key or pass it as a named argument in the constructor. Any parameters that are valid to be passed to the call can be passed in, even if not explicitly saved on this class. From 3c53051610db24933d1d2419590b900480f1ee3a Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 6 Dec 2023 08:58:38 -0800 Subject: [PATCH 5/5] cr --- libs/langchain/langchain/llms/cerebriumai.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/llms/cerebriumai.py b/libs/langchain/langchain/llms/cerebriumai.py index dd9e98509d097..75c7c7b5fa701 100644 --- a/libs/langchain/langchain/llms/cerebriumai.py +++ b/libs/langchain/langchain/llms/cerebriumai.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, cast import requests from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator @@ -92,8 +92,9 @@ def _call( **kwargs: Any, ) -> str: headers: Dict = { - "Authorization": self.cerebriumai_api_key - and self.cerebriumai_api_key.get_secret_value(), + "Authorization": cast( + SecretStr, self.cerebriumai_api_key + ).get_secret_value(), "Content-Type": "application/json", } params = self.model_kwargs or {}