From d4d64daa1e304aaa4bdf78d7a907f44a294b8fe3 Mon Sep 17 00:00:00 2001 From: newfinder Date: Thu, 7 Dec 2023 00:47:09 +0800 Subject: [PATCH] Mask API key for baidu qianfan (#14281) Description: This PR masked baidu qianfan - Chat_Models API Key and added unit tests. Issue: the issue langchain-ai#12165. Tag maintainer: @eyurtsev --------- Co-authored-by: xiayi --- .../chat_models/baidu_qianfan_endpoint.py | 31 ++++++----- .../chat_models/test_baiduqianfan.py | 53 +++++++++++++++++++ 2 files changed, 71 insertions(+), 13 deletions(-) create mode 100644 libs/langchain/tests/integration_tests/chat_models/test_baiduqianfan.py diff --git a/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py b/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py index 7c7e3f67edffd..51303ddbb74b9 100644 --- a/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py +++ b/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py @@ -13,7 +13,8 @@ SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -88,8 +89,8 @@ class QianfanChatEndpoint(BaseChatModel): client: Any - qianfan_ak: Optional[str] = None - qianfan_sk: Optional[str] = None + qianfan_ak: Optional[SecretStr] = None + qianfan_sk: Optional[SecretStr] = None streaming: Optional[bool] = False """Whether to stream the results or not.""" @@ -118,19 +119,23 @@ class QianfanChatEndpoint(BaseChatModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: - values["qianfan_ak"] = get_from_dict_or_env( - values, - "qianfan_ak", - "QIANFAN_AK", + values["qianfan_ak"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "qianfan_ak", + "QIANFAN_AK", + ) ) - values["qianfan_sk"] = get_from_dict_or_env( - values, - "qianfan_sk", - "QIANFAN_SK", + values["qianfan_sk"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "qianfan_sk", + "QIANFAN_SK", + ) ) params = { - "ak": values["qianfan_ak"], - "sk": values["qianfan_sk"], + "ak": values["qianfan_ak"].get_secret_value(), + "sk": values["qianfan_sk"].get_secret_value(), "model": values["model"], "stream": values["streaming"], } diff --git a/libs/langchain/tests/integration_tests/chat_models/test_baiduqianfan.py b/libs/langchain/tests/integration_tests/chat_models/test_baiduqianfan.py new file mode 100644 index 0000000000000..e8a4dfae62e8c --- /dev/null +++ b/libs/langchain/tests/integration_tests/chat_models/test_baiduqianfan.py @@ -0,0 +1,53 @@ +from typing import cast + +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch + +from langchain.chat_models.baidu_qianfan_endpoint import ( + QianfanChatEndpoint, +) + + +def test_qianfan_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("QIANFAN_AK", "test-api-key") + monkeypatch.setenv("QIANFAN_SK", "test-secret-key") + + chat = QianfanChatEndpoint() + print(chat.qianfan_ak, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + print(chat.qianfan_sk, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + +def test_qianfan_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + chat = QianfanChatEndpoint( + qianfan_ak="test-api-key", + qianfan_sk="test-secret-key", + ) + print(chat.qianfan_ak, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + print(chat.qianfan_sk, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_uses_actual_secret_value_from_secret_str() -> None: + """Test that actual secret is retrieved using `.get_secret_value()`.""" + chat = QianfanChatEndpoint( + qianfan_ak="test-api-key", + qianfan_sk="test-secret-key", + ) + assert cast(SecretStr, chat.qianfan_ak).get_secret_value() == "test-api-key" + assert cast(SecretStr, chat.qianfan_sk).get_secret_value() == "test-secret-key"