Skip to content

Commit

Permalink
Fix Anthropic tokenizer protocol #17115 (#17201)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreyling authored Dec 9, 2024
1 parent ae18106 commit baf1a8f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@
DEFAULT_ANTHROPIC_MAX_TOKENS = 512


class AnthropicTokenizer:
def __init__(self, client, model) -> None:
self._client = client
self.model = model

def encode(self, text: str, *args: Any, **kwargs: Any) -> List[int]:
count = self._client.beta.messages.count_tokens(
messages=[{"role": "user", "content": text}],
model=self.model,
).input_tokens
return [1] * count


class Anthropic(FunctionCallingLLM):
"""Anthropic LLM.
Expand Down Expand Up @@ -210,13 +223,7 @@ def metadata(self) -> LLMMetadata:

@property
def tokenizer(self) -> Tokenizer:
def _count_tokens(text: str) -> int:
return self._client.beta.messages.count_tokens(
messages=[{"role": "user", "content": text}],
model=self.model,
).input_tokens

return _count_tokens
return AnthropicTokenizer(self._client, self.model)

@property
def _model_kwargs(self) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from llama_index.core.llms import ChatMessage
import os
import pytest
from unittest.mock import MagicMock


def test_text_inference_embedding_class():
Expand Down Expand Up @@ -166,3 +167,39 @@ async def test_anthropic_through_bedrock_async():
print(f"Assertion failed: full_response is not a string")
print(f"Content of full_response: {full_response}")
raise


def test_anthropic_tokenizer():
"""Test that the Anthropic tokenizer properly implements the Tokenizer protocol."""
# Create a mock Messages object that returns a predictable token count
mock_messages = MagicMock()
mock_messages.count_tokens.return_value.input_tokens = 5

# Create a mock Beta object that returns our mock messages
mock_beta = MagicMock()
mock_beta.messages = mock_messages

# Create a mock client that returns our mock beta
mock_client = MagicMock()
mock_client.beta = mock_beta

# Create the Anthropic instance with our mock
anthropic_llm = Anthropic(model="claude-3-5-sonnet-20241022")
anthropic_llm._client = mock_client

# Test that tokenizer implements the protocol
tokenizer = anthropic_llm.tokenizer
assert hasattr(tokenizer, "encode")

# Test that encode returns a list of integers
test_text = "Hello, world!"
tokens = tokenizer.encode(test_text)
assert isinstance(tokens, list)
assert all(isinstance(t, int) for t in tokens)
assert len(tokens) == 5 # Should match our mocked token count

# Verify the mock was called correctly
mock_messages.count_tokens.assert_called_once_with(
messages=[{"role": "user", "content": test_text}],
model="claude-3-5-sonnet-20241022",
)

0 comments on commit baf1a8f

Please sign in to comment.