Skip to content

Commit

Permalink
core, partners: implement standard tracing params for LLMs (#25410)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme authored Aug 16, 2024
1 parent 9f0c76b commit b83f1eb
Show file tree
Hide file tree
Showing 17 changed files with 298 additions and 36 deletions.
2 changes: 1 addition & 1 deletion libs/community/tests/unit_tests/chat_models/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_standard_params() -> None:
class ExpectedParams(BaseModel):
ls_provider: str
ls_model_name: str
ls_model_type: Literal["chat"]
ls_model_type: Literal["chat", "llm"]
ls_temperature: Optional[float]
ls_max_tokens: Optional[int]
ls_stop: Optional[List[str]]
Expand Down
2 changes: 2 additions & 0 deletions libs/core/langchain_core/language_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
LanguageModelLike,
LanguageModelOutput,
Expand All @@ -62,6 +63,7 @@
"LLM",
"LanguageModelInput",
"get_tokenizer",
"LangSmithParams",
"LanguageModelOutput",
"LanguageModelLike",
"FakeListLLM",
Expand Down
20 changes: 19 additions & 1 deletion libs/core/langchain_core/language_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Expand All @@ -17,7 +18,7 @@
Union,
)

from typing_extensions import TypeAlias
from typing_extensions import TypeAlias, TypedDict

from langchain_core._api import deprecated
from langchain_core.messages import (
Expand All @@ -37,6 +38,23 @@
from langchain_core.outputs import LLMResult


class LangSmithParams(TypedDict, total=False):
"""LangSmith parameters for tracing."""

ls_provider: str
"""Provider of the model."""
ls_model_name: str
"""Name of the model."""
ls_model_type: Literal["chat", "llm"]
"""Type of the model. Should be 'chat' or 'llm'."""
ls_temperature: Optional[float]
"""Temperature for generation."""
ls_max_tokens: Optional[int]
"""Max tokens for generation."""
ls_stop: Optional[List[str]]
"""Stop words for generation."""


@lru_cache(maxsize=None) # Cache the tokenizer
def get_tokenizer() -> Any:
"""Get a GPT-2 tokenizer instance.
Expand Down
25 changes: 5 additions & 20 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
cast,
)

from typing_extensions import TypedDict

from langchain_core._api import deprecated
from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
Expand All @@ -36,7 +34,11 @@
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
)
from langchain_core.load import dumpd, dumps
from langchain_core.messages import (
AIMessage,
Expand Down Expand Up @@ -73,23 +75,6 @@
from langchain_core.tools import BaseTool


class LangSmithParams(TypedDict, total=False):
"""LangSmith parameters for tracing."""

ls_provider: str
"""Provider of the model."""
ls_model_name: str
"""Name of the model."""
ls_model_type: Literal["chat"]
"""Type of the model. Should be 'chat'."""
ls_temperature: Optional[float]
"""Temperature for generation."""
ls_max_tokens: Optional[int]
"""Max tokens for generation."""
ls_stop: Optional[List[str]]
"""Stop words for generation."""


def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
"""Generate from a stream.
Expand Down
85 changes: 82 additions & 3 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
LanguageModelInput,
)
from langchain_core.load import dumpd
from langchain_core.messages import (
AIMessage,
Expand Down Expand Up @@ -331,6 +335,43 @@ def _convert_input(self, input: LanguageModelInput) -> PromptValue:
"Must be a PromptValue, str, or list of BaseMessages."
)

