Skip to content

Commit

Permalink
Mask API key for Minimax LLM (#14309)
Browse files Browse the repository at this point in the history
- **Description:** Added masking for the API key for Minimax LLM + tests
inspired by #12418.
- **Issue:** the issue # fixes
#12165
- **Dependencies:** this fix is dependent on Minimax instantiation fix
which is introduced in
#13439, so merge this one
after.
  - **Tag maintainer:** @eyurtsev

---------

Co-authored-by: Harrison Chase <[email protected]>
  • Loading branch information
rancomp and hwchase17 authored Dec 5, 2023
1 parent 29e993a commit d22c13e
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 16 deletions.
29 changes: 13 additions & 16 deletions libs/langchain/langchain/llms/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
)

import requests
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator

from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env

logger = logging.getLogger(__name__)

Expand All @@ -27,7 +27,7 @@ class _MinimaxEndpointClient(BaseModel):

host: str
group_id: str
api_key: str
api_key: SecretStr
api_url: str

@root_validator(pre=True, allow_reuse=True)
Expand All @@ -40,7 +40,7 @@ def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values

def post(self, request: Any) -> Any:
headers = {"Authorization": f"Bearer {self.api_key}"}
headers = {"Authorization": f"Bearer {self.api_key.get_secret_value()}"}
response = requests.post(self.api_url, headers=headers, json=request)
# TODO: error handling and automatic retries
if not response.ok:
Expand All @@ -56,7 +56,7 @@ def post(self, request: Any) -> Any:
class MinimaxCommon(BaseModel):
"""Common parameters for Minimax large language models."""

_client: Any = None
_client: _MinimaxEndpointClient
model: str = "abab5.5-chat"
"""Model name to use."""
max_tokens: int = 256
Expand All @@ -69,13 +69,13 @@ class MinimaxCommon(BaseModel):
"""Holds any model parameters valid for `create` call not explicitly specified."""
minimax_api_host: Optional[str] = None
minimax_group_id: Optional[str] = None
minimax_api_key: Optional[str] = None
minimax_api_key: Optional[SecretStr] = None

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["minimax_api_key"] = get_from_dict_or_env(
values, "minimax_api_key", "MINIMAX_API_KEY"
values["minimax_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY")
)
values["minimax_group_id"] = get_from_dict_or_env(
values, "minimax_group_id", "MINIMAX_GROUP_ID"
Expand All @@ -87,6 +87,11 @@ def validate_environment(cls, values: Dict) -> Dict:
"MINIMAX_API_HOST",
default="https://api.minimax.chat",
)
values["_client"] = _MinimaxEndpointClient(
host=values["minimax_api_host"],
api_key=values["minimax_api_key"],
group_id=values["minimax_group_id"],
)
return values

@property
Expand All @@ -110,14 +115,6 @@ def _llm_type(self) -> str:
"""Return type of llm."""
return "minimax"

def __init__(self, **data: Any):
super().__init__(**data)
self._client = _MinimaxEndpointClient(
host=self.minimax_api_host,
api_key=self.minimax_api_key,
group_id=self.minimax_group_id,
)


class Minimax(MinimaxCommon, LLM):
"""Wrapper around Minimax large language models.
Expand Down
42 changes: 42 additions & 0 deletions libs/langchain/tests/unit_tests/llms/test_minimax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Test Minimax llm"""
from typing import cast

from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch

from langchain.llms.minimax import Minimax


def test_api_key_is_secret_string() -> None:
llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id")
assert isinstance(llm.minimax_api_key, SecretStr)


def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("MINIMAX_API_KEY", "secret-api-key")
monkeypatch.setenv("MINIMAX_GROUP_ID", "group_id")
llm = Minimax()
print(llm.minimax_api_key, end="")
captured = capsys.readouterr()

assert captured.out == "**********"


def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id")
print(llm.minimax_api_key, end="")
captured = capsys.readouterr()

assert captured.out == "**********"


def test_uses_actual_secret_value_from_secretstr() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id")
assert cast(SecretStr, llm.minimax_api_key).get_secret_value() == "secret-api-key"

0 comments on commit d22c13e

Please sign in to comment.