def _get_ls_params(
self,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> LangSmithParams:
"""Get standard params for tracing."""

# get default provider from class name
default_provider = self.__class__.__name__
if default_provider.endswith("LLM"):
default_provider = default_provider[:-3]
default_provider = default_provider.lower()

ls_params = LangSmithParams(ls_provider=default_provider, ls_model_type="llm")
if stop:
ls_params["ls_stop"] = stop

# model
if hasattr(self, "model") and isinstance(self.model, str):
ls_params["ls_model_name"] = self.model
elif hasattr(self, "model_name") and isinstance(self.model_name, str):
ls_params["ls_model_name"] = self.model_name

# temperature
if "temperature" in kwargs and isinstance(kwargs["temperature"], float):
ls_params["ls_temperature"] = kwargs["temperature"]
elif hasattr(self, "temperature") and isinstance(self.temperature, float):
ls_params["ls_temperature"] = self.temperature

# max_tokens
if "max_tokens" in kwargs and isinstance(kwargs["max_tokens"], int):
ls_params["ls_max_tokens"] = kwargs["max_tokens"]
elif hasattr(self, "max_tokens") and isinstance(self.max_tokens, int):
ls_params["ls_max_tokens"] = self.max_tokens

return ls_params

def invoke(
self,
input: LanguageModelInput,
Expand Down Expand Up @@ -487,13 +528,17 @@ def stream(
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
inheritable_metadata,
self.metadata,
)
(run_manager,) = callback_manager.on_llm_start(
Expand Down Expand Up @@ -548,13 +593,17 @@ async def astream(
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
inheritable_metadata,
self.metadata,
)
(run_manager,) = await callback_manager.on_llm_start(
Expand Down Expand Up @@ -796,6 +845,21 @@ def generate(
f" argument of type {type(prompts)}."
)
# Create callback managers
if isinstance(metadata, list):
metadata = [
{
**(meta or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
for meta in metadata
]
elif isinstance(metadata, dict):
metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
else:
pass
if (
isinstance(callbacks, list)
and callbacks
Expand Down Expand Up @@ -1017,6 +1081,21 @@ async def agenerate(
An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output.
"""
if isinstance(metadata, list):
metadata = [
{
**(meta or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
for meta in metadata
]
elif isinstance(metadata, dict):
metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
else:
pass
# Create callback managers
if isinstance(callbacks, list) and (
isinstance(callbacks[0], (list, BaseCallbackManager))
Expand Down
1 change: 1 addition & 0 deletions libs/core/tests/unit_tests/language_models/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"SimpleChatModel",
"BaseLLM",
"LLM",
"LangSmithParams",
"LanguageModelInput",
"LanguageModelOutput",
"LanguageModelLike",
Expand Down

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2180,7 +2180,7 @@ async def test_prompt_with_llm(
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"metadata": {"ls_model_type": "llm", "ls_provider": "fakelist"},
"name": "FakeListLLM",
"start_time": "2023-01-01T00:00:00.000+00:00",
"streamed_output": [],
Expand Down Expand Up @@ -2384,7 +2384,10 @@ async def test_prompt_with_llm_parser(
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"metadata": {
"ls_model_type": "llm",
"ls_provider": "fakestreaminglist",
},
"name": "FakeStreamingListLLM",
"start_time": "2023-01-01T00:00:00.000+00:00",
"streamed_output": [],
Expand Down
15 changes: 14 additions & 1 deletion libs/partners/anthropic/langchain_anthropic/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models import BaseLanguageModel, LangSmithParams
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.prompt_values import PromptValue
Expand Down Expand Up @@ -204,6 +204,19 @@ def _identifying_params(self) -> Dict[str, Any]:
"max_retries": self.max_retries,
}

def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
identifying_params = self._identifying_params
if max_tokens := kwargs.get(
"max_tokens_to_sample",
identifying_params.get("max_tokens"),
):
params["ls_max_tokens"] = max_tokens
return params

def _wrap_prompt(self, prompt: str) -> str:
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")
Expand Down
29 changes: 29 additions & 0 deletions libs/partners/anthropic/tests/unit_tests/test_llms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

from langchain_anthropic import AnthropicLLM

os.environ["ANTHROPIC_API_KEY"] = "foo"


def test_anthropic_model_params() -> None:
# Test standard tracing params
llm = AnthropicLLM(model="foo") # type: ignore[call-arg]

ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "anthropic",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 1024,
}

llm = AnthropicLLM(model="foo", temperature=0.1) # type: ignore[call-arg]

ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "anthropic",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 1024,
"ls_temperature": 0.1,
}
28 changes: 28 additions & 0 deletions libs/partners/fireworks/tests/unit_tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,31 @@ def test_fireworks_uses_actual_secret_value_from_secretstr() -> None:
max_tokens=250,
)
assert cast(SecretStr, llm.fireworks_api_key).get_secret_value() == "secret-api-key"


def test_fireworks_model_params() -> None:
# Test standard tracing params
llm = Fireworks(model="foo", api_key="secret-api-key") # type: ignore[arg-type]

ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "fireworks",
"ls_model_type": "llm",
"ls_model_name": "foo",
}

llm = Fireworks(
model="foo",
api_key="secret-api-key", # type: ignore[arg-type]
max_tokens=10,
temperature=0.1,
)

ls_params = llm._get_ls_params()
assert ls_params == {
"ls_provider": "fireworks",
"ls_model_type": "llm",
"ls_model_name": "foo",
"ls_max_tokens": 10,
"ls_temperature": 0.1,
}
11 changes: 10 additions & 1 deletion libs/partners/ollama/langchain_ollama/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM
from langchain_core.language_models import BaseLLM, LangSmithParams
from langchain_core.outputs import GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, root_validator
from ollama import AsyncClient, Client, Options
Expand Down Expand Up @@ -155,6 +155,15 @@ def _llm_type(self) -> str:
"""Return type of LLM."""
return "ollama-llm"

def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = super()._get_ls_params(stop=stop, **kwargs)
if max_tokens := kwargs.get("num_predict", self.num_predict):
params["ls_max_tokens"] = max_tokens
return params

@root_validator(pre=False, skip_on_failure=True)
def _set_clients(cls, values: dict) -> dict:
"""Set clients to use for ollama."""
Expand Down
Loading

0 comments on commit b83f1eb

Please sign in to comment.